// Copyright 2013 The Flutter Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

import 'dart:convert';
import 'dart:ffi';
import 'dart:typed_data';

import 'package:ui/src/engine.dart';
import 'package:ui/src/engine/skwasm/skwasm_impl.dart';
import 'package:ui/ui.dart' as ui;

// A shared interface for shaders for which you can acquire a native handle.
abstract class SkwasmShader implements ui.Shader {
  ShaderHandle get handle;

  /// Whether the shader represents a gradient.
  bool get isGradient;
}

// An implementation that handles the storage, disposal, and finalization of
// a native shader handle.
class SkwasmNativeShader extends SkwasmObjectWrapper<RawShader> implements SkwasmShader {
  SkwasmNativeShader(ShaderHandle handle) : super(handle, _registry);

  static final SkwasmFinalizationRegistry<RawShader> _registry =
      SkwasmFinalizationRegistry<RawShader>((ShaderHandle handle) => shaderDispose(handle));

  @override
  bool get isGradient => false;
}

class SkwasmGradient extends SkwasmNativeShader implements ui.Gradient {
  factory SkwasmGradient.linear({
    required ui.Offset from,
    required ui.Offset to,
    required List<ui.Color> colors,
    List<double>? colorStops,
    ui.TileMode tileMode = ui.TileMode.clamp,
    Float32List? matrix4,
  }) => withStackScope((StackScope scope) {
    assert(() {
      validateColorStops(colors, colorStops);
      return true;
    }());

    final RawPointArray endPoints = scope.convertPointArrayToNative(<ui.Offset>[from, to]);
    final RawColorArray nativeColors = scope.convertColorArrayToNative(colors);
    final Pointer<Float> stops = colorStops != null
        ? scope.convertDoublesToNative(colorStops)
        : nullptr;
    final Pointer<Float> matrix = matrix4 != null
        ? scope.convertMatrix4toSkMatrix(matrix4)
        : nullptr;
    final ShaderHandle handle = shaderCreateLinearGradient(
      endPoints,
      nativeColors,
      stops,
      colors.length,
      tileMode.index,
      matrix,
    );
    return SkwasmGradient._(handle);
  });

  factory SkwasmGradient.radial({
    required ui.Offset center,
    required double radius,
    required List<ui.Color> colors,
    List<double>? colorStops,
    ui.TileMode tileMode = ui.TileMode.clamp,
    Float32List? matrix4,
  }) => withStackScope((StackScope scope) {
    assert(() {
      validateColorStops(colors, colorStops);
      return true;
    }());

    final RawColorArray rawColors = scope.convertColorArrayToNative(colors);
    final Pointer<Float> rawStops = colorStops != null
        ? scope.convertDoublesToNative(colorStops)
        : nullptr;
    final Pointer<Float> matrix = matrix4 != null
        ? scope.convertMatrix4toSkMatrix(matrix4)
        : nullptr;
    final ShaderHandle handle = shaderCreateRadialGradient(
      center.dx,
      center.dy,
      radius,
      rawColors,
      rawStops,
      colors.length,
      tileMode.index,
      matrix,
    );
    return SkwasmGradient._(handle);
  });

  factory SkwasmGradient.conical({
    required ui.Offset focal,
    required double focalRadius,
    required ui.Offset center,
    required double centerRadius,
    required List<ui.Color> colors,
    List<double>? colorStops,
    ui.TileMode tileMode = ui.TileMode.clamp,
    Float32List? matrix4,
  }) => withStackScope((StackScope scope) {
    assert(() {
      validateColorStops(colors, colorStops);
      return true;
    }());

    final RawPointArray endPoints = scope.convertPointArrayToNative(<ui.Offset>[focal, center]);
    final RawColorArray rawColors = scope.convertColorArrayToNative(colors);
    final Pointer<Float> rawStops = colorStops != null
        ? scope.convertDoublesToNative(colorStops)
        : nullptr;
    final Pointer<Float> matrix = matrix4 != null
        ? scope.convertMatrix4toSkMatrix(matrix4)
        : nullptr;
    final ShaderHandle handle = shaderCreateConicalGradient(
      endPoints,
      focalRadius,
      centerRadius,
      rawColors,
      rawStops,
      colors.length,
      tileMode.index,
      matrix,
    );
    return SkwasmGradient._(handle);
  });

