human/src/posenet/decodeSingle.ts

105 lines
5.1 KiB
TypeScript

import * as keypoints from './keypoints';
import * as decoders from './decodeParts';
import * as utils from './utils';
const parentChildrenTuples = keypoints.poseChain.map(([parentJoinName, childJoinName]) => ([keypoints.partIds[parentJoinName], keypoints.partIds[childJoinName]]));
const parentToChildEdges = parentChildrenTuples.map(([, childJointId]) => childJointId);
const childToParentEdges = parentChildrenTuples.map(([parentJointId]) => parentJointId);
const defaultOutputStride = 16;
function getDisplacement(edgeId, point, displacements) {
const numEdges = displacements.shape[2] / 2;
return {
y: displacements.get(point.y, point.x, edgeId),
x: displacements.get(point.y, point.x, numEdges + edgeId),
};
}
function getStridedIndexNearPoint(point, outputStride, height, width) {
return {
y: utils.clamp(Math.round(point.y / outputStride), 0, height - 1),
x: utils.clamp(Math.round(point.x / outputStride), 0, width - 1),
};
}
function traverseToTargetKeypoint(edgeId, sourceKeypoint, targetKeypointId, scoresBuffer, offsets, outputStride, displacements, offsetRefineStep = 2) {
const [height, width] = scoresBuffer.shape;
// Nearest neighbor interpolation for the source->target displacements.
const sourceKeypointIndices = getStridedIndexNearPoint(sourceKeypoint.position, outputStride, height, width);
const displacement = getDisplacement(edgeId, sourceKeypointIndices, displacements);
const displacedPoint = utils.addVectors(sourceKeypoint.position, displacement);
let targetKeypoint = displacedPoint;
for (let i = 0; i < offsetRefineStep; i++) {
const targetKeypointIndices = getStridedIndexNearPoint(targetKeypoint, outputStride, height, width);
const offsetPoint = utils.getOffsetPoint(targetKeypointIndices.y, targetKeypointIndices.x, targetKeypointId, offsets);
targetKeypoint = utils.addVectors({
x: targetKeypointIndices.x * outputStride,
y: targetKeypointIndices.y * outputStride,
}, { x: offsetPoint.x, y: offsetPoint.y });
}
const targetKeyPointIndices = getStridedIndexNearPoint(targetKeypoint, outputStride, height, width);
const score = scoresBuffer.get(targetKeyPointIndices.y, targetKeyPointIndices.x, targetKeypointId);
return { position: targetKeypoint, part: keypoints.partNames[targetKeypointId], score };
}
export function decodePose(root, scores, offsets, outputStride, displacementsFwd, displacementsBwd) {
const numParts = scores.shape[2];
const numEdges = parentToChildEdges.length;
const instanceKeypoints = new Array(numParts);
// Start a new detection instance at the position of the root.
const { part: rootPart, score: rootScore } = root;
const rootPoint = utils.getImageCoords(rootPart, outputStride, offsets);
instanceKeypoints[rootPart.id] = {
score: rootScore,
part: keypoints.partNames[rootPart.id],
position: rootPoint,
};
// Decode the part positions upwards in the tree, following the backward displacements.
for (let edge = numEdges - 1; edge >= 0; --edge) {
const sourceKeypointId = parentToChildEdges[edge];
const targetKeypointId = childToParentEdges[edge];
if (instanceKeypoints[sourceKeypointId] && !instanceKeypoints[targetKeypointId]) {
instanceKeypoints[targetKeypointId] = traverseToTargetKeypoint(edge, instanceKeypoints[sourceKeypointId], targetKeypointId, scores, offsets, outputStride, displacementsBwd);
}
}
// Decode the part positions downwards in the tree, following the forward displacements.
for (let edge = 0; edge < numEdges; ++edge) {
const sourceKeypointId = childToParentEdges[edge];
const targetKeypointId = parentToChildEdges[edge];
if (instanceKeypoints[sourceKeypointId] && !instanceKeypoints[targetKeypointId]) {
instanceKeypoints[targetKeypointId] = traverseToTargetKeypoint(edge, instanceKeypoints[sourceKeypointId], targetKeypointId, scores, offsets, outputStride, displacementsFwd);
}
}
return instanceKeypoints;
}
export async function decodeSinglePose(heatmapScores, offsets, minScore) {
const heatmapValues = decoders.argmax2d(heatmapScores);
const allTensorBuffers = await Promise.all([heatmapScores.buffer(), offsets.buffer(), heatmapValues.buffer()]);
const scoresBuffer = allTensorBuffers[0];
const offsetsBuffer = allTensorBuffers[1];
const heatmapValuesBuffer = allTensorBuffers[2];
const offsetPoints = decoders.getOffsetPoints(heatmapValuesBuffer, defaultOutputStride, offsetsBuffer);
const offsetPointsData = offsetPoints.dataSync();
const keypointConfidence = decoders.getPointsConfidence(scoresBuffer, heatmapValuesBuffer);
let avgScore = 0.0;
const filteredKeypoints = keypointConfidence
.filter((score) => score > minScore)
.map((score, i) => {
avgScore += score;
return {
position: {
y: offsetPointsData[2 * i + 0], // offsetPointsBuffer.get(i, 0),
x: offsetPointsData[2 * i + 1], // offsetPointsBuffer.get(i, 1),
},
part: keypoints.partNames[i],
score,
};
});
heatmapValues.dispose();
offsetPoints.dispose();
const box = utils.getBoundingBox(filteredKeypoints);
return { keypoints: filteredKeypoints, box, score: Math.round(100 * avgScore / filteredKeypoints.length) / 100 };
}