major work on body module

pull/50/head
Vladimir Mandic 2020-12-16 18:36:24 -05:00
parent bbebaccd3b
commit 1a73761da0
24 changed files with 422 additions and 450 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
node_modules node_modules
alternative

View File

@ -138,6 +138,9 @@ export default {
scoreThreshold: 0.5, // threshold for deciding when to remove boxes based on score scoreThreshold: 0.5, // threshold for deciding when to remove boxes based on score
// in non-maximum suppression // in non-maximum suppression
nmsRadius: 20, // radius for deciding points are too close in non-maximum suppression nmsRadius: 20, // radius for deciding points are too close in non-maximum suppression
outputStride: 16, // size of block in which to run point detectopn, smaller value means higher resolution
// defined by model itself, can be 8, 16, or 32
modelType: 'MobileNet', // Human includes MobileNet version, but you can switch to ResNet
}, },
hand: { hand: {

View File

@ -4,6 +4,13 @@ import Menu from './menu.js';
import GLBench from './gl-bench.js'; import GLBench from './gl-bench.js';
const userConfig = {}; // add any user configuration overrides const userConfig = {}; // add any user configuration overrides
/*
const userConfig = {
face: { enabled: false },
body: { enabled: true },
hand: { enabled: false },
};
*/
const human = new Human(userConfig); const human = new Human(userConfig);

View File

@ -136,49 +136,57 @@ async function drawBody(result, canvas, ui) {
} }
if (ui.drawPolygons) { if (ui.drawPolygons) {
const path = new Path2D(); const path = new Path2D();
let root;
let part; let part;
// torso // torso
part = result[i].keypoints.find((a) => a.part === 'leftShoulder'); root = result[i].keypoints.find((a) => a.part === 'leftShoulder');
path.moveTo(part.position.x, part.position.y); if (root) {
path.moveTo(root.position.x, root.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightShoulder'); part = result[i].keypoints.find((a) => a.part === 'rightShoulder');
path.lineTo(part.position.x, part.position.y); if (part) path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightHip'); part = result[i].keypoints.find((a) => a.part === 'rightHip');
path.lineTo(part.position.x, part.position.y); if (part) path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftHip'); part = result[i].keypoints.find((a) => a.part === 'leftHip');
path.lineTo(part.position.x, part.position.y); if (part) path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftShoulder'); part = result[i].keypoints.find((a) => a.part === 'leftShoulder');
path.lineTo(part.position.x, part.position.y); if (part) path.lineTo(part.position.x, part.position.y);
// legs }
part = result[i].keypoints.find((a) => a.part === 'leftHip'); // leg left
path.moveTo(part.position.x, part.position.y); root = result[i].keypoints.find((a) => a.part === 'leftHip');
if (root) {
path.moveTo(root.position.x, root.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftKnee'); part = result[i].keypoints.find((a) => a.part === 'leftKnee');
path.lineTo(part.position.x, part.position.y); if (part) path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftAnkle'); part = result[i].keypoints.find((a) => a.part === 'leftAnkle');
path.lineTo(part.position.x, part.position.y); if (part) path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightHip'); }
path.moveTo(part.position.x, part.position.y); // leg right
root = result[i].keypoints.find((a) => a.part === 'rightHip');
if (root) {
path.moveTo(root.position.x, root.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightKnee'); part = result[i].keypoints.find((a) => a.part === 'rightKnee');
path.lineTo(part.position.x, part.position.y); if (part) path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightAnkle'); part = result[i].keypoints.find((a) => a.part === 'rightAnkle');
path.lineTo(part.position.x, part.position.y); if (part) path.lineTo(part.position.x, part.position.y);
// arms }
part = result[i].keypoints.find((a) => a.part === 'rightShoulder'); // arm left
path.moveTo(part.position.x, part.position.y); root = result[i].keypoints.find((a) => a.part === 'leftShoulder');
part = result[i].keypoints.find((a) => a.part === 'leftShoulder'); if (root) {
path.lineTo(part.position.x, part.position.y); path.moveTo(root.position.x, root.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftElbow'); part = result[i].keypoints.find((a) => a.part === 'leftElbow');
path.lineTo(part.position.x, part.position.y); if (part) path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'leftWrist'); part = result[i].keypoints.find((a) => a.part === 'leftWrist');
path.lineTo(part.position.x, part.position.y); if (part) path.lineTo(part.position.x, part.position.y);
// arms }
part = result[i].keypoints.find((a) => a.part === 'leftShoulder'); // arm right
path.moveTo(part.position.x, part.position.y); root = result[i].keypoints.find((a) => a.part === 'rightShoulder');
part = result[i].keypoints.find((a) => a.part === 'rightShoulder'); if (root) {
path.lineTo(part.position.x, part.position.y); path.moveTo(root.position.x, root.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightElbow'); part = result[i].keypoints.find((a) => a.part === 'rightElbow');
path.lineTo(part.position.x, part.position.y); if (part) path.lineTo(part.position.x, part.position.y);
part = result[i].keypoints.find((a) => a.part === 'rightWrist'); part = result[i].keypoints.find((a) => a.part === 'rightWrist');
path.lineTo(part.position.x, part.position.y); if (part) path.lineTo(part.position.x, part.position.y);
}
// draw all // draw all
ctx.stroke(path); ctx.stroke(path);
} }

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,11 +1,11 @@
{ {
"inputs": { "inputs": {
"dist/human.esm.js": { "dist/human.esm.js": {
"bytes": 1836264, "bytes": 1839401,
"imports": [] "imports": []
}, },
"demo/draw.js": { "demo/draw.js": {
"bytes": 10630, "bytes": 10733,
"imports": [] "imports": []
}, },
"demo/menu.js": { "demo/menu.js": {
@ -17,7 +17,7 @@
"imports": [] "imports": []
}, },
"demo/browser.js": { "demo/browser.js": {
"bytes": 25337, "bytes": 25450,
"imports": [ "imports": [
{ {
"path": "dist/human.esm.js" "path": "dist/human.esm.js"
@ -38,17 +38,17 @@
"dist/demo-browser-index.js.map": { "dist/demo-browser-index.js.map": {
"imports": [], "imports": [],
"inputs": {}, "inputs": {},
"bytes": 2198146 "bytes": 2199701
}, },
"dist/demo-browser-index.js": { "dist/demo-browser-index.js": {
"imports": [], "imports": [],
"exports": [], "exports": [],
"inputs": { "inputs": {
"dist/human.esm.js": { "dist/human.esm.js": {
"bytesInOutput": 1829024 "bytesInOutput": 1832161
}, },
"demo/draw.js": { "demo/draw.js": {
"bytesInOutput": 7816 "bytesInOutput": 7726
}, },
"demo/menu.js": { "demo/menu.js": {
"bytesInOutput": 11800 "bytesInOutput": 11800
@ -60,7 +60,7 @@
"bytesInOutput": 19539 "bytesInOutput": 19539
} }
}, },
"bytes": 1882950 "bytes": 1885997
} }
} }
} }

4
dist/human.esm.js vendored

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

74
dist/human.esm.json vendored
View File

@ -148,30 +148,19 @@
] ]
}, },
"src/body/modelBase.js": { "src/body/modelBase.js": {
"bytes": 889, "bytes": 1343,
"imports": [ "imports": [
{ {
"path": "dist/tfjs.esm.js" "path": "dist/tfjs.esm.js"
} }
] ]
}, },
"src/body/modelMobileNet.js": {
"bytes": 599,
"imports": [
{
"path": "dist/tfjs.esm.js"
},
{
"path": "src/body/modelBase.js"
}
]
},
"src/body/heapSort.js": { "src/body/heapSort.js": {
"bytes": 1590, "bytes": 1590,
"imports": [] "imports": []
}, },
"src/body/buildParts.js": { "src/body/buildParts.js": {
"bytes": 2035, "bytes": 1775,
"imports": [ "imports": [
{ {
"path": "src/body/heapSort.js" "path": "src/body/heapSort.js"
@ -179,7 +168,7 @@
] ]
}, },
"src/body/keypoints.js": { "src/body/keypoints.js": {
"bytes": 2291, "bytes": 2011,
"imports": [] "imports": []
}, },
"src/body/vectors.js": { "src/body/vectors.js": {
@ -190,19 +179,33 @@
} }
] ]
}, },
"src/body/decoders.js": {
"bytes": 2083,
"imports": [
{
"path": "dist/tfjs.esm.js"
},
{
"path": "src/body/keypoints.js"
}
]
},
"src/body/decodePose.js": { "src/body/decodePose.js": {
"bytes": 4530, "bytes": 5216,
"imports": [ "imports": [
{ {
"path": "src/body/keypoints.js" "path": "src/body/keypoints.js"
}, },
{ {
"path": "src/body/vectors.js" "path": "src/body/vectors.js"
},
{
"path": "src/body/decoders.js"
} }
] ]
}, },
"src/body/decodeMultiple.js": { "src/body/decodeMultiple.js": {
"bytes": 5608, "bytes": 2303,
"imports": [ "imports": [
{ {
"path": "src/body/buildParts.js" "path": "src/body/buildParts.js"
@ -224,7 +227,7 @@
] ]
}, },
"src/body/modelPoseNet.js": { "src/body/modelPoseNet.js": {
"bytes": 1905, "bytes": 2395,
"imports": [ "imports": [
{ {
"path": "src/log.js" "path": "src/log.js"
@ -233,28 +236,25 @@
"path": "dist/tfjs.esm.js" "path": "dist/tfjs.esm.js"
}, },
{ {
"path": "src/body/modelMobileNet.js" "path": "src/body/modelBase.js"
}, },
{ {
"path": "src/body/decodeMultiple.js" "path": "src/body/decodeMultiple.js"
}, },
{
"path": "src/body/decodePose.js"
},
{ {
"path": "src/body/util.js" "path": "src/body/util.js"
} }
] ]
}, },
"src/body/posenet.js": { "src/body/posenet.js": {
"bytes": 830, "bytes": 614,
"imports": [ "imports": [
{
"path": "src/body/modelMobileNet.js"
},
{ {
"path": "src/body/modelPoseNet.js" "path": "src/body/modelPoseNet.js"
}, },
{
"path": "src/body/decodeMultiple.js"
},
{ {
"path": "src/body/keypoints.js" "path": "src/body/keypoints.js"
}, },
@ -350,7 +350,7 @@
] ]
}, },
"config.js": { "config.js": {
"bytes": 9241, "bytes": 9530,
"imports": [] "imports": []
}, },
"src/sample.js": { "src/sample.js": {
@ -419,7 +419,7 @@
"dist/human.esm.js.map": { "dist/human.esm.js.map": {
"imports": [], "imports": [],
"inputs": {}, "inputs": {},
"bytes": 2104843 "bytes": 2106203
}, },
"dist/human.esm.js": { "dist/human.esm.js": {
"imports": [], "imports": [],
@ -461,10 +461,7 @@
"bytesInOutput": 1318 "bytesInOutput": 1318
}, },
"src/body/modelBase.js": { "src/body/modelBase.js": {
"bytesInOutput": 615 "bytesInOutput": 1080
},
"src/body/modelMobileNet.js": {
"bytesInOutput": 375
}, },
"src/body/heapSort.js": { "src/body/heapSort.js": {
"bytesInOutput": 1139 "bytesInOutput": 1139
@ -478,20 +475,23 @@
"src/body/vectors.js": { "src/body/vectors.js": {
"bytesInOutput": 1050 "bytesInOutput": 1050
}, },
"src/body/decoders.js": {
"bytesInOutput": 1722
},
"src/body/decodePose.js": { "src/body/decodePose.js": {
"bytesInOutput": 3111 "bytesInOutput": 4161
}, },
"src/body/decodeMultiple.js": { "src/body/decodeMultiple.js": {
"bytesInOutput": 1684 "bytesInOutput": 1698
}, },
"src/body/util.js": { "src/body/util.js": {
"bytesInOutput": 1913 "bytesInOutput": 1913
}, },
"src/body/modelPoseNet.js": { "src/body/modelPoseNet.js": {
"bytesInOutput": 1569 "bytesInOutput": 2002
}, },
"src/body/posenet.js": { "src/body/posenet.js": {
"bytesInOutput": 832 "bytesInOutput": 622
}, },
"src/hand/handdetector.js": { "src/hand/handdetector.js": {
"bytesInOutput": 2742 "bytesInOutput": 2742
@ -533,7 +533,7 @@
"bytesInOutput": 1796 "bytesInOutput": 1796
}, },
"config.js": { "config.js": {
"bytesInOutput": 1454 "bytesInOutput": 1492
}, },
"src/sample.js": { "src/sample.js": {
"bytesInOutput": 55295 "bytesInOutput": 55295
@ -542,7 +542,7 @@
"bytesInOutput": 21 "bytesInOutput": 21
} }
}, },
"bytes": 1836264 "bytes": 1839401
} }
} }
} }

4
dist/human.js vendored

File diff suppressed because one or more lines are too long

6
dist/human.js.map vendored

File diff suppressed because one or more lines are too long

74
dist/human.json vendored
View File

@ -148,30 +148,19 @@
] ]
}, },
"src/body/modelBase.js": { "src/body/modelBase.js": {
"bytes": 889, "bytes": 1343,
"imports": [ "imports": [
{ {
"path": "dist/tfjs.esm.js" "path": "dist/tfjs.esm.js"
} }
] ]
}, },
"src/body/modelMobileNet.js": {
"bytes": 599,
"imports": [
{
"path": "dist/tfjs.esm.js"
},
{
"path": "src/body/modelBase.js"
}
]
},
"src/body/heapSort.js": { "src/body/heapSort.js": {
"bytes": 1590, "bytes": 1590,
"imports": [] "imports": []
}, },
"src/body/buildParts.js": { "src/body/buildParts.js": {
"bytes": 2035, "bytes": 1775,
"imports": [ "imports": [
{ {
"path": "src/body/heapSort.js" "path": "src/body/heapSort.js"
@ -179,7 +168,7 @@
] ]
}, },
"src/body/keypoints.js": { "src/body/keypoints.js": {
"bytes": 2291, "bytes": 2011,
"imports": [] "imports": []
}, },
"src/body/vectors.js": { "src/body/vectors.js": {
@ -190,19 +179,33 @@
} }
] ]
}, },
"src/body/decoders.js": {
"bytes": 2083,
"imports": [
{
"path": "dist/tfjs.esm.js"
},
{
"path": "src/body/keypoints.js"
}
]
},
"src/body/decodePose.js": { "src/body/decodePose.js": {
"bytes": 4530, "bytes": 5216,
"imports": [ "imports": [
{ {
"path": "src/body/keypoints.js" "path": "src/body/keypoints.js"
}, },
{ {
"path": "src/body/vectors.js" "path": "src/body/vectors.js"
},
{
"path": "src/body/decoders.js"
} }
] ]
}, },
"src/body/decodeMultiple.js": { "src/body/decodeMultiple.js": {
"bytes": 5608, "bytes": 2303,
"imports": [ "imports": [
{ {
"path": "src/body/buildParts.js" "path": "src/body/buildParts.js"
@ -224,7 +227,7 @@
] ]
}, },
"src/body/modelPoseNet.js": { "src/body/modelPoseNet.js": {
"bytes": 1905, "bytes": 2395,
"imports": [ "imports": [
{ {
"path": "src/log.js" "path": "src/log.js"
@ -233,28 +236,25 @@
"path": "dist/tfjs.esm.js" "path": "dist/tfjs.esm.js"
}, },
{ {
"path": "src/body/modelMobileNet.js" "path": "src/body/modelBase.js"
}, },
{ {
"path": "src/body/decodeMultiple.js" "path": "src/body/decodeMultiple.js"
}, },
{
"path": "src/body/decodePose.js"
},
{ {
"path": "src/body/util.js" "path": "src/body/util.js"
} }
] ]
}, },
"src/body/posenet.js": { "src/body/posenet.js": {
"bytes": 830, "bytes": 614,
"imports": [ "imports": [
{
"path": "src/body/modelMobileNet.js"
},
{ {
"path": "src/body/modelPoseNet.js" "path": "src/body/modelPoseNet.js"
}, },
{
"path": "src/body/decodeMultiple.js"
},
{ {
"path": "src/body/keypoints.js" "path": "src/body/keypoints.js"
}, },
@ -350,7 +350,7 @@
] ]
}, },
"config.js": { "config.js": {
"bytes": 9241, "bytes": 9530,
"imports": [] "imports": []
}, },
"src/sample.js": { "src/sample.js": {
@ -419,7 +419,7 @@
"dist/human.js.map": { "dist/human.js.map": {
"imports": [], "imports": [],
"inputs": {}, "inputs": {},
"bytes": 2121966 "bytes": 2123326
}, },
"dist/human.js": { "dist/human.js": {
"imports": [], "imports": [],
@ -459,10 +459,7 @@
"bytesInOutput": 1318 "bytesInOutput": 1318
}, },
"src/body/modelBase.js": { "src/body/modelBase.js": {
"bytesInOutput": 615 "bytesInOutput": 1080
},
"src/body/modelMobileNet.js": {
"bytesInOutput": 375
}, },
"src/body/heapSort.js": { "src/body/heapSort.js": {
"bytesInOutput": 1139 "bytesInOutput": 1139
@ -476,20 +473,23 @@
"src/body/vectors.js": { "src/body/vectors.js": {
"bytesInOutput": 1050 "bytesInOutput": 1050
}, },
"src/body/decoders.js": {
"bytesInOutput": 1722
},
"src/body/decodePose.js": { "src/body/decodePose.js": {
"bytesInOutput": 3111 "bytesInOutput": 4161
}, },
"src/body/decodeMultiple.js": { "src/body/decodeMultiple.js": {
"bytesInOutput": 1684 "bytesInOutput": 1698
}, },
"src/body/util.js": { "src/body/util.js": {
"bytesInOutput": 1913 "bytesInOutput": 1913
}, },
"src/body/modelPoseNet.js": { "src/body/modelPoseNet.js": {
"bytesInOutput": 1569 "bytesInOutput": 2002
}, },
"src/body/posenet.js": { "src/body/posenet.js": {
"bytesInOutput": 832 "bytesInOutput": 622
}, },
"src/hand/handdetector.js": { "src/hand/handdetector.js": {
"bytesInOutput": 2742 "bytesInOutput": 2742
@ -531,7 +531,7 @@
"bytesInOutput": 1796 "bytesInOutput": 1796
}, },
"config.js": { "config.js": {
"bytesInOutput": 1454 "bytesInOutput": 1492
}, },
"src/sample.js": { "src/sample.js": {
"bytesInOutput": 55295 "bytesInOutput": 55295
@ -540,7 +540,7 @@
"bytesInOutput": 21 "bytesInOutput": 21
} }
}, },
"bytes": 1836338 "bytes": 1839475
} }
} }
} }

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@ -14,17 +14,11 @@ function scoreIsMaximumInLocalWindow(keypointId, score, heatmapY, heatmapX, loca
break; break;
} }
} }
if (!localMaximum) { if (!localMaximum) break;
break;
}
} }
return localMaximum; return localMaximum;
} }
/**
* Builds a priority queue with part candidate positions for a specific image in
* the batch. For this we find all local maxima in the score maps with score
* values above a threshold. We create a single priority queue across all parts.
*/
function buildPartWithScoreQueue(scoreThreshold, localMaximumRadius, scores) { function buildPartWithScoreQueue(scoreThreshold, localMaximumRadius, scores) {
const [height, width, numKeypoints] = scores.shape; const [height, width, numKeypoints] = scores.shape;
const queue = new heapSort.MaxHeap(height * width * numKeypoints, ({ score }) => score); const queue = new heapSort.MaxHeap(height * width * numKeypoints, ({ score }) => score);

View File

@ -2,16 +2,15 @@ import * as buildParts from './buildParts';
import * as decodePose from './decodePose'; import * as decodePose from './decodePose';
import * as vectors from './vectors'; import * as vectors from './vectors';
const kLocalMaximumRadius = 1;
function withinNmsRadiusOfCorrespondingPoint(poses, squaredNmsRadius, { x, y }, keypointId) { function withinNmsRadiusOfCorrespondingPoint(poses, squaredNmsRadius, { x, y }, keypointId) {
return poses.some(({ keypoints }) => { return poses.some(({ keypoints }) => {
const correspondingKeypoint = keypoints[keypointId].position; const correspondingKeypoint = keypoints[keypointId].position;
return vectors.squaredDistance(y, x, correspondingKeypoint.y, correspondingKeypoint.x) <= squaredNmsRadius; return vectors.squaredDistance(y, x, correspondingKeypoint.y, correspondingKeypoint.x) <= squaredNmsRadius;
}); });
} }
/* Score the newly proposed object instance without taking into account
* the scores of the parts that overlap with any previously detected
* instance.
*/
function getInstanceScore(existingPoses, squaredNmsRadius, instanceKeypoints) { function getInstanceScore(existingPoses, squaredNmsRadius, instanceKeypoints) {
const notOverlappedKeypointScores = instanceKeypoints.reduce((result, { position, score }, keypointId) => { const notOverlappedKeypointScores = instanceKeypoints.reduce((result, { position, score }, keypointId) => {
if (!withinNmsRadiusOfCorrespondingPoint(existingPoses, squaredNmsRadius, position, keypointId)) result += score; if (!withinNmsRadiusOfCorrespondingPoint(existingPoses, squaredNmsRadius, position, keypointId)) result += score;
@ -19,83 +18,22 @@ function getInstanceScore(existingPoses, squaredNmsRadius, instanceKeypoints) {
}, 0.0); }, 0.0);
return notOverlappedKeypointScores / instanceKeypoints.length; return notOverlappedKeypointScores / instanceKeypoints.length;
} }
// A point (y, x) is considered as root part candidate if its score is a
// maximum in a window |y - y'| <= kLocalMaximumRadius, |x - x'| <= function decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, config) {
// kLocalMaximumRadius.
const kLocalMaximumRadius = 1;
/**
* Detects multiple poses and finds their parts from part scores and
* displacement vectors. It returns up to `maxDetections` object instance
* detections in decreasing root score order. It works as follows: We first
* create a priority queue with local part score maxima above
* `scoreThreshold`, considering all parts at the same time. Then we
* iteratively pull the top element of the queue (in decreasing score order)
* and treat it as a root candidate for a new object instance. To avoid
* duplicate detections, we reject the root candidate if it is within a disk
* of `nmsRadius` pixels from the corresponding part of a previously detected
* instance, which is a form of part-based non-maximum suppression (NMS). If
* the root candidate passes the NMS check, we start a new object instance
* detection, treating the corresponding part as root and finding the
* positions of the remaining parts by following the displacement vectors
* along the tree-structured part graph. We assign to the newly detected
* instance a score equal to the sum of scores of its parts which have not
* been claimed by a previous instance (i.e., those at least `nmsRadius`
* pixels away from the corresponding part of all previously detected
* instances), divided by the total number of parts `numParts`.
*
* @param heatmapScores 3-D tensor with shape `[height, width, numParts]`.
* The value of heatmapScores[y, x, k]` is the score of placing the `k`-th
* object part at position `(y, x)`.
*
* @param offsets 3-D tensor with shape `[height, width, numParts * 2]`.
* The value of [offsets[y, x, k], offsets[y, x, k + numParts]]` is the
* short range offset vector of the `k`-th object part at heatmap
* position `(y, x)`.
*
* @param displacementsFwd 3-D tensor of shape
* `[height, width, 2 * num_edges]`, where `num_edges = num_parts - 1` is the
* number of edges (parent-child pairs) in the tree. It contains the forward
* displacements between consecutive part from the root towards the leaves.
*
* @param displacementsBwd 3-D tensor of shape
* `[height, width, 2 * num_edges]`, where `num_edges = num_parts - 1` is the
* number of edges (parent-child pairs) in the tree. It contains the backward
* displacements between consecutive part from the root towards the leaves.
*
* @param outputStride The output stride that was used when feed-forwarding
* through the PoseNet model. Must be 32, 16, or 8.
*
* @param maxPoseDetections Maximum number of returned instance detections per
* image.
*
* @param scoreThreshold Only return instance detections that have root part
* score greater or equal to this value. Defaults to 0.5.
*
* @param nmsRadius Non-maximum suppression part distance. It needs to be
* strictly positive. Two parts suppress each other if they are less than
* `nmsRadius` pixels away. Defaults to 20.
*
* @return An array of poses and their scores, each containing keypoints and
* the corresponding keypoint scores.
*/
function decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, outputStride, maxPoseDetections, scoreThreshold, nmsRadius) {
const poses = []; const poses = [];
const queue = buildParts.buildPartWithScoreQueue(scoreThreshold, kLocalMaximumRadius, scoresBuffer); const queue = buildParts.buildPartWithScoreQueue(config.body.scoreThreshold, kLocalMaximumRadius, scoresBuffer);
const squaredNmsRadius = nmsRadius * nmsRadius; const squaredNmsRadius = config.body.nmsRadius ^ 2;
// Generate at most maxDetections object instances per image in // Generate at most maxDetections object instances per image in decreasing root part score order.
// decreasing root part score order. while (poses.length < config.body.maxDetections && !queue.empty()) {
while (poses.length < maxPoseDetections && !queue.empty()) {
// The top element in the queue is the next root candidate. // The top element in the queue is the next root candidate.
const root = queue.dequeue(); const root = queue.dequeue();
// Part-based non-maximum suppression: We reject a root candidate if it // Part-based non-maximum suppression: We reject a root candidate if it is within a disk of `nmsRadius` pixels from the corresponding part of a previously detected instance.
// is within a disk of `nmsRadius` pixels from the corresponding part of const rootImageCoords = vectors.getImageCoords(root.part, config.body.outputStride, offsetsBuffer);
// a previously detected instance.
const rootImageCoords = vectors.getImageCoords(root.part, outputStride, offsetsBuffer);
if (withinNmsRadiusOfCorrespondingPoint(poses, squaredNmsRadius, rootImageCoords, root.part.id)) continue; if (withinNmsRadiusOfCorrespondingPoint(poses, squaredNmsRadius, rootImageCoords, root.part.id)) continue;
// Start a new detection instance at the position of the root. // Else start a new detection instance at the position of the root.
const keypoints = decodePose.decodePose(root, scoresBuffer, offsetsBuffer, outputStride, displacementsFwdBuffer, displacementsBwdBuffer); const keypoints = decodePose.decodePose(root, scoresBuffer, offsetsBuffer, config.body.outputStride, displacementsFwdBuffer, displacementsBwdBuffer);
const score = getInstanceScore(poses, squaredNmsRadius, keypoints); const score = getInstanceScore(poses, squaredNmsRadius, keypoints);
if (score > scoreThreshold) poses.push({ keypoints, score }); if (score > config.body.scoreThreshold) poses.push({ keypoints, score });
} }
return poses; return poses;
} }

