import * as posedetection from "@tensorflow-models/pose-detection";
import * as tf from "@tensorflow/tfjs-core";
import "@tensorflow/tfjs-backend-webgl";

let _blazePose;
let _moveNet;

const _initDetectors = async () => {
  _blazePose = await posedetection.createDetector(
    posedetection.SupportedModels.BlazePose,
    {
      enableSmoothing: false,
      runtime: "tfjs",
      modelType: "heavy",
    }
  );
  _moveNet = await posedetection.createDetector(
    posedetection.SupportedModels.MoveNet,
    {
      modelType: posedetection.movenet.modelType.SINGLEPOSE_THUNDER,
      enableSmoothing: false,
      minPoseScore: 0.25,
      enableTracking: false,
    }
  );
};

const _doInference = async (model, img, onError) => {
  let results;
  try {
    results = await model.estimatePoses(img, {
      maxPoses: 1,
      flipHorizontal: false,
    });
  } catch (error) {
    model.dispose();
    onError(error);
    return;
  }

  // If it doesn't detect anything, try again! This is weirdly effective...
  if (results.length === 0) {
    results = await model.estimatePoses(img, {
      maxPoses: 1,
      flipHorizontal: false,
    });
  }

  if (results && results.length > 0) {
    const keypoints = {};
    results[0]["keypoints"].forEach((kpt) => {
      keypoints[kpt.name] = kpt;
    });
    return keypoints;
  }
};

const _cropImage = (img, bbox) => {
  const { top, left, width, height } = bbox;
  const imageTensor3D = tf.browser.fromPixels(img);
  const imageTensor4D = tf.expandDims(imageTensor3D, 0);
  const croppedImage = tf.tidy(() => {
    // Crop region is a [batch, 4] size tensor.
    const cropRegionTensor = tf.tensor2d([
      [top, left, top + height, left + width],
    ]);
    // The batch index that the crop should operate on. A [batch] size
    // tensor.
    const boxInd = tf.zeros([1], "int32");

    // Target size of each crop.
    const cropSize = [
      Math.round(height * img.height),
      Math.round(width * img.width),
    ];
    return tf.cast(
      tf.image.cropAndResize(
        imageTensor4D,
        cropRegionTensor,
        boxInd,
        cropSize,
        "bilinear",
        0
      ),
      "int32"
    );
  });
  imageTensor4D.dispose();
  return tf.squeeze(croppedImage);
};

export const inferKeypoints = async (img, bbox = null) => {
  if (bbox === null) {
    bbox = { top: 0, left: 0, width: 1, height: 1 };
  }
  if (!_blazePose || !_moveNet) {
    await _initDetectors();
  }
  const croppedImage = _cropImage(img, bbox);

  // do inference with cropped image
  const blazePoseResults = await _doInference(_blazePose, croppedImage, () => {
    _blazePose = null;
  });
  const moveNetResults = await _doInference(_moveNet, croppedImage, () => {
    _moveNet = null;
  });

  const blazePoseKeypoints = [
    "left_heel",
    "right_heel",
    "left_foot_index",
    "right_foot_index",
  ];

  const result = {};
  const xOffset = bbox.left * img.width;
  const yOffset = bbox.top * img.height;

  if (blazePoseResults) {
    blazePoseKeypoints.forEach((label) => {
      result[label] = {
        ...blazePoseResults[label],
        x: xOffset + blazePoseResults[label].x,
        y: yOffset + blazePoseResults[label].y,
      };
    });
  }

  if (moveNetResults) {
    Object.keys(moveNetResults).forEach((label) => {
      result[label] = {
        ...moveNetResults[label],
        x: xOffset + moveNetResults[label].x,
        y: yOffset + moveNetResults[label].y,
      };
    });
  }
  return result;
};