  factory SkwasmGradient.sweep({
    required ui.Offset center,
    required List<ui.Color> colors,
    List<double>? colorStops,
    ui.TileMode tileMode = ui.TileMode.clamp,
    required double startAngle,
    required double endAngle,
    Float32List? matrix4,
  }) => withStackScope((StackScope scope) {
    assert(() {
      validateColorStops(colors, colorStops);
      return true;
    }());

    final RawColorArray rawColors = scope.convertColorArrayToNative(colors);
    final Pointer<Float> rawStops = colorStops != null
        ? scope.convertDoublesToNative(colorStops)
        : nullptr;
    final Pointer<Float> matrix = matrix4 != null
        ? scope.convertMatrix4toSkMatrix(matrix4)
        : nullptr;
    final ShaderHandle handle = shaderCreateSweepGradient(
      center.dx,
      center.dy,
      rawColors,
      rawStops,
      colors.length,
      tileMode.index,
      ui.toDegrees(startAngle),
      ui.toDegrees(endAngle),
      matrix,
    );
    return SkwasmGradient._(handle);
  });

  SkwasmGradient._(super.handle);

  @override
  bool get isGradient => true;

  @override
  String toString() => 'Gradient()';
}

class SkwasmImageShader extends SkwasmNativeShader implements ui.ImageShader {
  SkwasmImageShader._(super.handle);

  factory SkwasmImageShader.imageShader(
    SkwasmImage image,
    ui.TileMode tmx,
    ui.TileMode tmy,
    Float64List? matrix4,
    ui.FilterQuality? filterQuality,
  ) {
    if (matrix4 != null) {
      return withStackScope((StackScope scope) {
        final RawMatrix33 localMatrix = scope.convertMatrix4toSkMatrix(matrix4);
        return SkwasmImageShader._(
          shaderCreateFromImage(
            image.handle,
            tmx.index,
            tmy.index,
            (filterQuality ?? ui.FilterQuality.none).index,
            localMatrix,
          ),
        );
      });
    } else {
      return SkwasmImageShader._(
        shaderCreateFromImage(
          image.handle,
          tmx.index,
          tmy.index,
          (filterQuality ?? ui.FilterQuality.none).index,
          nullptr,
        ),
      );
    }
  }
}

class SkwasmFragmentProgram extends SkwasmObjectWrapper<RawRuntimeEffect>
    implements ui.FragmentProgram {
  SkwasmFragmentProgram._(this.name, RuntimeEffectHandle handle, this._shaderData)
    : super(handle, _registry);

  factory SkwasmFragmentProgram.fromBytes(String name, Uint8List bytes) {
    final shaderData = ShaderData.fromBytes(bytes);

    // TODO(jacksongardner): Can we avoid this copy?
    final List<int> sourceData = utf8.encode(shaderData.source);
    final SkStringHandle sourceString = skStringAllocate(sourceData.length);
    final Pointer<Int8> sourceBuffer = skStringGetData(sourceString);
    var i = 0;
    for (final byte in sourceData) {
      sourceBuffer[i] = byte;
      i++;
    }
    final RuntimeEffectHandle handle = runtimeEffectCreate(sourceString);
    skStringFree(sourceString);
    return SkwasmFragmentProgram._(name, handle, shaderData);
  }

  static final SkwasmFinalizationRegistry<RawRuntimeEffect> _registry =
      SkwasmFinalizationRegistry<RawRuntimeEffect>(
        (RuntimeEffectHandle handle) => runtimeEffectDispose(handle),
      );

  final String name;
  int get floatUniformCount => _shaderData.floatCount;
  int get childShaderCount => _shaderData.textureCount;
  final ShaderData _shaderData;

  @override
  ui.FragmentShader fragmentShader() => SkwasmFragmentShader(this);

  int get uniformSize => runtimeEffectGetUniformSize(handle);

  UniformData _getUniformFloatInfo(String name) {
    for (final UniformData uniform in _shaderData.uniforms) {
      if (uniform.name == name) {
        return uniform;
      }
    }
    throw ArgumentError('No uniform named "$name".');
  }
}

class SkwasmShaderData extends SkwasmObjectWrapper<RawUniformData> {
  SkwasmShaderData(int size) : super(uniformDataCreate(size), _registry);

  static final SkwasmFinalizationRegistry<RawUniformData> _registry =
      SkwasmFinalizationRegistry<RawUniformData>(
        (UniformDataHandle handle) => uniformDataDispose(handle),
      );

  Pointer<Void> get pointer => uniformDataGetPointer(handle);
}

// This class does not inherit from SkwasmNativeShader, as its handle might
// change over time if the uniforms or image shaders are changed. Instead this
// wraps a SkwasmNativeShader that it creates and destroys on demand. It does
// implement SkwasmShader though, in order to provide the handle for the
// underlying shader object.
class SkwasmFragmentShader implements SkwasmShader, ui.FragmentShader {
  SkwasmFragmentShader(SkwasmFragmentProgram program)
    : _program = program,
      _uniformData = SkwasmShaderData(program.uniformSize),
      _floatUniformCount = program.floatUniformCount,
      _childShaders = List<SkwasmShader?>.filled(program.childShaderCount, null);

