major work on body module

pull/50/head
Vladimir Mandic 2020-12-16 18:36:24 -05:00
parent bbebaccd3b
commit 1a73761da0
24 changed files with 422 additions and 450 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
node_modules
alternative

View File

@ -138,6 +138,9 @@ export default {
scoreThreshold: 0.5, // threshold for deciding when to remove boxes based on score
// in non-maximum suppression
nmsRadius: 20, // radius for deciding points are too close in non-maximum suppression
outputStride: 16, // size of block in which to run point detectopn, smaller value means higher resolution
// defined by model itself, can be 8, 16, or 32
modelType: 'MobileNet', // Human includes MobileNet version, but you can switch to ResNet
},
hand: {

View File

@ -4,6 +4,13 @@ import Menu from './menu.js';
import GLBench from './gl-bench.js';
const userConfig = {}; // add any user configuration overrides
/*
const userConfig = {
face: { enabled: false },
body: { enabled: true },
hand: { enabled: false },
};
*/
const human = new Human(userConfig);

View File

@ -136,49 +136,57 @@ async function drawBody(result, canvas, ui) {
}
if (ui.drawPolygons) {
const path = new Path2D();
let root;
let part;
// torso
part = result[i].keypoints.find((a) => a.part === 'leftShoulder');
path.moveTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightShoulder');
path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightHip');
path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftHip');
path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftShoulder');
path.lineTo(part.position.x, part.position.y);
// legs
part = result[i].keypoints.find((a) => a.part === 'leftHip');
path.moveTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftKnee');
path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftAnkle');
path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightHip');
path.moveTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightKnee');
path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightAnkle');
path.lineTo(part.position.x, part.position.y);
// arms
part = result[i].keypoints.find((a) => a.part === 'rightShoulder');
path.moveTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftShoulder');
path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftElbow');
path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftWrist');
path.lineTo(part.position.x, part.position.y);
// arms
part = result[i].keypoints.find((a) => a.part === 'leftShoulder');
path.moveTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightShoulder');
path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightElbow');
path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightWrist');
path.lineTo(part.position.x, part.position.y);
root = result[i].keypoints.find((a) => a.part === 'leftShoulder');
if (root) {
path.moveTo(root.position.x, root.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightShoulder');
if (part) path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightHip');
if (part) path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftHip');
if (part) path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftShoulder');
if (part) path.lineTo(part.position.x, part.position.y);
}
// leg left
root = result[i].keypoints.find((a) => a.part === 'leftHip');
if (root) {
path.moveTo(root.position.x, root.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftKnee');
if (part) path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftAnkle');
if (part) path.lineTo(part.position.x, part.position.y);
}
// leg right
root = result[i].keypoints.find((a) => a.part === 'rightHip');
if (root) {
path.moveTo(root.position.x, root.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightKnee');
if (part) path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightAnkle');
if (part) path.lineTo(part.position.x, part.position.y);
}
// arm left
root = result[i].keypoints.find((a) => a.part === 'leftShoulder');
if (root) {
path.moveTo(root.position.x, root.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftElbow');
if (part) path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftWrist');
if (part) path.lineTo(part.position.x, part.position.y);
}
// arm right
root = result[i].keypoints.find((a) => a.part === 'rightShoulder');
if (root) {
path.moveTo(root.position.x, root.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightElbow');
if (part) path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightWrist');
if (part) path.lineTo(part.position.x, part.position.y);
}
// draw all
ctx.stroke(path);
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,11 +1,11 @@
{
"inputs": {
"dist/human.esm.js": {
"bytes": 1836264,
"bytes": 1839401,
"imports": []
},
"demo/draw.js": {
"bytes": 10630,
"bytes": 10733,
"imports": []
},
"demo/menu.js": {
@ -17,7 +17,7 @@
"imports": []
},
"demo/browser.js": {
"bytes": 25337,
"bytes": 25450,
"imports": [
{
"path": "dist/human.esm.js"
@ -38,17 +38,17 @@
"dist/demo-browser-index.js.map": {
"imports": [],
"inputs": {},
"bytes": 2198146
"bytes": 2199701
},
"dist/demo-browser-index.js": {
"imports": [],
"exports": [],
"inputs": {
"dist/human.esm.js": {
"bytesInOutput": 1829024
"bytesInOutput": 1832161
},
"demo/draw.js": {
"bytesInOutput": 7816
"bytesInOutput": 7726
},
"demo/menu.js": {
"bytesInOutput": 11800
@ -60,7 +60,7 @@
"bytesInOutput": 19539
}
},
"bytes": 1882950
"bytes": 1885997
}
}
}

4
dist/human.esm.js vendored

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

74
dist/human.esm.json vendored
View File

@ -148,30 +148,19 @@
]
},
"src/body/modelBase.js": {
"bytes": 889,
"bytes": 1343,
"imports": [
{
"path": "dist/tfjs.esm.js"
}
]
},
"src/body/modelMobileNet.js": {
"bytes": 599,
"imports": [
{
"path": "dist/tfjs.esm.js"
},
{
"path": "src/body/modelBase.js"
}
]
},
"src/body/heapSort.js": {
"bytes": 1590,
"imports": []
},
"src/body/buildParts.js": {
"bytes": 2035,
"bytes": 1775,
"imports": [
{
"path": "src/body/heapSort.js"
@ -179,7 +168,7 @@
]
},
"src/body/keypoints.js": {
"bytes": 2291,
"bytes": 2011,
"imports": []
},
"src/body/vectors.js": {
@ -190,19 +179,33 @@
}
]
},
"src/body/decoders.js": {
"bytes": 2083,
"imports": [
{
"path": "dist/tfjs.esm.js"
},
{
"path": "src/body/keypoints.js"
}
]
},
"src/body/decodePose.js": {
"bytes": 4530,
"bytes": 5216,
"imports": [
{
"path": "src/body/keypoints.js"
},
{
"path": "src/body/vectors.js"
},
{
"path": "src/body/decoders.js"
}
]
},
"src/body/decodeMultiple.js": {
"bytes": 5608,
"bytes": 2303,
"imports": [
{
"path": "src/body/buildParts.js"
@ -224,7 +227,7 @@
]
},
"src/body/modelPoseNet.js": {
"bytes": 1905,
"bytes": 2395,
"imports": [
{
"path": "src/log.js"
@ -233,28 +236,25 @@
"path": "dist/tfjs.esm.js"
},
{
"path": "src/body/modelMobileNet.js"
"path": "src/body/modelBase.js"
},
{
"path": "src/body/decodeMultiple.js"
},
{
"path": "src/body/decodePose.js"
},
{
"path": "src/body/util.js"
}
]
},
"src/body/posenet.js": {
"bytes": 830,
"bytes": 614,
"imports": [
{
"path": "src/body/modelMobileNet.js"
},
{
"path": "src/body/modelPoseNet.js"
},
{
"path": "src/body/decodeMultiple.js"
},
{
"path": "src/body/keypoints.js"
},
@ -350,7 +350,7 @@
]
},
"config.js": {
"bytes": 9241,
"bytes": 9530,
"imports": []
},
"src/sample.js": {
@ -419,7 +419,7 @@
"dist/human.esm.js.map": {
"imports": [],
"inputs": {},
"bytes": 2104843
"bytes": 2106203
},
"dist/human.esm.js": {
"imports": [],
@ -461,10 +461,7 @@
"bytesInOutput": 1318
},
"src/body/modelBase.js": {
"bytesInOutput": 615
},
"src/body/modelMobileNet.js": {
"bytesInOutput": 375
"bytesInOutput": 1080
},
"src/body/heapSort.js": {
"bytesInOutput": 1139
@ -478,20 +475,23 @@
"src/body/vectors.js": {
"bytesInOutput": 1050
},
"src/body/decoders.js": {
"bytesInOutput": 1722
},
"src/body/decodePose.js": {
"bytesInOutput": 3111
"bytesInOutput": 4161
},
"src/body/decodeMultiple.js": {
"bytesInOutput": 1684
"bytesInOutput": 1698
},
"src/body/util.js": {
"bytesInOutput": 1913
},
"src/body/modelPoseNet.js": {
"bytesInOutput": 1569
"bytesInOutput": 2002
},
"src/body/posenet.js": {
"bytesInOutput": 832
"bytesInOutput": 622
},
"src/hand/handdetector.js": {
"bytesInOutput": 2742
@ -533,7 +533,7 @@
"bytesInOutput": 1796
},
"config.js": {
"bytesInOutput": 1454
"bytesInOutput": 1492
},
"src/sample.js": {
"bytesInOutput": 55295
@ -542,7 +542,7 @@
"bytesInOutput": 21
}
},
"bytes": 1836264
"bytes": 1839401
}
}
}

4
dist/human.js vendored

File diff suppressed because one or more lines are too long

6
dist/human.js.map vendored

File diff suppressed because one or more lines are too long

74
dist/human.json vendored
View File

@ -148,30 +148,19 @@
]
},
"src/body/modelBase.js": {
"bytes": 889,
"bytes": 1343,
"imports": [
{
"path": "dist/tfjs.esm.js"
}
]
},
"src/body/modelMobileNet.js": {
"bytes": 599,
"imports": [
{
"path": "dist/tfjs.esm.js"
},
{
"path": "src/body/modelBase.js"
}
]
},
"src/body/heapSort.js": {
"bytes": 1590,
"imports": []
},
"src/body/buildParts.js": {
"bytes": 2035,
"bytes": 1775,
"imports": [
{
"path": "src/body/heapSort.js"
@ -179,7 +168,7 @@
]
},
"src/body/keypoints.js": {
"bytes": 2291,
"bytes": 2011,
"imports": []
},
"src/body/vectors.js": {
@ -190,19 +179,33 @@
}
]
},
"src/body/decoders.js": {
"bytes": 2083,
"imports": [
{
"path": "dist/tfjs.esm.js"
},
{
"path": "src/body/keypoints.js"
}
]
},
"src/body/decodePose.js": {
"bytes": 4530,
"bytes": 5216,
"imports": [
{
"path": "src/body/keypoints.js"
},
{
"path": "src/body/vectors.js"
},
{
"path": "src/body/decoders.js"
}
]
},
"src/body/decodeMultiple.js": {
"bytes": 5608,
"bytes": 2303,
"imports": [
{
"path": "src/body/buildParts.js"
@ -224,7 +227,7 @@
]
},
"src/body/modelPoseNet.js": {
"bytes": 1905,
"bytes": 2395,
"imports": [
{
"path": "src/log.js"
@ -233,28 +236,25 @@
"path": "dist/tfjs.esm.js"
},
{
"path": "src/body/modelMobileNet.js"
"path": "src/body/modelBase.js"
},
{
"path": "src/body/decodeMultiple.js"
},
{
"path": "src/body/decodePose.js"
},
{
"path": "src/body/util.js"
}
]
},
"src/body/posenet.js": {
"bytes": 830,
"bytes": 614,
"imports": [
{
"path": "src/body/modelMobileNet.js"
},
{
"path": "src/body/modelPoseNet.js"
},
{
"path": "src/body/decodeMultiple.js"
},
{
"path": "src/body/keypoints.js"
},
@ -350,7 +350,7 @@
]
},
"config.js": {
"bytes": 9241,
"bytes": 9530,
"imports": []
},
"src/sample.js": {
@ -419,7 +419,7 @@
"dist/human.js.map": {
"imports": [],
"inputs": {},
"bytes": 2121966
"bytes": 2123326
},
"dist/human.js": {
"imports": [],
@ -459,10 +459,7 @@
"bytesInOutput": 1318
},
"src/body/modelBase.js": {
"bytesInOutput": 615
},
"src/body/modelMobileNet.js": {
"bytesInOutput": 375
"bytesInOutput": 1080
},
"src/body/heapSort.js": {
"bytesInOutput": 1139
@ -476,20 +473,23 @@
"src/body/vectors.js": {
"bytesInOutput": 1050
},
"src/body/decoders.js": {
"bytesInOutput": 1722
},
"src/body/decodePose.js": {
"bytesInOutput": 3111
"bytesInOutput": 4161
},
"src/body/decodeMultiple.js": {
"bytesInOutput": 1684
"bytesInOutput": 1698
},
"src/body/util.js": {
"bytesInOutput": 1913
},
"src/body/modelPoseNet.js": {
"bytesInOutput": 1569
"bytesInOutput": 2002
},
"src/body/posenet.js": {
"bytesInOutput": 832
"bytesInOutput": 622
},
"src/hand/handdetector.js": {
"bytesInOutput": 2742
@ -531,7 +531,7 @@
"bytesInOutput": 1796
},
"config.js": {
"bytesInOutput": 1454
"bytesInOutput": 1492
},
"src/sample.js": {
"bytesInOutput": 55295
@ -540,7 +540,7 @@
"bytesInOutput": 21
}
},
"bytes": 1836338
"bytes": 1839475
}
}
}

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@ -14,17 +14,11 @@ function scoreIsMaximumInLocalWindow(keypointId, score, heatmapY, heatmapX, loca
break;
}
}
if (!localMaximum) {
break;
}
if (!localMaximum) break;
}
return localMaximum;
}
/**
* Builds a priority queue with part candidate positions for a specific image in
* the batch. For this we find all local maxima in the score maps with score
* values above a threshold. We create a single priority queue across all parts.
*/
function buildPartWithScoreQueue(scoreThreshold, localMaximumRadius, scores) {
const [height, width, numKeypoints] = scores.shape;
const queue = new heapSort.MaxHeap(height * width * numKeypoints, ({ score }) => score);

View File

@ -2,16 +2,15 @@ import * as buildParts from './buildParts';
import * as decodePose from './decodePose';
import * as vectors from './vectors';
const kLocalMaximumRadius = 1;
function withinNmsRadiusOfCorrespondingPoint(poses, squaredNmsRadius, { x, y }, keypointId) {
return poses.some(({ keypoints }) => {
const correspondingKeypoint = keypoints[keypointId].position;
return vectors.squaredDistance(y, x, correspondingKeypoint.y, correspondingKeypoint.x) <= squaredNmsRadius;
});
}
/* Score the newly proposed object instance without taking into account
* the scores of the parts that overlap with any previously detected
* instance.
*/
function getInstanceScore(existingPoses, squaredNmsRadius, instanceKeypoints) {
const notOverlappedKeypointScores = instanceKeypoints.reduce((result, { position, score }, keypointId) => {
if (!withinNmsRadiusOfCorrespondingPoint(existingPoses, squaredNmsRadius, position, keypointId)) result += score;
@ -19,83 +18,22 @@ function getInstanceScore(existingPoses, squaredNmsRadius, instanceKeypoints) {
}, 0.0);
return notOverlappedKeypointScores / instanceKeypoints.length;
}
// A point (y, x) is considered as root part candidate if its score is a
// maximum in a window |y - y'| <= kLocalMaximumRadius, |x - x'| <=
// kLocalMaximumRadius.
const kLocalMaximumRadius = 1;
/**
* Detects multiple poses and finds their parts from part scores and
* displacement vectors. It returns up to `maxDetections` object instance
* detections in decreasing root score order. It works as follows: We first
* create a priority queue with local part score maxima above
* `scoreThreshold`, considering all parts at the same time. Then we
* iteratively pull the top element of the queue (in decreasing score order)
* and treat it as a root candidate for a new object instance. To avoid
* duplicate detections, we reject the root candidate if it is within a disk
* of `nmsRadius` pixels from the corresponding part of a previously detected
* instance, which is a form of part-based non-maximum suppression (NMS). If
* the root candidate passes the NMS check, we start a new object instance
* detection, treating the corresponding part as root and finding the
* positions of the remaining parts by following the displacement vectors
* along the tree-structured part graph. We assign to the newly detected
* instance a score equal to the sum of scores of its parts which have not
* been claimed by a previous instance (i.e., those at least `nmsRadius`
* pixels away from the corresponding part of all previously detected
* instances), divided by the total number of parts `numParts`.
*
* @param heatmapScores 3-D tensor with shape `[height, width, numParts]`.
* The value of heatmapScores[y, x, k]` is the score of placing the `k`-th
* object part at position `(y, x)`.
*
* @param offsets 3-D tensor with shape `[height, width, numParts * 2]`.
* The value of [offsets[y, x, k], offsets[y, x, k + numParts]]` is the
* short range offset vector of the `k`-th object part at heatmap
* position `(y, x)`.
*
* @param displacementsFwd 3-D tensor of shape
* `[height, width, 2 * num_edges]`, where `num_edges = num_parts - 1` is the
* number of edges (parent-child pairs) in the tree. It contains the forward
* displacements between consecutive part from the root towards the leaves.
*
* @param displacementsBwd 3-D tensor of shape
* `[height, width, 2 * num_edges]`, where `num_edges = num_parts - 1` is the
* number of edges (parent-child pairs) in the tree. It contains the backward
* displacements between consecutive part from the root towards the leaves.
*
* @param outputStride The output stride that was used when feed-forwarding
* through the PoseNet model. Must be 32, 16, or 8.
*
* @param maxPoseDetections Maximum number of returned instance detections per
* image.
*
* @param scoreThreshold Only return instance detections that have root part
* score greater or equal to this value. Defaults to 0.5.
*
* @param nmsRadius Non-maximum suppression part distance. It needs to be
* strictly positive. Two parts suppress each other if they are less than
* `nmsRadius` pixels away. Defaults to 20.
*
* @return An array of poses and their scores, each containing keypoints and
* the corresponding keypoint scores.
*/
function decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, outputStride, maxPoseDetections, scoreThreshold, nmsRadius) {
function decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, config) {
const poses = [];
const queue = buildParts.buildPartWithScoreQueue(scoreThreshold, kLocalMaximumRadius, scoresBuffer);
const squaredNmsRadius = nmsRadius * nmsRadius;
// Generate at most maxDetections object instances per image in
// decreasing root part score order.
while (poses.length < maxPoseDetections && !queue.empty()) {
const queue = buildParts.buildPartWithScoreQueue(config.body.scoreThreshold, kLocalMaximumRadius, scoresBuffer);
const squaredNmsRadius = config.body.nmsRadius ^ 2;
// Generate at most maxDetections object instances per image in decreasing root part score order.
while (poses.length < config.body.maxDetections && !queue.empty()) {
// The top element in the queue is the next root candidate.
const root = queue.dequeue();
// Part-based non-maximum suppression: We reject a root candidate if it
// is within a disk of `nmsRadius` pixels from the corresponding part of
// a previously detected instance.
const rootImageCoords = vectors.getImageCoords(root.part, outputStride, offsetsBuffer);
// Part-based non-maximum suppression: We reject a root candidate if it is within a disk of `nmsRadius` pixels from the corresponding part of a previously detected instance.
const rootImageCoords = vectors.getImageCoords(root.part, config.body.outputStride, offsetsBuffer);
if (withinNmsRadiusOfCorrespondingPoint(poses, squaredNmsRadius, rootImageCoords, root.part.id)) continue;
// Start a new detection instance at the position of the root.
const keypoints = decodePose.decodePose(root, scoresBuffer, offsetsBuffer, outputStride, displacementsFwdBuffer, displacementsBwdBuffer);
// Else start a new detection instance at the position of the root.
const keypoints = decodePose.decodePose(root, scoresBuffer, offsetsBuffer, config.body.outputStride, displacementsFwdBuffer, displacementsBwdBuffer);
const score = getInstanceScore(poses, squaredNmsRadius, keypoints);
if (score > scoreThreshold) poses.push({ keypoints, score });
if (score > config.body.scoreThreshold) poses.push({ keypoints, score });
}
return poses;
}

View File

@ -1,5 +1,6 @@
import * as keypoints from './keypoints';
import * as vectors from './vectors';
import * as decoders from './decoders';
const parentChildrenTuples = keypoints.poseChain.map(([parentJoinName, childJoinName]) => ([keypoints.partIds[parentJoinName], keypoints.partIds[childJoinName]]));
const parentToChildEdges = parentChildrenTuples.map(([, childJointId]) => childJointId);
@ -17,13 +18,7 @@ function getStridedIndexNearPoint(point, outputStride, height, width) {
x: vectors.clamp(Math.round(point.x / outputStride), 0, width - 1),
};
}
/**
* We get a new keypoint along the `edgeId` for the pose instance, assuming
* that the position of the `idSource` part is already known. For this, we
* follow the displacement vector from the source to target part (stored in
* the `i`-t channel of the displacement tensor). The displaced keypoint
* vector is refined using the offset vector by `offsetRefineStep` times.
*/
function traverseToTargetKeypoint(edgeId, sourceKeypoint, targetKeypointId, scoresBuffer, offsets, outputStride, displacements, offsetRefineStep = 2) {
const [height, width] = scoresBuffer.shape;
// Nearest neighbor interpolation for the source->target displacements.
@ -43,12 +38,7 @@ function traverseToTargetKeypoint(edgeId, sourceKeypoint, targetKeypointId, scor
const score = scoresBuffer.get(targetKeyPointIndices.y, targetKeyPointIndices.x, targetKeypointId);
return { position: targetKeypoint, part: keypoints.partNames[targetKeypointId], score };
}
/**
* Follows the displacement fields to decode the full pose of the object
* instance given the position of a part that acts as root.
*
* @return An array of decoded keypoints and their scores for a single pose
*/
function decodePose(root, scores, offsets, outputStride, displacementsFwd, displacementsBwd) {
const numParts = scores.shape[2];
const numEdges = parentToChildEdges.length;
@ -80,3 +70,31 @@ function decodePose(root, scores, offsets, outputStride, displacementsFwd, displ
return instanceKeypoints;
}
exports.decodePose = decodePose;
async function decodeSinglePose(heatmapScores, offsets, config) {
let totalScore = 0.0;
const heatmapValues = decoders.argmax2d(heatmapScores);
const allTensorBuffers = await Promise.all([heatmapScores.buffer(), offsets.buffer(), heatmapValues.buffer()]);
const scoresBuffer = allTensorBuffers[0];
const offsetsBuffer = allTensorBuffers[1];
const heatmapValuesBuffer = allTensorBuffers[2];
const offsetPoints = decoders.getOffsetPoints(heatmapValuesBuffer, config.body.outputStride, offsetsBuffer);
const offsetPointsBuffer = await offsetPoints.buffer();
const keypointConfidence = Array.from(decoders.getPointsConfidence(scoresBuffer, heatmapValuesBuffer));
const instanceKeypoints = keypointConfidence.map((score, i) => {
totalScore += score;
return {
position: {
y: offsetPointsBuffer.get(i, 0),
x: offsetPointsBuffer.get(i, 1),
},
part: keypoints.partNames[i],
score,
};
});
const filteredKeypoints = instanceKeypoints.filter((kpt) => kpt.score > config.body.scoreThreshold);
heatmapValues.dispose();
offsetPoints.dispose();
return { keypoints: filteredKeypoints, score: totalScore / instanceKeypoints.length };
}
exports.decodeSinglePose = decodeSinglePose;

View File

@ -12,12 +12,14 @@ function getPointsConfidence(heatmapScores, heatMapCoords) {
return result;
}
exports.getPointsConfidence = getPointsConfidence;
function getOffsetPoint(y, x, keypoint, offsetsBuffer) {
return {
y: offsetsBuffer.get(y, x, keypoint),
x: offsetsBuffer.get(y, x, keypoint + kpt.NUM_KEYPOINTS),
};
}
function getOffsetVectors(heatMapCoordsBuffer, offsetsBuffer) {
const result = [];
for (let keypoint = 0; keypoint < kpt.NUM_KEYPOINTS; keypoint++) {
@ -30,14 +32,9 @@ function getOffsetVectors(heatMapCoordsBuffer, offsetsBuffer) {
return tf.tensor2d(result, [kpt.NUM_KEYPOINTS, 2]);
}
exports.getOffsetVectors = getOffsetVectors;
function getOffsetPoints(heatMapCoordsBuffer, outputStride, offsetsBuffer) {
return tf.tidy(() => {
const offsetVectors = getOffsetVectors(heatMapCoordsBuffer, offsetsBuffer);
return heatMapCoordsBuffer.toTensor()
.mul(tf.scalar(outputStride, 'int32'))
.toFloat()
.add(offsetVectors);
});
return tf.tidy(() => heatMapCoordsBuffer.toTensor().mul(tf.scalar(outputStride, 'int32')).toFloat().add(getOffsetVectors(heatMapCoordsBuffer, offsetsBuffer)));
}
exports.getOffsetPoints = getOffsetPoints;
@ -47,6 +44,7 @@ function mod(a, b) {
return a.sub(floored.mul(tf.scalar(b, 'int32')));
});
}
function argmax2d(inputs) {
const [height, width, depth] = inputs.shape;
return tf.tidy(() => {

View File

@ -3,11 +3,14 @@ exports.partNames = [
'rightShoulder', 'leftElbow', 'rightElbow', 'leftWrist', 'rightWrist',
'leftHip', 'rightHip', 'leftKnee', 'rightKnee', 'leftAnkle', 'rightAnkle',
];
exports.NUM_KEYPOINTS = exports.partNames.length;
exports.partIds = exports.partNames.reduce((result, jointName, i) => {
result[jointName] = i;
return result;
}, {});
const connectedPartNames = [
['leftHip', 'leftShoulder'], ['leftElbow', 'leftShoulder'],
['leftElbow', 'leftWrist'], ['leftHip', 'leftKnee'],
@ -16,12 +19,8 @@ const connectedPartNames = [
['rightHip', 'rightKnee'], ['rightKnee', 'rightAnkle'],
['leftShoulder', 'rightShoulder'], ['leftHip', 'rightHip'],
];
/*
* Define the skeleton. This defines the parent->child relationships of our
* tree. Arbitrarily this defines the nose as the root of the tree, however
* since we will infer the displacement for both parent->child and
* child->parent, we can define the tree root as any node.
*/
exports.connectedPartIndices = connectedPartNames.map(([jointNameA, jointNameB]) => ([exports.partIds[jointNameA], exports.partIds[jointNameB]]));
exports.poseChain = [
['nose', 'leftEye'], ['leftEye', 'leftEar'], ['nose', 'rightEye'],
['rightEye', 'rightEar'], ['nose', 'leftShoulder'],
@ -32,7 +31,7 @@ exports.poseChain = [
['rightShoulder', 'rightHip'], ['rightHip', 'rightKnee'],
['rightKnee', 'rightAnkle'],
];
exports.connectedPartIndices = connectedPartNames.map(([jointNameA, jointNameB]) => ([exports.partIds[jointNameA], exports.partIds[jointNameB]]));
exports.partChannels = [
'left_face',
'right_face',

View File

@ -1,18 +1,29 @@
import * as tf from '../../dist/tfjs.esm.js';
const imageNetMean = [-123.15, -115.90, -103.06];
function nameOutputResultsMobileNet(results) {
const [offsets, heatmap, displacementFwd, displacementBwd] = results;
return { offsets, heatmap, displacementFwd, displacementBwd };
}
function nameOutputResultsResNet(results) {
const [displacementFwd, displacementBwd, offsets, heatmap] = results;
return { offsets, heatmap, displacementFwd, displacementBwd };
}
class BaseModel {
constructor(model, outputStride) {
constructor(model) {
this.model = model;
this.outputStride = outputStride;
}
predict(input) {
predict(input, config) {
return tf.tidy(() => {
const asFloat = this.preprocessInput(input.toFloat());
const asFloat = (config.body.modelType === 'ResNet') ? input.toFloat().add(imageNetMean) : input.toFloat().div(127.5).sub(1.0);
const asBatch = asFloat.expandDims(0);
const results = this.model.predict(asBatch);
const results3d = results.map((y) => y.squeeze([0]));
const namedResults = this.nameOutputResults(results3d);
const namedResults = (config.body.modelType === 'ResNet') ? nameOutputResultsResNet(results3d) : nameOutputResultsMobileNet(results3d);
return {
heatmapScores: namedResults.heatmap.sigmoid(),
offsets: namedResults.offsets,
@ -22,9 +33,6 @@ class BaseModel {
});
}
/**
* Releases the CPU and GPU memory allocated by the model.
*/
dispose() {
this.model.dispose();
}

View File

@ -1,17 +0,0 @@
import * as tf from '../../dist/tfjs.esm.js';
import * as modelBase from './modelBase';
class MobileNet extends modelBase.BaseModel {
// eslint-disable-next-line class-methods-use-this
preprocessInput(input) {
// Normalize the pixels [0, 255] to be between [-1, 1].
return tf.tidy(() => tf.div(input, 127.5).sub(1.0));
}
// eslint-disable-next-line class-methods-use-this
nameOutputResults(results) {
const [offsets, heatmap, displacementFwd, displacementBwd] = results;
return { offsets, heatmap, displacementFwd, displacementBwd };
}
}
exports.MobileNet = MobileNet;

View File

@ -1,35 +1,54 @@
import { log } from '../log.js';
import * as tf from '../../dist/tfjs.esm.js';
import * as modelMobileNet from './modelMobileNet';
import * as modelBase from './modelBase';
import * as decodeMultiple from './decodeMultiple';
import * as decodePose from './decodePose';
import * as util from './util';
async function estimateMultiple(input, res, config) {
return new Promise(async (resolve) => {
const height = input.shape[1];
const width = input.shape[2];
const allTensorBuffers = await util.toTensorBuffers3D([res.heatmapScores, res.offsets, res.displacementFwd, res.displacementBwd]);
const scoresBuffer = allTensorBuffers[0];
const offsetsBuffer = allTensorBuffers[1];
const displacementsFwdBuffer = allTensorBuffers[2];
const displacementsBwdBuffer = allTensorBuffers[3];
const poses = await decodeMultiple.decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, config);
const scaled = util.scaleAndFlipPoses(poses, [height, width], [config.body.inputSize, config.body.inputSize]);
resolve(scaled);
});
}
async function estimateSingle(input, res, config) {
return new Promise(async (resolve) => {
const height = input.shape[1];
const width = input.shape[2];
const pose = await decodePose.decodeSinglePose(res.heatmapScores, res.offsets, config);
const poses = [pose];
const scaled = util.scaleAndFlipPoses(poses, [height, width], [config.body.inputSize, config.body.inputSize]);
resolve(scaled);
});
}
class PoseNet {
constructor(net) {
this.baseModel = net;
this.outputStride = 16;
constructor(model) {
this.baseModel = model;
}
async estimatePoses(input, config) {
return new Promise(async (resolve) => {
const height = input.shape[1];
const width = input.shape[2];
const resized = util.resizeTo(input, [config.body.inputSize, config.body.inputSize]);
const res = this.baseModel.predict(resized);
const allTensorBuffers = await util.toTensorBuffers3D([res.heatmapScores, res.offsets, res.displacementFwd, res.displacementBwd]);
const scoresBuffer = allTensorBuffers[0];
const offsetsBuffer = allTensorBuffers[1];
const displacementsFwdBuffer = allTensorBuffers[2];
const displacementsBwdBuffer = allTensorBuffers[3];
const poses = await decodeMultiple.decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, this.outputStride, config.body.maxDetections, config.body.scoreThreshold, config.body.nmsRadius);
const resultPoses = util.scaleAndFlipPoses(poses, [height, width], [config.body.inputSize, config.body.inputSize]);
res.heatmapScores.dispose();
res.offsets.dispose();
res.displacementFwd.dispose();
res.displacementBwd.dispose();
resized.dispose();
resolve(resultPoses);
});
const resized = util.resizeTo(input, [config.body.inputSize, config.body.inputSize]);
const res = this.baseModel.predict(resized, config);
const poses = (config.body.maxDetections < 2) ? await estimateSingle(input, res, config) : await estimateMultiple(input, res, config);
res.heatmapScores.dispose();
res.offsets.dispose();
res.displacementFwd.dispose();
res.displacementBwd.dispose();
resized.dispose();
return poses;
}
dispose() {
@ -39,8 +58,8 @@ class PoseNet {
exports.PoseNet = PoseNet;
async function load(config) {
const graphModel = await tf.loadGraphModel(config.body.modelPath);
const mobilenet = new modelMobileNet.MobileNet(graphModel, this.outputStride);
const model = await tf.loadGraphModel(config.body.modelPath);
const mobilenet = new modelBase.BaseModel(model);
log(`load model: ${config.body.modelPath.match(/\/(.*)\./)[1]}`);
return new PoseNet(mobilenet);
}

View File

@ -1,14 +1,10 @@
import * as modelMobileNet from './modelMobileNet';
import * as modelPoseNet from './modelPoseNet';
import * as decodeMultiple from './decodeMultiple';
import * as keypoints from './keypoints';
import * as util from './util';
exports.load = modelPoseNet.load;
exports.PoseNet = modelPoseNet.PoseNet;
exports.MobileNet = modelMobileNet.MobileNet;
exports.decodeMultiplePoses = decodeMultiple.decodeMultiplePoses;
exports.partChannels = keypoints.partChannels;
exports.partIds = keypoints.partIds;
exports.partNames = keypoints.partNames;