mirror of https://github.com/vladmandic/human
105 lines
5.1 KiB
TypeScript
105 lines
5.1 KiB
TypeScript
import * as keypoints from './keypoints';
|
|
import * as decoders from './decodeParts';
|
|
import * as utils from './utils';
|
|
|
|
const parentChildrenTuples = keypoints.poseChain.map(([parentJoinName, childJoinName]) => ([keypoints.partIds[parentJoinName], keypoints.partIds[childJoinName]]));
|
|
const parentToChildEdges = parentChildrenTuples.map(([, childJointId]) => childJointId);
|
|
const childToParentEdges = parentChildrenTuples.map(([parentJointId]) => parentJointId);
|
|
|
|
const defaultOutputStride = 16;
|
|
|
|
function getDisplacement(edgeId, point, displacements) {
|
|
const numEdges = displacements.shape[2] / 2;
|
|
return {
|
|
y: displacements.get(point.y, point.x, edgeId),
|
|
x: displacements.get(point.y, point.x, numEdges + edgeId),
|
|
};
|
|
}
|
|
|
|
function getStridedIndexNearPoint(point, outputStride, height, width) {
|
|
return {
|
|
y: utils.clamp(Math.round(point.y / outputStride), 0, height - 1),
|
|
x: utils.clamp(Math.round(point.x / outputStride), 0, width - 1),
|
|
};
|
|
}
|
|
|
|
function traverseToTargetKeypoint(edgeId, sourceKeypoint, targetKeypointId, scoresBuffer, offsets, outputStride, displacements, offsetRefineStep = 2) {
|
|
const [height, width] = scoresBuffer.shape;
|
|
// Nearest neighbor interpolation for the source->target displacements.
|
|
const sourceKeypointIndices = getStridedIndexNearPoint(sourceKeypoint.position, outputStride, height, width);
|
|
const displacement = getDisplacement(edgeId, sourceKeypointIndices, displacements);
|
|
const displacedPoint = utils.addVectors(sourceKeypoint.position, displacement);
|
|
let targetKeypoint = displacedPoint;
|
|
for (let i = 0; i < offsetRefineStep; i++) {
|
|
const targetKeypointIndices = getStridedIndexNearPoint(targetKeypoint, outputStride, height, width);
|
|
const offsetPoint = utils.getOffsetPoint(targetKeypointIndices.y, targetKeypointIndices.x, targetKeypointId, offsets);
|
|
targetKeypoint = utils.addVectors({
|
|
x: targetKeypointIndices.x * outputStride,
|
|
y: targetKeypointIndices.y * outputStride,
|
|
}, { x: offsetPoint.x, y: offsetPoint.y });
|
|
}
|
|
const targetKeyPointIndices = getStridedIndexNearPoint(targetKeypoint, outputStride, height, width);
|
|
const score = scoresBuffer.get(targetKeyPointIndices.y, targetKeyPointIndices.x, targetKeypointId);
|
|
return { position: targetKeypoint, part: keypoints.partNames[targetKeypointId], score };
|
|
}
|
|
|
|
export function decodePose(root, scores, offsets, outputStride, displacementsFwd, displacementsBwd) {
|
|
const numParts = scores.shape[2];
|
|
const numEdges = parentToChildEdges.length;
|
|
const instanceKeypoints = new Array(numParts);
|
|
// Start a new detection instance at the position of the root.
|
|
const { part: rootPart, score: rootScore } = root;
|
|
const rootPoint = utils.getImageCoords(rootPart, outputStride, offsets);
|
|
instanceKeypoints[rootPart.id] = {
|
|
score: rootScore,
|
|
part: keypoints.partNames[rootPart.id],
|
|
position: rootPoint,
|
|
};
|
|
// Decode the part positions upwards in the tree, following the backward displacements.
|
|
for (let edge = numEdges - 1; edge >= 0; --edge) {
|
|
const sourceKeypointId = parentToChildEdges[edge];
|
|
const targetKeypointId = childToParentEdges[edge];
|
|
if (instanceKeypoints[sourceKeypointId] && !instanceKeypoints[targetKeypointId]) {
|
|
instanceKeypoints[targetKeypointId] = traverseToTargetKeypoint(edge, instanceKeypoints[sourceKeypointId], targetKeypointId, scores, offsets, outputStride, displacementsBwd);
|
|
}
|
|
}
|
|
// Decode the part positions downwards in the tree, following the forward displacements.
|
|
for (let edge = 0; edge < numEdges; ++edge) {
|
|
const sourceKeypointId = childToParentEdges[edge];
|
|
const targetKeypointId = parentToChildEdges[edge];
|
|
if (instanceKeypoints[sourceKeypointId] && !instanceKeypoints[targetKeypointId]) {
|
|
instanceKeypoints[targetKeypointId] = traverseToTargetKeypoint(edge, instanceKeypoints[sourceKeypointId], targetKeypointId, scores, offsets, outputStride, displacementsFwd);
|
|
}
|
|
}
|
|
return instanceKeypoints;
|
|
}
|
|
|
|
export async function decodeSinglePose(heatmapScores, offsets, minScore) {
|
|
const heatmapValues = decoders.argmax2d(heatmapScores);
|
|
const allTensorBuffers = await Promise.all([heatmapScores.buffer(), offsets.buffer(), heatmapValues.buffer()]);
|
|
const scoresBuffer = allTensorBuffers[0];
|
|
const offsetsBuffer = allTensorBuffers[1];
|
|
const heatmapValuesBuffer = allTensorBuffers[2];
|
|
const offsetPoints = decoders.getOffsetPoints(heatmapValuesBuffer, defaultOutputStride, offsetsBuffer);
|
|
const offsetPointsData = offsetPoints.dataSync();
|
|
const keypointConfidence = decoders.getPointsConfidence(scoresBuffer, heatmapValuesBuffer);
|
|
let avgScore = 0.0;
|
|
const filteredKeypoints = keypointConfidence
|
|
.filter((score) => score > minScore)
|
|
.map((score, i) => {
|
|
avgScore += score;
|
|
return {
|
|
position: {
|
|
y: offsetPointsData[2 * i + 0], // offsetPointsBuffer.get(i, 0),
|
|
x: offsetPointsData[2 * i + 1], // offsetPointsBuffer.get(i, 1),
|
|
},
|
|
part: keypoints.partNames[i],
|
|
score,
|
|
};
|
|
});
|
|
heatmapValues.dispose();
|
|
offsetPoints.dispose();
|
|
const box = utils.getBoundingBox(filteredKeypoints);
|
|
return { keypoints: filteredKeypoints, box, score: Math.round(100 * avgScore / filteredKeypoints.length) / 100 };
|
|
}
|