  @override
  ShaderHandle get handle {
    if (_nativeShader == null) {
      final ShaderHandle newHandle = withStackScope((StackScope s) {
        Pointer<ShaderHandle> childShaders = nullptr;
        if (_childShaders.isNotEmpty) {
          childShaders = s.allocPointerArray(_childShaders.length).cast<ShaderHandle>();
          for (var i = 0; i < _childShaders.length; i++) {
            final SkwasmShader? child = _childShaders[i];
            childShaders[i] = child != null ? child.handle : nullptr;
          }
        }
        return shaderCreateRuntimeEffectShader(
          _program.handle,
          _uniformData.handle,
          childShaders,
          _childShaders.length,
        );
      });
      _nativeShader = SkwasmNativeShader(newHandle);
    }
    return _nativeShader!.handle;
  }

  @override
  bool get isGradient => false;

  SkwasmShader? _nativeShader;
  final SkwasmFragmentProgram _program;
  final SkwasmShaderData _uniformData;
  bool _isDisposed = false;
  final int _floatUniformCount;
  final List<SkwasmShader?> _childShaders;

  @override
  void dispose() {
    assert(!_isDisposed);
    _nativeShader?.dispose();
    _uniformData.dispose();
    _isDisposed = true;
  }

  @override
  bool get debugDisposed => _isDisposed;

  @override
  void setFloat(int index, double value) {
    if (_nativeShader != null) {
      // Invalidate the previous shader so that it is recreated with the new
      // uniform data.
      _nativeShader!.dispose();
      _nativeShader = null;
    }
    final Pointer<Float> dataPointer = _uniformData.pointer.cast<Float>();
    dataPointer[index] = value;
  }

  @override
  void setImageSampler(
    int index,
    ui.Image image, {
    ui.FilterQuality filterQuality = ui.FilterQuality.none,
  }) {
    if (_nativeShader != null) {
      // Invalidate the previous shader so that it is recreated with the new
      // child shaders.
      _nativeShader!.dispose();
      _nativeShader = null;
    }

    final shader = SkwasmImageShader.imageShader(
      image as SkwasmImage,
      ui.TileMode.clamp,
      ui.TileMode.clamp,
      null,
      filterQuality,
    );
    final SkwasmShader? oldShader = _childShaders[index];
    _childShaders[index] = shader;
    oldShader?.dispose();

    final Pointer<Float> dataPointer = _uniformData.pointer.cast<Float>();
    dataPointer[_floatUniformCount + index * 2] = image.width.toDouble();
    dataPointer[_floatUniformCount + index * 2 + 1] = image.height.toDouble();
  }

  @override
  ui.UniformFloatSlot getUniformFloat(String name, [int? index]) {
    index ??= 0;
    final UniformData info = _program._getUniformFloatInfo(name);

    IndexError.check(index, info.floatCount, message: 'Index `$index` out of bounds for `$name`.');

    return SkwasmUniformFloatSlot._(this, index, name, info.floatOffset + index);
  }

  @override
  ui.UniformVec2Slot getUniformVec2(String name) {
    final List<SkwasmUniformFloatSlot> slots = _getUniformFloatSlots(name, 2);
    return _SkwasmUniformVec2Slot._(slots[0], slots[1]);
  }

  @override
  ui.UniformVec3Slot getUniformVec3(String name) {
    final List<SkwasmUniformFloatSlot> slots = _getUniformFloatSlots(name, 3);
    return _SkwasmUniformVec3Slot._(slots[0], slots[1], slots[2]);
  }

  @override
  ui.UniformVec4Slot getUniformVec4(String name) {
    final List<SkwasmUniformFloatSlot> slots = _getUniformFloatSlots(name, 4);
    return _SkwasmUniformVec4Slot._(slots[0], slots[1], slots[2], slots[3]);
  }

  @override
  ui.UniformMat2Slot getUniformMat2(String name) {
    throw UnsupportedError('getUniformMat2 is not supported on the web.');
  }

  @override
  ui.UniformMat3Slot getUniformMat3(String name) {
    throw UnsupportedError('getUniformMat3 is not supported on the web.');
  }

  @override
  ui.UniformMat4Slot getUniformMat4(String name) {
    throw UnsupportedError('getUniformMat4 is not supported on the web.');
  }

  @override
  ui.UniformArray<ui.UniformMat2Slot> getUniformMat2Array(String name) {
    throw UnsupportedError('getUniformMat2Array is not supported on the web.');
  }

  @override
  ui.UniformArray<ui.UniformMat3Slot> getUniformMat3Array(String name) {
    throw UnsupportedError('getUniformMat3Array is not supported on the web.');
  }

  @override
  ui.UniformArray<ui.UniformMat4Slot> getUniformMat4Array(String name) {
    throw UnsupportedError('getUniformMat4Array is not supported on the web.');
  }

