human/src/posenet/poses.ts

141 lines
7.2 KiB
TypeScript
Raw Normal View History

2021-04-24 22:04:49 +02:00
import * as utils from './utils';
import * as kpt from './keypoints';
const localMaximumRadius = 1;
const defaultOutputStride = 16;
2021-04-25 19:16:04 +02:00
const squaredNmsRadius = 20 ** 2;
2021-04-24 22:04:49 +02:00
function traverseToTargetKeypoint(edgeId, sourceKeypoint, targetKeypointId, scoresBuffer, offsets, outputStride, displacements, offsetRefineStep = 2) {
const getDisplacement = (point) => ({
y: displacements.get(point.y, point.x, edgeId),
x: displacements.get(point.y, point.x, (displacements.shape[2] / 2) + edgeId),
});
const getStridedIndexNearPoint = (point, height, width) => ({
y: utils.clamp(Math.round(point.y / outputStride), 0, height - 1),
x: utils.clamp(Math.round(point.x / outputStride), 0, width - 1),
});
const [height, width] = scoresBuffer.shape;
// Nearest neighbor interpolation for the source->target displacements.
const sourceKeypointIndices = getStridedIndexNearPoint(sourceKeypoint.position, height, width);
const displacement = getDisplacement(sourceKeypointIndices);
const displacedPoint = utils.addVectors(sourceKeypoint.position, displacement);
let targetKeypoint = displacedPoint;
for (let i = 0; i < offsetRefineStep; i++) {
const targetKeypointIndices = getStridedIndexNearPoint(targetKeypoint, height, width);
const offsetPoint = utils.getOffsetPoint(targetKeypointIndices.y, targetKeypointIndices.x, targetKeypointId, offsets);
targetKeypoint = utils.addVectors({
x: targetKeypointIndices.x * outputStride,
y: targetKeypointIndices.y * outputStride,
}, { x: offsetPoint.x, y: offsetPoint.y });
}
const targetKeyPointIndices = getStridedIndexNearPoint(targetKeypoint, height, width);
const score = scoresBuffer.get(targetKeyPointIndices.y, targetKeyPointIndices.x, targetKeypointId);
return { position: targetKeypoint, part: kpt.partNames[targetKeypointId], score };
}
export function decodePose(root, scores, offsets, outputStride, displacementsFwd, displacementsBwd) {
const parentChildrenTuples = kpt.poseChain.map(([parentJoinName, childJoinName]) => ([kpt.partIds[parentJoinName], kpt.partIds[childJoinName]]));
const parentToChildEdges = parentChildrenTuples.map(([, childJointId]) => childJointId);
const childToParentEdges = parentChildrenTuples.map(([parentJointId]) => parentJointId);
const numParts = scores.shape[2];
const numEdges = parentToChildEdges.length;
const instanceKeypoints = new Array(numParts);
// Start a new detection instance at the position of the root.
const { part: rootPart, score: rootScore } = root;
const rootPoint = utils.getImageCoords(rootPart, outputStride, offsets);
instanceKeypoints[rootPart.id] = {
score: rootScore,
part: kpt.partNames[rootPart.id],
position: rootPoint,
};
// Decode the part positions upwards in the tree, following the backward displacements.
for (let edge = numEdges - 1; edge >= 0; --edge) {
const sourceKeypointId = parentToChildEdges[edge];
const targetKeypointId = childToParentEdges[edge];
if (instanceKeypoints[sourceKeypointId] && !instanceKeypoints[targetKeypointId]) {
instanceKeypoints[targetKeypointId] = traverseToTargetKeypoint(edge, instanceKeypoints[sourceKeypointId], targetKeypointId, scores, offsets, outputStride, displacementsBwd);
}
}
// Decode the part positions downwards in the tree, following the forward displacements.
for (let edge = 0; edge < numEdges; ++edge) {
const sourceKeypointId = childToParentEdges[edge];
const targetKeypointId = parentToChildEdges[edge];
if (instanceKeypoints[sourceKeypointId] && !instanceKeypoints[targetKeypointId]) {
instanceKeypoints[targetKeypointId] = traverseToTargetKeypoint(edge, instanceKeypoints[sourceKeypointId], targetKeypointId, scores, offsets, outputStride, displacementsFwd);
}
}
return instanceKeypoints;
}
function scoreIsMaximumInLocalWindow(keypointId, score, heatmapY, heatmapX, scores) {
const [height, width] = scores.shape;
let localMaximum = true;
const yStart = Math.max(heatmapY - localMaximumRadius, 0);
const yEnd = Math.min(heatmapY + localMaximumRadius + 1, height);
for (let yCurrent = yStart; yCurrent < yEnd; ++yCurrent) {
const xStart = Math.max(heatmapX - localMaximumRadius, 0);
const xEnd = Math.min(heatmapX + localMaximumRadius + 1, width);
for (let xCurrent = xStart; xCurrent < xEnd; ++xCurrent) {
if (scores.get(yCurrent, xCurrent, keypointId) > score) {
localMaximum = false;
break;
}
}
if (!localMaximum) break;
}
return localMaximum;
}
2021-04-25 19:16:04 +02:00
export function buildPartWithScoreQueue(minConfidence, scores) {
2021-04-24 22:04:49 +02:00
const [height, width, numKeypoints] = scores.shape;
const queue = new utils.MaxHeap(height * width * numKeypoints, ({ score }) => score);
for (let heatmapY = 0; heatmapY < height; ++heatmapY) {
for (let heatmapX = 0; heatmapX < width; ++heatmapX) {
for (let keypointId = 0; keypointId < numKeypoints; ++keypointId) {
const score = scores.get(heatmapY, heatmapX, keypointId);
// Only consider parts with score greater or equal to threshold as root candidates.
2021-04-25 19:16:04 +02:00
if (score < minConfidence) continue;
2021-04-24 22:04:49 +02:00
// Only consider keypoints whose score is maximum in a local window.
if (scoreIsMaximumInLocalWindow(keypointId, score, heatmapY, heatmapX, scores)) queue.enqueue({ score, part: { heatmapY, heatmapX, id: keypointId } });
}
}
}
return queue;
}
2021-04-25 19:16:04 +02:00
function withinRadius(poses, { x, y }, keypointId) {
2021-04-24 22:04:49 +02:00
return poses.some(({ keypoints }) => {
const correspondingKeypoint = keypoints[keypointId].position;
return utils.squaredDistance(y, x, correspondingKeypoint.y, correspondingKeypoint.x) <= squaredNmsRadius;
});
}
2021-04-25 19:16:04 +02:00
function getInstanceScore(existingPoses, instanceKeypoints) {
2021-04-24 22:04:49 +02:00
const notOverlappedKeypointScores = instanceKeypoints.reduce((result, { position, score }, keypointId) => {
2021-04-25 19:16:04 +02:00
if (!withinRadius(existingPoses, position, keypointId)) result += score;
2021-04-24 22:04:49 +02:00
return result;
}, 0.0);
return notOverlappedKeypointScores / instanceKeypoints.length;
}
2021-04-25 19:16:04 +02:00
export function decode(offsetsBuffer, scoresBuffer, displacementsFwdBuffer, displacementsBwdBuffer, maxDetected, minConfidence) {
2021-04-24 22:04:49 +02:00
const poses: Array<{ keypoints: any, box: any, score: number }> = [];
2021-04-25 19:16:04 +02:00
const queue = buildPartWithScoreQueue(minConfidence, scoresBuffer);
// Generate at most maxDetected object instances per image in decreasing root part score order.
while (poses.length < maxDetected && !queue.empty()) {
2021-04-24 22:04:49 +02:00
// The top element in the queue is the next root candidate.
const root = queue.dequeue();
// Part-based non-maximum suppression: We reject a root candidate if it is within a disk of `nmsRadius` pixels from the corresponding part of a previously detected instance.
const rootImageCoords = utils.getImageCoords(root.part, defaultOutputStride, offsetsBuffer);
2021-04-25 19:16:04 +02:00
if (withinRadius(poses, rootImageCoords, root.part.id)) continue;
2021-04-24 22:04:49 +02:00
// Else start a new detection instance at the position of the root.
const allKeypoints = decodePose(root, scoresBuffer, offsetsBuffer, defaultOutputStride, displacementsFwdBuffer, displacementsBwdBuffer);
2021-04-25 19:16:04 +02:00
const keypoints = allKeypoints.filter((a) => a.score > minConfidence);
const score = getInstanceScore(poses, keypoints);
2021-04-24 22:04:49 +02:00
const box = utils.getBoundingBox(keypoints);
2021-04-25 19:16:04 +02:00
if (score > minConfidence) poses.push({ keypoints, box, score: Math.round(100 * score) / 100 });
2021-04-24 22:04:49 +02:00
}
return poses;
}