mirror of https://github.com/vladmandic/human
85 lines
4.4 KiB
JavaScript
85 lines
4.4 KiB
JavaScript
const keypoints = require('./keypoints');
|
|
const vectors = require('./vectors');
|
|
|
|
const parentChildrenTuples = keypoints.poseChain.map(([parentJoinName, childJoinName]) => ([keypoints.partIds[parentJoinName], keypoints.partIds[childJoinName]]));
|
|
const parentToChildEdges = parentChildrenTuples.map(([, childJointId]) => childJointId);
|
|
const childToParentEdges = parentChildrenTuples.map(([parentJointId]) => parentJointId);
|
|
function getDisplacement(edgeId, point, displacements) {
|
|
const numEdges = displacements.shape[2] / 2;
|
|
return {
|
|
y: displacements.get(point.y, point.x, edgeId),
|
|
x: displacements.get(point.y, point.x, numEdges + edgeId),
|
|
};
|
|
}
|
|
function getStridedIndexNearPoint(point, outputStride, height, width) {
|
|
return {
|
|
y: vectors.clamp(Math.round(point.y / outputStride), 0, height - 1),
|
|
x: vectors.clamp(Math.round(point.x / outputStride), 0, width - 1),
|
|
};
|
|
}
|
|
/**
|
|
* We get a new keypoint along the `edgeId` for the pose instance, assuming
|
|
* that the position of the `idSource` part is already known. For this, we
|
|
* follow the displacement vector from the source to target part (stored in
|
|
* the `i`-t channel of the displacement tensor). The displaced keypoint
|
|
* vector is refined using the offset vector by `offsetRefineStep` times.
|
|
*/
|
|
function traverseToTargetKeypoint(edgeId, sourceKeypoint, targetKeypointId, scoresBuffer, offsets, outputStride, displacements, offsetRefineStep = 2) {
|
|
const [height, width] = scoresBuffer.shape;
|
|
// Nearest neighbor interpolation for the source->target displacements.
|
|
const sourceKeypointIndices = getStridedIndexNearPoint(sourceKeypoint.position, outputStride, height, width);
|
|
const displacement = getDisplacement(edgeId, sourceKeypointIndices, displacements);
|
|
const displacedPoint = vectors.addVectors(sourceKeypoint.position, displacement);
|
|
let targetKeypoint = displacedPoint;
|
|
for (let i = 0; i < offsetRefineStep; i++) {
|
|
const targetKeypointIndices = getStridedIndexNearPoint(targetKeypoint, outputStride, height, width);
|
|
const offsetPoint = vectors.getOffsetPoint(targetKeypointIndices.y, targetKeypointIndices.x, targetKeypointId, offsets);
|
|
targetKeypoint = vectors.addVectors({
|
|
x: targetKeypointIndices.x * outputStride,
|
|
y: targetKeypointIndices.y * outputStride,
|
|
}, { x: offsetPoint.x, y: offsetPoint.y });
|
|
}
|
|
const targetKeyPointIndices = getStridedIndexNearPoint(targetKeypoint, outputStride, height, width);
|
|
const score = scoresBuffer.get(targetKeyPointIndices.y, targetKeyPointIndices.x, targetKeypointId);
|
|
return { position: targetKeypoint, part: keypoints.partNames[targetKeypointId], score };
|
|
}
|
|
/**
|
|
* Follows the displacement fields to decode the full pose of the object
|
|
* instance given the position of a part that acts as root.
|
|
*
|
|
* @return An array of decoded keypoints and their scores for a single pose
|
|
*/
|
|
function decodePose(root, scores, offsets, outputStride, displacementsFwd, displacementsBwd) {
|
|
const numParts = scores.shape[2];
|
|
const numEdges = parentToChildEdges.length;
|
|
const instanceKeypoints = new Array(numParts);
|
|
// Start a new detection instance at the position of the root.
|
|
const { part: rootPart, score: rootScore } = root;
|
|
const rootPoint = vectors.getImageCoords(rootPart, outputStride, offsets);
|
|
instanceKeypoints[rootPart.id] = {
|
|
score: rootScore,
|
|
part: keypoints.partNames[rootPart.id],
|
|
position: rootPoint,
|
|
};
|
|
// Decode the part positions upwards in the tree, following the backward
|
|
// displacements.
|
|
for (let edge = numEdges - 1; edge >= 0; --edge) {
|
|
const sourceKeypointId = parentToChildEdges[edge];
|
|
const targetKeypointId = childToParentEdges[edge];
|
|
if (instanceKeypoints[sourceKeypointId] && !instanceKeypoints[targetKeypointId]) {
|
|
instanceKeypoints[targetKeypointId] = traverseToTargetKeypoint(edge, instanceKeypoints[sourceKeypointId], targetKeypointId, scores, offsets, outputStride, displacementsBwd);
|
|
}
|
|
}
|
|
// Decode the part positions downwards in the tree, following the forward
|
|
// displacements.
|
|
for (let edge = 0; edge < numEdges; ++edge) {
|
|
const sourceKeypointId = childToParentEdges[edge];
|
|
const targetKeypointId = parentToChildEdges[edge];
|
|
if (instanceKeypoints[sourceKeypointId] && !instanceKeypoints[targetKeypointId]) {
|
|
instanceKeypoints[targetKeypointId] = traverseToTargetKeypoint(edge, instanceKeypoints[sourceKeypointId], targetKeypointId, scores, offsets, outputStride, displacementsFwd);
|
|
}
|
|
}
|
|
return instanceKeypoints;
|
|
}
|
|
exports.decodePose = decodePose;
|