import { distance2D } from "./Math";

export function JointToPoints(jTP) {
  const MS_IN_SEC = 1000;
  const timestampsByJoint = {};
  const jointNames = Object.entries(jTP)
    .filter(([joint, points]) => typeof points !== "function")
    .map(([joint, points]) => joint);
  jointNames.forEach((joint) => {
    timestampsByJoint[joint] = jTP[joint].map((p) => p.timestamp * MS_IN_SEC);
  });

  const estimateFps = () => {
    const argmax = (arr) =>
      arr.reduce((iMax, x, i, arr) => (x > arr[iMax] ? i : iMax), 0);
    const numPointsForJoint = Object.keys(timestampsByJoint).map(
      (joint) => jTP[joint].length
    );
    const mostCommonJoint = jointNames[argmax(numPointsForJoint)];

    // get the delta between timestamps
    const t1 = [...timestampsByJoint[mostCommonJoint]].sort((a, b) => a - b);
    const deltas = t1
      .slice(1)
      .map((t2, i) => t2 - t1[i])
      .filter((x) => x > 0);
    const p10 = deltas[Math.floor(deltas.length / 10)];
    return 1000 / p10;
  };

  const _indexAtOrBefore = (joint, t) => {
    // Note: we assume vals is pre-sorted
    const vals = timestampsByJoint[joint];
    var hi = vals.length - 1;
    var lo = 0;
    if (t < vals[lo]) {
      return -1;
    } else if (t >= vals[hi]) {
      return hi;
    }

    // Invariants:
    // * lo <= t
    // * (hi - lo) decreases on every iteration
    while (lo < hi) {
      const mid = lo + Math.ceil((hi - lo) / 2);
      if (vals[mid] > t) {
        hi = mid - 1;
      } else if (vals[mid] < t) {
        lo = mid;
      } else {
        return mid;
      }
    }
    return lo;
  };

  const positionAt = (joint, t, frameIndex = null) => {
    if (!(joint in jTP && jTP[joint].length > 0)) {
      return null;
    }

    // if we have this exact frame, use it
    if (frameIndex !== null) {
      const point = jTP[joint].find(
        ({ frame_idx }) => frame_idx === frameIndex
      );
      if (point) {
        return point["position"];
      }
    }

    // else, fall back to the timestamp and interpolate
    const idx = _indexAtOrBefore(joint, t);
    if (idx === -1) {
      return null;
    }
    const points = jTP[joint];
    const t1 = timestampsByJoint[joint][idx];
    const p1 = points[idx]["position"];
    if (idx + 1 < points.length) {
      // interpolate between this point and the next one
      const t2 = timestampsByJoint[joint][idx + 1];
      // when j2p has duplicates it's possible that t2 === t1
      const weight = t2 === t1 ? 1 : (t - t1) / (t2 - t1);
      const p2 = points[idx + 1]["position"];
      return {
        x: (1 - weight) * p1.x + weight * p2.x,
        y: (1 - weight) * p1.y + weight * p2.y,
      };
    } else {
      return p1;
    }
  };

  const getAvgDistance = (i, fixedJointToPoint, originalJTP) => {
    const average = (xs) => xs.reduce((a, b) => a + b) / xs.length;
    const distances = Object.entries(fixedJointToPoint).map(
      ([joint, { x, y }]) => {
        if (originalJTP[joint].length > i) {
          return distance2D({ x, y }, originalJTP[joint][i]["position"]);
        }
        return Infinity;
      }
    );
    return average(distances);
  };

  // TODO: use w, h in distance calculations
  const addCorrection = (t, fixedJointToPoint) => {
    const prevIdx = Math.max(
      ...Object.keys(fixedJointToPoint).map((joint) =>
        _indexAtOrBefore(joint, t)
      )
    );

    // TODO: associating this edit with the "next" index makes for a bad UX
    // The "next" index isn't visible on screen yet if the video is paused, so it seems to
    // disappear on mouse up.
    // const nextIdx = prevIdx + 1;
    // const avgDistancePrev = getAvgDistance(prevIdx, fixedJointToPoint, jTP);
    // const avgDistanceNext = getAvgDistance(nextIdx, fixedJointToPoint, jTP);

    // For now let's always pick the "previous" index
    const idxToReplace = prevIdx;

    const copy = { ...jTP };
    Object.keys(fixedJointToPoint).forEach((joint) => {
      const newVal = {
        ...jTP[joint][idxToReplace],
        position: fixedJointToPoint[joint],
      };
      copy[joint] = [...copy[joint]];
      copy[joint][idxToReplace] = newVal;
    });
    return JointToPoints(copy);
  };

  jTP.getJointNames = () => jointNames;
  jTP.positionAt = positionAt;
  jTP.addCorrection = addCorrection;
  jTP.estimateFps = estimateFps;
  return jTP;
}
