human/src/body/posenet.ts

187 lines
8.9 KiB
TypeScript
Raw Normal View History

/**
* PoseNet body detection model implementation
2021-09-28 18:01:48 +02:00
*
* Based on: [**PoseNet**](https://medium.com/tensorflow/real-time-human-pose-estimation-in-the-browser-with-tensorflow-js-7dd0bc881cd5)
*/
2022-10-17 02:28:57 +02:00
import * as tf from 'dist/tfjs.esm.js';
2022-01-17 17:03:21 +01:00
import { log } from '../util/util';
2022-01-16 15:49:55 +01:00
import { loadModel } from '../tfjs/load';
2021-12-15 15:26:32 +01:00
import type { BodyResult, BodyLandmark, Box } from '../result';
2022-10-17 02:28:57 +02:00
import type { Tensor, GraphModel, Tensor4D } from '../tfjs/types';
2021-09-28 18:01:48 +02:00
import type { Config } from '../config';
import { env } from '../util/env';
import * as utils from './posenetutils';
let model: GraphModel;
const poseNetOutputs = ['MobilenetV1/offset_2/BiasAdd'/* offsets */, 'MobilenetV1/heatmap_2/BiasAdd'/* heatmapScores */, 'MobilenetV1/displacement_fwd_2/BiasAdd'/* displacementFwd */, 'MobilenetV1/displacement_bwd_2/BiasAdd'/* displacementBwd */];
2021-04-24 22:04:49 +02:00
const localMaximumRadius = 1;
2021-05-05 02:46:33 +02:00
const outputStride = 16;
const squaredNmsRadius = 50 ** 2;
2021-04-24 22:04:49 +02:00
2022-08-21 21:23:03 +02:00
function traverse(edgeId: number, sourceKeypoint, targetId, scores, offsets, displacements, offsetRefineStep = 2) {
2021-04-24 22:04:49 +02:00
const getDisplacement = (point) => ({
y: displacements.get(point.y, point.x, edgeId),
x: displacements.get(point.y, point.x, (displacements.shape[2] / 2) + edgeId),
});
const getStridedIndexNearPoint = (point, height, width) => ({
y: utils.clamp(Math.round(point.y / outputStride), 0, height - 1),
x: utils.clamp(Math.round(point.x / outputStride), 0, width - 1),
});
2021-05-05 16:07:44 +02:00
const [height, width] = scores.shape;
2021-04-24 22:04:49 +02:00
// Nearest neighbor interpolation for the source->target displacements.
const sourceKeypointIndices = getStridedIndexNearPoint(sourceKeypoint.position, height, width);
const displacement = getDisplacement(sourceKeypointIndices);
const displacedPoint = utils.addVectors(sourceKeypoint.position, displacement);
let targetKeypoint = displacedPoint;
for (let i = 0; i < offsetRefineStep; i++) {
const targetKeypointIndices = getStridedIndexNearPoint(targetKeypoint, height, width);
2021-05-05 16:07:44 +02:00
const offsetPoint = utils.getOffsetPoint(targetKeypointIndices.y, targetKeypointIndices.x, targetId, offsets);
targetKeypoint = utils.addVectors(
{ x: targetKeypointIndices.x * outputStride, y: targetKeypointIndices.y * outputStride },
{ x: offsetPoint.x, y: offsetPoint.y },
);
2021-04-24 22:04:49 +02:00
}
const targetKeyPointIndices = getStridedIndexNearPoint(targetKeypoint, height, width);
2021-05-05 16:07:44 +02:00
const score = scores.get(targetKeyPointIndices.y, targetKeyPointIndices.x, targetId);
2021-09-28 18:01:48 +02:00
return { position: targetKeypoint, part: utils.partNames[targetId], score };
2021-04-24 22:04:49 +02:00
}
2021-05-05 02:46:33 +02:00
export function decodePose(root, scores, offsets, displacementsFwd, displacementsBwd) {
2021-09-28 18:01:48 +02:00
const tuples = utils.poseChain.map(([parentJoinName, childJoinName]) => ([utils.partIds[parentJoinName], utils.partIds[childJoinName]]));
2021-05-05 16:07:44 +02:00
const edgesFwd = tuples.map(([, childJointId]) => childJointId);
const edgesBwd = tuples.map(([parentJointId]) => parentJointId);
2021-05-05 02:46:33 +02:00
const numParts = scores.shape[2]; // [21,21,17]
2021-05-05 16:07:44 +02:00
const numEdges = edgesFwd.length;
const keypoints = new Array(numParts);
2021-04-24 22:04:49 +02:00
// Start a new detection instance at the position of the root.
2021-05-05 02:46:33 +02:00
const rootPoint = utils.getImageCoords(root.part, outputStride, offsets);
2021-05-05 16:07:44 +02:00
keypoints[root.part.id] = {
2021-05-05 02:46:33 +02:00
score: root.score,
2021-12-15 15:26:32 +01:00
part: utils.partNames[root.part.id] as BodyLandmark,
2021-04-24 22:04:49 +02:00
position: rootPoint,
};
// Decode the part positions upwards in the tree, following the backward displacements.
for (let edge = numEdges - 1; edge >= 0; --edge) {
2021-05-05 16:07:44 +02:00
const sourceId = edgesFwd[edge];
const targetId = edgesBwd[edge];
if (keypoints[sourceId] && !keypoints[targetId]) {
keypoints[targetId] = traverse(edge, keypoints[sourceId], targetId, scores, offsets, displacementsBwd);
2021-04-24 22:04:49 +02:00
}
}
// Decode the part positions downwards in the tree, following the forward displacements.
for (let edge = 0; edge < numEdges; ++edge) {
2021-05-05 16:07:44 +02:00
const sourceId = edgesBwd[edge];
const targetId = edgesFwd[edge];
if (keypoints[sourceId] && !keypoints[targetId]) {
keypoints[targetId] = traverse(edge, keypoints[sourceId], targetId, scores, offsets, displacementsFwd);
2021-04-24 22:04:49 +02:00
}
}
2021-05-05 16:07:44 +02:00
return keypoints;
2021-04-24 22:04:49 +02:00
}
2022-08-21 21:23:03 +02:00
function scoreIsMaximumInLocalWindow(keypointId, score: number, heatmapY: number, heatmapX: number, scores) {
const [height, width]: [number, number] = scores.shape;
2021-04-24 22:04:49 +02:00
let localMaximum = true;
const yStart = Math.max(heatmapY - localMaximumRadius, 0);
const yEnd = Math.min(heatmapY + localMaximumRadius + 1, height);
for (let yCurrent = yStart; yCurrent < yEnd; ++yCurrent) {
const xStart = Math.max(heatmapX - localMaximumRadius, 0);
const xEnd = Math.min(heatmapX + localMaximumRadius + 1, width);
for (let xCurrent = xStart; xCurrent < xEnd; ++xCurrent) {
if (scores.get(yCurrent, xCurrent, keypointId) > score) {
localMaximum = false;
break;
}
}
if (!localMaximum) break;
}
return localMaximum;
}
2021-04-25 19:16:04 +02:00
export function buildPartWithScoreQueue(minConfidence, scores) {
2021-04-24 22:04:49 +02:00
const [height, width, numKeypoints] = scores.shape;
const queue = new utils.MaxHeap(height * width * numKeypoints, ({ score }) => score);
for (let heatmapY = 0; heatmapY < height; ++heatmapY) {
for (let heatmapX = 0; heatmapX < width; ++heatmapX) {
for (let keypointId = 0; keypointId < numKeypoints; ++keypointId) {
const score = scores.get(heatmapY, heatmapX, keypointId);
// Only consider parts with score greater or equal to threshold as root candidates.
2021-04-25 19:16:04 +02:00
if (score < minConfidence) continue;
2021-04-24 22:04:49 +02:00
// Only consider keypoints whose score is maximum in a local window.
if (scoreIsMaximumInLocalWindow(keypointId, score, heatmapY, heatmapX, scores)) queue.enqueue({ score, part: { heatmapY, heatmapX, id: keypointId } });
}
}
}
return queue;
}
2021-04-25 19:16:04 +02:00
function withinRadius(poses, { x, y }, keypointId) {
2021-04-24 22:04:49 +02:00
return poses.some(({ keypoints }) => {
2021-05-05 16:07:44 +02:00
const correspondingKeypoint = keypoints[keypointId]?.position;
if (!correspondingKeypoint) return false;
2021-04-24 22:04:49 +02:00
return utils.squaredDistance(y, x, correspondingKeypoint.y, correspondingKeypoint.x) <= squaredNmsRadius;
});
}
2021-05-05 16:07:44 +02:00
function getInstanceScore(existingPoses, keypoints) {
const notOverlappedKeypointScores = keypoints.reduce((result, { position, score }, keypointId) => {
2021-04-25 19:16:04 +02:00
if (!withinRadius(existingPoses, position, keypointId)) result += score;
2021-04-24 22:04:49 +02:00
return result;
}, 0.0);
2021-05-05 16:07:44 +02:00
return notOverlappedKeypointScores / keypoints.length;
2021-04-24 22:04:49 +02:00
}
2021-05-05 16:07:44 +02:00
export function decode(offsets, scores, displacementsFwd, displacementsBwd, maxDetected, minConfidence) {
2022-08-21 19:34:51 +02:00
const poses: { keypoints, box: Box, score: number }[] = [];
2021-05-05 16:07:44 +02:00
const queue = buildPartWithScoreQueue(minConfidence, scores);
2021-04-25 19:16:04 +02:00
// Generate at most maxDetected object instances per image in decreasing root part score order.
while (poses.length < maxDetected && !queue.empty()) {
2021-04-24 22:04:49 +02:00
// The top element in the queue is the next root candidate.
const root = queue.dequeue();
// Part-based non-maximum suppression: We reject a root candidate if it is within a disk of `nmsRadius` pixels from the corresponding part of a previously detected instance.
// @ts-ignore this one is tree walk
2021-05-05 16:07:44 +02:00
const rootImageCoords = utils.getImageCoords(root.part, outputStride, offsets);
// @ts-ignore this one is tree walk
2021-04-25 19:16:04 +02:00
if (withinRadius(poses, rootImageCoords, root.part.id)) continue;
2021-04-24 22:04:49 +02:00
// Else start a new detection instance at the position of the root.
2021-05-05 16:07:44 +02:00
let keypoints = decodePose(root, scores, offsets, displacementsFwd, displacementsBwd);
2021-05-05 02:46:33 +02:00
keypoints = keypoints.filter((a) => a.score > minConfidence);
2021-04-25 19:16:04 +02:00
const score = getInstanceScore(poses, keypoints);
2021-04-24 22:04:49 +02:00
const box = utils.getBoundingBox(keypoints);
2021-04-25 19:16:04 +02:00
if (score > minConfidence) poses.push({ keypoints, box, score: Math.round(100 * score) / 100 });
2021-04-24 22:04:49 +02:00
}
return poses;
}
2021-09-28 18:01:48 +02:00
2022-10-17 02:28:57 +02:00
export async function predict(input: Tensor4D, config: Config): Promise<BodyResult[]> {
2021-10-10 23:52:43 +02:00
/** posenet is mostly obsolete
* caching is not implemented
*/
2022-08-30 16:28:33 +02:00
if (!model?.['executor']) return [];
2021-09-28 18:01:48 +02:00
const res = tf.tidy(() => {
if (!model.inputs[0].shape) return [];
const resized = tf.image.resizeBilinear(input, [model.inputs[0].shape[2], model.inputs[0].shape[1]]);
const normalized = tf.sub(tf.div(tf.cast(resized, 'float32'), 127.5), 1.0);
2022-08-21 19:34:51 +02:00
const results: Tensor[] = model.execute(normalized, poseNetOutputs) as Tensor[];
2021-09-28 18:01:48 +02:00
const results3d = results.map((y) => tf.squeeze(y, [0]));
2021-12-18 18:24:01 +01:00
results3d[1] = tf.sigmoid(results3d[1]); // apply sigmoid on scores
2021-09-28 18:01:48 +02:00
return results3d;
});
const buffers = await Promise.all(res.map((tensor: Tensor) => tensor.buffer()));
for (const t of res) tf.dispose(t);
2022-08-21 21:23:03 +02:00
const decoded = decode(buffers[0], buffers[1], buffers[2], buffers[3], config.body.maxDetected, config.body.minConfidence);
2021-09-28 18:01:48 +02:00
if (!model.inputs[0].shape) return [];
2022-08-21 19:34:51 +02:00
const scaled = utils.scalePoses(decoded, [input.shape[1], input.shape[2]], [model.inputs[0].shape[2], model.inputs[0].shape[1]]);
2021-09-28 18:01:48 +02:00
return scaled;
}
export async function load(config: Config): Promise<GraphModel> {
2022-01-17 17:03:21 +01:00
if (!model || env.initial) model = await loadModel(config.body.modelPath);
else if (config.debug) log('cached model:', model['modelUrl']);
2021-09-28 18:01:48 +02:00
return model;
}