View File

@ -1,5 +1,6 @@
import * as keypoints from './keypoints'; import * as keypoints from './keypoints';
import * as vectors from './vectors'; import * as vectors from './vectors';
import * as decoders from './decoders';
const parentChildrenTuples = keypoints.poseChain.map(([parentJoinName, childJoinName]) => ([keypoints.partIds[parentJoinName], keypoints.partIds[childJoinName]])); const parentChildrenTuples = keypoints.poseChain.map(([parentJoinName, childJoinName]) => ([keypoints.partIds[parentJoinName], keypoints.partIds[childJoinName]]));
const parentToChildEdges = parentChildrenTuples.map(([, childJointId]) => childJointId); const parentToChildEdges = parentChildrenTuples.map(([, childJointId]) => childJointId);
@ -17,13 +18,7 @@ function getStridedIndexNearPoint(point, outputStride, height, width) {
x: vectors.clamp(Math.round(point.x / outputStride), 0, width - 1), x: vectors.clamp(Math.round(point.x / outputStride), 0, width - 1),
}; };
} }
/**
* We get a new keypoint along the `edgeId` for the pose instance, assuming
* that the position of the `idSource` part is already known. For this, we
* follow the displacement vector from the source to target part (stored in
* the `i`-t channel of the displacement tensor). The displaced keypoint
* vector is refined using the offset vector by `offsetRefineStep` times.
*/
function traverseToTargetKeypoint(edgeId, sourceKeypoint, targetKeypointId, scoresBuffer, offsets, outputStride, displacements, offsetRefineStep = 2) { function traverseToTargetKeypoint(edgeId, sourceKeypoint, targetKeypointId, scoresBuffer, offsets, outputStride, displacements, offsetRefineStep = 2) {
const [height, width] = scoresBuffer.shape; const [height, width] = scoresBuffer.shape;
// Nearest neighbor interpolation for the source->target displacements. // Nearest neighbor interpolation for the source->target displacements.
@ -43,12 +38,7 @@ function traverseToTargetKeypoint(edgeId, sourceKeypoint, targetKeypointId, scor
const score = scoresBuffer.get(targetKeyPointIndices.y, targetKeyPointIndices.x, targetKeypointId); const score = scoresBuffer.get(targetKeyPointIndices.y, targetKeyPointIndices.x, targetKeypointId);
return { position: targetKeypoint, part: keypoints.partNames[targetKeypointId], score }; return { position: targetKeypoint, part: keypoints.partNames[targetKeypointId], score };
} }
/**
* Follows the displacement fields to decode the full pose of the object
* instance given the position of a part that acts as root.
*
* @return An array of decoded keypoints and their scores for a single pose
*/
function decodePose(root, scores, offsets, outputStride, displacementsFwd, displacementsBwd) { function decodePose(root, scores, offsets, outputStride, displacementsFwd, displacementsBwd) {
const numParts = scores.shape[2]; const numParts = scores.shape[2];
const numEdges = parentToChildEdges.length; const numEdges = parentToChildEdges.length;
@ -80,3 +70,31 @@ function decodePose(root, scores, offsets, outputStride, displacementsFwd, displ
return instanceKeypoints; return instanceKeypoints;
} }
exports.decodePose = decodePose; exports.decodePose = decodePose;
async function decodeSinglePose(heatmapScores, offsets, config) {
let totalScore = 0.0;
const heatmapValues = decoders.argmax2d(heatmapScores);
const allTensorBuffers = await Promise.all([heatmapScores.buffer(), offsets.buffer(), heatmapValues.buffer()]);
const scoresBuffer = allTensorBuffers[0];
const offsetsBuffer = allTensorBuffers[1];
const heatmapValuesBuffer = allTensorBuffers[2];
const offsetPoints = decoders.getOffsetPoints(heatmapValuesBuffer, config.body.outputStride, offsetsBuffer);
const offsetPointsBuffer = await offsetPoints.buffer();
const keypointConfidence = Array.from(decoders.getPointsConfidence(scoresBuffer, heatmapValuesBuffer));
const instanceKeypoints = keypointConfidence.map((score, i) => {
totalScore += score;
return {
position: {
y: offsetPointsBuffer.get(i, 0),
x: offsetPointsBuffer.get(i, 1),
},
part: keypoints.partNames[i],
score,
};
});
const filteredKeypoints = instanceKeypoints.filter((kpt) => kpt.score > config.body.scoreThreshold);
heatmapValues.dispose();
offsetPoints.dispose();
return { keypoints: filteredKeypoints, score: totalScore / instanceKeypoints.length };
}
exports.decodeSinglePose = decodeSinglePose;

View File

@ -12,12 +12,14 @@ function getPointsConfidence(heatmapScores, heatMapCoords) {
return result; return result;
} }
exports.getPointsConfidence = getPointsConfidence; exports.getPointsConfidence = getPointsConfidence;
function getOffsetPoint(y, x, keypoint, offsetsBuffer) { function getOffsetPoint(y, x, keypoint, offsetsBuffer) {
return { return {
y: offsetsBuffer.get(y, x, keypoint), y: offsetsBuffer.get(y, x, keypoint),
x: offsetsBuffer.get(y, x, keypoint + kpt.NUM_KEYPOINTS), x: offsetsBuffer.get(y, x, keypoint + kpt.NUM_KEYPOINTS),
}; };
} }
function getOffsetVectors(heatMapCoordsBuffer, offsetsBuffer) { function getOffsetVectors(heatMapCoordsBuffer, offsetsBuffer) {
const result = []; const result = [];
for (let keypoint = 0; keypoint < kpt.NUM_KEYPOINTS; keypoint++) { for (let keypoint = 0; keypoint < kpt.NUM_KEYPOINTS; keypoint++) {
@ -30,14 +32,9 @@ function getOffsetVectors(heatMapCoordsBuffer, offsetsBuffer) {
return tf.tensor2d(result, [kpt.NUM_KEYPOINTS, 2]); return tf.tensor2d(result, [kpt.NUM_KEYPOINTS, 2]);
} }
exports.getOffsetVectors = getOffsetVectors; exports.getOffsetVectors = getOffsetVectors;
function getOffsetPoints(heatMapCoordsBuffer, outputStride, offsetsBuffer) { function getOffsetPoints(heatMapCoordsBuffer, outputStride, offsetsBuffer) {
return tf.tidy(() => { return tf.tidy(() => heatMapCoordsBuffer.toTensor().mul(tf.scalar(outputStride, 'int32')).toFloat().add(getOffsetVectors(heatMapCoordsBuffer, offsetsBuffer)));
const offsetVectors = getOffsetVectors(heatMapCoordsBuffer, offsetsBuffer);
return heatMapCoordsBuffer.toTensor()
.mul(tf.scalar(outputStride, 'int32'))
.toFloat()
.add(offsetVectors);
});
} }
exports.getOffsetPoints = getOffsetPoints; exports.getOffsetPoints = getOffsetPoints;
@ -47,6 +44,7 @@ function mod(a, b) {
return a.sub(floored.mul(tf.scalar(b, 'int32'))); return a.sub(floored.mul(tf.scalar(b, 'int32')));
}); });
} }
function argmax2d(inputs) { function argmax2d(inputs) {
const [height, width, depth] = inputs.shape; const [height, width, depth] = inputs.shape;
return tf.tidy(() => { return tf.tidy(() => {

View File

@ -3,11 +3,14 @@ exports.partNames = [
'rightShoulder', 'leftElbow', 'rightElbow', 'leftWrist', 'rightWrist', 'rightShoulder', 'leftElbow', 'rightElbow', 'leftWrist', 'rightWrist',
'leftHip', 'rightHip', 'leftKnee', 'rightKnee', 'leftAnkle', 'rightAnkle', 'leftHip', 'rightHip', 'leftKnee', 'rightKnee', 'leftAnkle', 'rightAnkle',
]; ];
exports.NUM_KEYPOINTS = exports.partNames.length; exports.NUM_KEYPOINTS = exports.partNames.length;
exports.partIds = exports.partNames.reduce((result, jointName, i) => { exports.partIds = exports.partNames.reduce((result, jointName, i) => {
result[jointName] = i; result[jointName] = i;
return result; return result;
}, {}); }, {});
const connectedPartNames = [ const connectedPartNames = [
['leftHip', 'leftShoulder'], ['leftElbow', 'leftShoulder'], ['leftHip', 'leftShoulder'], ['leftElbow', 'leftShoulder'],
['leftElbow', 'leftWrist'], ['leftHip', 'leftKnee'], ['leftElbow', 'leftWrist'], ['leftHip', 'leftKnee'],
@ -16,12 +19,8 @@ const connectedPartNames = [
['rightHip', 'rightKnee'], ['rightKnee', 'rightAnkle'], ['rightHip', 'rightKnee'], ['rightKnee', 'rightAnkle'],
['leftShoulder', 'rightShoulder'], ['leftHip', 'rightHip'], ['leftShoulder', 'rightShoulder'], ['leftHip', 'rightHip'],
]; ];
/* exports.connectedPartIndices = connectedPartNames.map(([jointNameA, jointNameB]) => ([exports.partIds[jointNameA], exports.partIds[jointNameB]]));
* Define the skeleton. This defines the parent->child relationships of our
* tree. Arbitrarily this defines the nose as the root of the tree, however
* since we will infer the displacement for both parent->child and
* child->parent, we can define the tree root as any node.
*/
exports.poseChain = [ exports.poseChain = [
['nose', 'leftEye'], ['leftEye', 'leftEar'], ['nose', 'rightEye'], ['nose', 'leftEye'], ['leftEye', 'leftEar'], ['nose', 'rightEye'],
['rightEye', 'rightEar'], ['nose', 'leftShoulder'], ['rightEye', 'rightEar'], ['nose', 'leftShoulder'],
@ -32,7 +31,7 @@ exports.poseChain = [
['rightShoulder', 'rightHip'], ['rightHip', 'rightKnee'], ['rightShoulder', 'rightHip'], ['rightHip', 'rightKnee'],
['rightKnee', 'rightAnkle'], ['rightKnee', 'rightAnkle'],
]; ];
exports.connectedPartIndices = connectedPartNames.map(([jointNameA, jointNameB]) => ([exports.partIds[jointNameA], exports.partIds[jointNameB]]));
exports.partChannels = [ exports.partChannels = [
'left_face', 'left_face',
'right_face', 'right_face',

View File

@ -1,18 +1,29 @@
import * as tf from '../../dist/tfjs.esm.js'; import * as tf from '../../dist/tfjs.esm.js';
class BaseModel { const imageNetMean = [-123.15, -115.90, -103.06];
constructor(model, outputStride) {
this.model = model; function nameOutputResultsMobileNet(results) {
this.outputStride = outputStride; const [offsets, heatmap, displacementFwd, displacementBwd] = results;
return { offsets, heatmap, displacementFwd, displacementBwd };
} }
predict(input) { function nameOutputResultsResNet(results) {
const [displacementFwd, displacementBwd, offsets, heatmap] = results;
return { offsets, heatmap, displacementFwd, displacementBwd };
}
class BaseModel {
constructor(model) {
this.model = model;
}
predict(input, config) {
return tf.tidy(() => { return tf.tidy(() => {
const asFloat = this.preprocessInput(input.toFloat()); const asFloat = (config.body.modelType === 'ResNet') ? input.toFloat().add(imageNetMean) : input.toFloat().div(127.5).sub(1.0);
const asBatch = asFloat.expandDims(0); const asBatch = asFloat.expandDims(0);
const results = this.model.predict(asBatch); const results = this.model.predict(asBatch);
const results3d = results.map((y) => y.squeeze([0])); const results3d = results.map((y) => y.squeeze([0]));
const namedResults = this.nameOutputResults(results3d); const namedResults = (config.body.modelType === 'ResNet') ? nameOutputResultsResNet(results3d) : nameOutputResultsMobileNet(results3d);
return { return {
heatmapScores: namedResults.heatmap.sigmoid(), heatmapScores: namedResults.heatmap.sigmoid(),
offsets: namedResults.offsets, offsets: namedResults.offsets,
@ -22,9 +33,6 @@ class BaseModel {
}); });
} }
/**
* Releases the CPU and GPU memory allocated by the model.
*/
dispose() { dispose() {
this.model.dispose(); this.model.dispose();
} }

View File

@ -1,17 +0,0 @@
import * as tf from '../../dist/tfjs.esm.js';
import * as modelBase from './modelBase';
class MobileNet extends modelBase.BaseModel {
// eslint-disable-next-line class-methods-use-this
preprocessInput(input) {
// Normalize the pixels [0, 255] to be between [-1, 1].
return tf.tidy(() => tf.div(input, 127.5).sub(1.0));
}
// eslint-disable-next-line class-methods-use-this
nameOutputResults(results) {
const [offsets, heatmap, displacementFwd, displacementBwd] = results;
return { offsets, heatmap, displacementFwd, displacementBwd };
}
}
exports.MobileNet = MobileNet;

View File

@ -1,35 +1,54 @@
import { log } from '../log.js'; import { log } from '../log.js';
import * as tf from '../../dist/tfjs.esm.js'; import * as tf from '../../dist/tfjs.esm.js';
import * as modelMobileNet from './modelMobileNet'; import * as modelBase from './modelBase';
import * as decodeMultiple from './decodeMultiple'; import * as decodeMultiple from './decodeMultiple';
import * as decodePose from './decodePose';
import * as util from './util'; import * as util from './util';
class PoseNet { async function estimateMultiple(input, res, config) {
constructor(net) {
this.baseModel = net;
this.outputStride = 16;
}
async estimatePoses(input, config) {
return new Promise(async (resolve) => { return new Promise(async (resolve) => {
const height = input.shape[1]; const height = input.shape[1];
const width = input.shape[2]; const width = input.shape[2];
const resized = util.resizeTo(input, [config.body.inputSize, config.body.inputSize]);
const res = this.baseModel.predict(resized);
const allTensorBuffers = await util.toTensorBuffers3D([res.heatmapScores, res.offsets, res.displacementFwd, res.displacementBwd]); const allTensorBuffers = await util.toTensorBuffers3D([res.heatmapScores, res.offsets, res.displacementFwd, res.displacementBwd]);
const scoresBuffer = allTensorBuffers[0]; const scoresBuffer = allTensorBuffers[0];
const offsetsBuffer = allTensorBuffers[1]; const offsetsBuffer = allTensorBuffers[1];
const displacementsFwdBuffer = allTensorBuffers[2]; const displacementsFwdBuffer = allTensorBuffers[2];
const displacementsBwdBuffer = allTensorBuffers[3]; const displacementsBwdBuffer = allTensorBuffers[3];
const poses = await decodeMultiple.decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, this.outputStride, config.body.maxDetections, config.body.scoreThreshold, config.body.nmsRadius); const poses = await decodeMultiple.decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, config);
const resultPoses = util.scaleAndFlipPoses(poses, [height, width], [config.body.inputSize, config.body.inputSize]); const scaled = util.scaleAndFlipPoses(poses, [height, width], [config.body.inputSize, config.body.inputSize]);
resolve(scaled);
});
}
async function estimateSingle(input, res, config) {
return new Promise(async (resolve) => {
const height = input.shape[1];
const width = input.shape[2];
const pose = await decodePose.decodeSinglePose(res.heatmapScores, res.offsets, config);
const poses = [pose];
const scaled = util.scaleAndFlipPoses(poses, [height, width], [config.body.inputSize, config.body.inputSize]);
resolve(scaled);
});
}
class PoseNet {
constructor(model) {
this.baseModel = model;
}
async estimatePoses(input, config) {
const resized = util.resizeTo(input, [config.body.inputSize, config.body.inputSize]);
const res = this.baseModel.predict(resized, config);
const poses = (config.body.maxDetections < 2) ? await estimateSingle(input, res, config) : await estimateMultiple(input, res, config);
res.heatmapScores.dispose(); res.heatmapScores.dispose();
res.offsets.dispose(); res.offsets.dispose();
res.displacementFwd.dispose(); res.displacementFwd.dispose();
res.displacementBwd.dispose(); res.displacementBwd.dispose();
resized.dispose(); resized.dispose();
resolve(resultPoses);
}); return poses;
} }
dispose() { dispose() {
@ -39,8 +58,8 @@ class PoseNet {
exports.PoseNet = PoseNet; exports.PoseNet = PoseNet;
async function load(config) { async function load(config) {
const graphModel = await tf.loadGraphModel(config.body.modelPath); const model = await tf.loadGraphModel(config.body.modelPath);
const mobilenet = new modelMobileNet.MobileNet(graphModel, this.outputStride); const mobilenet = new modelBase.BaseModel(model);
log(`load model: ${config.body.modelPath.match(/\/(.*)\./)[1]}`); log(`load model: ${config.body.modelPath.match(/\/(.*)\./)[1]}`);
return new PoseNet(mobilenet); return new PoseNet(mobilenet);
} }

View File

@ -1,14 +1,10 @@
import * as modelMobileNet from './modelMobileNet';
import * as modelPoseNet from './modelPoseNet'; import * as modelPoseNet from './modelPoseNet';
import * as decodeMultiple from './decodeMultiple';
import * as keypoints from './keypoints'; import * as keypoints from './keypoints';
import * as util from './util'; import * as util from './util';
exports.load = modelPoseNet.load; exports.load = modelPoseNet.load;
exports.PoseNet = modelPoseNet.PoseNet; exports.PoseNet = modelPoseNet.PoseNet;
exports.MobileNet = modelMobileNet.MobileNet;
exports.decodeMultiplePoses = decodeMultiple.decodeMultiplePoses;
exports.partChannels = keypoints.partChannels; exports.partChannels = keypoints.partChannels;
exports.partIds = keypoints.partIds; exports.partIds = keypoints.partIds;
exports.partNames = keypoints.partNames; exports.partNames = keypoints.partNames;