  @override
  ui.ImageSamplerSlot getImageSampler(String name) {
    throw UnsupportedError('getImageSampler is not supported on the web.');
  }

  List<SkwasmUniformFloatSlot> _getUniformFloatSlots(String name, int size) {
    final UniformData info = _program._getUniformFloatInfo(name);

    if (info.floatCount != size) {
      throw ArgumentError('Uniform `$name` has size ${info.floatCount}, not size $size.');
    }

    return List<SkwasmUniformFloatSlot>.generate(
      size,
      (i) => SkwasmUniformFloatSlot._(this, i, name, info.floatOffset + i),
    );
  }

  ui.UniformArray<T> _getUniformArray<T extends ui.UniformType>(
    String name,
    int elementSize,
    T Function(List<SkwasmUniformFloatSlot> slots) elementFactory,
  ) {
    final UniformData info = _program._getUniformFloatInfo(name);

    if (info.floatCount % elementSize != 0) {
      throw ArgumentError(
        'Uniform size (${info.floatCount}) for "$name" is not a multiple of $elementSize.',
      );
    }
    final int numElements = info.floatCount ~/ elementSize;

    final elements = List<T>.generate(numElements, (i) {
      final slots = List<SkwasmUniformFloatSlot>.generate(
        info.floatCount,
        (j) => SkwasmUniformFloatSlot._(this, j, name, info.floatOffset + i * elementSize + j),
      );
      return elementFactory(slots);
    });

    return _SkwasmUniformArray<T>._(elements);
  }

  @override
  ui.UniformArray<ui.UniformFloatSlot> getUniformFloatArray(String name) {
    return _getUniformArray(name, 1, (components) => components.first);
  }

  @override
  ui.UniformArray<ui.UniformVec2Slot> getUniformVec2Array(String name) {
    return _getUniformArray<_SkwasmUniformVec2Slot>(
      name,
      2, // 2 floats per element
      (components) => _SkwasmUniformVec2Slot._(
        components[0],
        components[1],
      ), // Create Vec2 from two UniformFloat components
    );
  }

  @override
  ui.UniformArray<ui.UniformVec3Slot> getUniformVec3Array(String name) {
    return _getUniformArray<_SkwasmUniformVec3Slot>(
      name,
      3, // 3 floats per element
      (components) =>
          _SkwasmUniformVec3Slot._(components[0], components[1], components[2]), // Create Vec3
    );
  }

  @override
  ui.UniformArray<ui.UniformVec4Slot> getUniformVec4Array(String name) {
    return _getUniformArray<_SkwasmUniformVec4Slot>(
      name,
      4, // 4 floats per element
      (components) => _SkwasmUniformVec4Slot._(
        components[0],
        components[1],
        components[2],
        components[3],
      ), // Create Vec4
    );
  }
}

class SkwasmUniformFloatSlot implements ui.UniformFloatSlot {
  SkwasmUniformFloatSlot._(this._shader, this.index, this.name, this.shaderIndex);

  final SkwasmFragmentShader _shader;

  @override
  final int index;

  @override
  final String name;

  @override
  void set(double val) {
    _shader.setFloat(shaderIndex, val);
  }

  @override
  final int shaderIndex;
}

class _SkwasmUniformVec2Slot implements ui.UniformVec2Slot {
  _SkwasmUniformVec2Slot._(this._xSlot, this._ySlot);

  @override
  void set(double x, double y) {
    _xSlot.set(x);
    _ySlot.set(y);
  }

  final SkwasmUniformFloatSlot _xSlot, _ySlot;
}

class _SkwasmUniformVec3Slot implements ui.UniformVec3Slot {
  _SkwasmUniformVec3Slot._(this._xSlot, this._ySlot, this._zSlot);

  @override
  void set(double x, double y, double z) {
    _xSlot.set(x);
    _ySlot.set(y);
    _zSlot.set(z);
  }

  final SkwasmUniformFloatSlot _xSlot, _ySlot, _zSlot;
}

class _SkwasmUniformVec4Slot implements ui.UniformVec4Slot {
  _SkwasmUniformVec4Slot._(this._xSlot, this._ySlot, this._zSlot, this._wSlot);

  @override
  void set(double x, double y, double z, double w) {
    _xSlot.set(x);
    _ySlot.set(y);
    _zSlot.set(z);
    _wSlot.set(w);
  }

  final SkwasmUniformFloatSlot _xSlot, _ySlot, _zSlot, _wSlot;
}

class _SkwasmUniformArray<T extends ui.UniformType> implements ui.UniformArray<T> {
  _SkwasmUniformArray._(this._elements);

  @override
  T operator [](int index) {
    return _elements[index];
  }

  @override
  int get length => _elements.length;

  final List<T> _elements;
}
