human/src/handpose/handpipeline.ts

157 lines
7.0 KiB
TypeScript
Raw Normal View History

2020-11-18 14:26:28 +01:00
import * as tf from '../../dist/tfjs.esm.js';
2020-11-10 02:13:38 +01:00
import * as box from './box';
import * as util from './util';
2020-11-04 07:11:24 +01:00
2020-12-10 21:46:45 +01:00
// const PALM_BOX_SHIFT_VECTOR = [0, -0.4];
const PALM_BOX_ENLARGE_FACTOR = 5; // default 3
// const HAND_BOX_SHIFT_VECTOR = [0, -0.1]; // move detected hand box by x,y to ease landmark detection
const HAND_BOX_ENLARGE_FACTOR = 1.65; // default 1.65
2020-11-04 07:11:24 +01:00
const PALM_LANDMARK_IDS = [0, 5, 9, 13, 17, 1, 2];
const PALM_LANDMARKS_INDEX_OF_PALM_BASE = 0;
const PALM_LANDMARKS_INDEX_OF_MIDDLE_FINGER_BASE = 2;
2021-02-08 17:39:09 +01:00
export class HandPipeline {
handDetector: any;
landmarkDetector: any;
inputSize: number;
storedBoxes: any;
skipped: number;
detectedHands: number;
2021-04-25 22:56:10 +02:00
constructor(handDetector, landmarkDetector) {
2020-11-26 16:37:04 +01:00
this.handDetector = handDetector;
this.landmarkDetector = landmarkDetector;
2021-04-25 22:56:10 +02:00
this.inputSize = this.landmarkDetector?.inputs[0].shape[2];
2020-11-08 07:17:25 +01:00
this.storedBoxes = [];
2020-12-11 16:11:49 +01:00
this.skipped = 0;
2020-11-04 20:59:30 +01:00
this.detectedHands = 0;
2020-11-04 07:11:24 +01:00
}
getBoxForPalmLandmarks(palmLandmarks, rotationMatrix) {
2020-12-10 21:46:45 +01:00
const rotatedPalmLandmarks = palmLandmarks.map((coord) => util.rotatePoint([...coord, 1], rotationMatrix));
2020-11-04 07:11:24 +01:00
const boxAroundPalm = this.calculateLandmarksBoundingBox(rotatedPalmLandmarks);
2020-12-10 21:46:45 +01:00
// return box.enlargeBox(box.squarifyBox(box.shiftBox(boxAroundPalm, PALM_BOX_SHIFT_VECTOR)), PALM_BOX_ENLARGE_FACTOR);
return box.enlargeBox(box.squarifyBox(boxAroundPalm), PALM_BOX_ENLARGE_FACTOR);
2020-11-04 07:11:24 +01:00
}
getBoxForHandLandmarks(landmarks) {
const boundingBox = this.calculateLandmarksBoundingBox(landmarks);
2020-12-10 21:46:45 +01:00
// const boxAroundHand = box.enlargeBox(box.squarifyBox(box.shiftBox(boundingBox, HAND_BOX_SHIFT_VECTOR)), HAND_BOX_ENLARGE_FACTOR);
const boxAroundHand = box.enlargeBox(box.squarifyBox(boundingBox), HAND_BOX_ENLARGE_FACTOR);
boxAroundHand.palmLandmarks = [];
2020-11-04 07:11:24 +01:00
for (let i = 0; i < PALM_LANDMARK_IDS.length; i++) {
2020-12-10 21:46:45 +01:00
boxAroundHand.palmLandmarks.push(landmarks[PALM_LANDMARK_IDS[i]].slice(0, 2));
2020-11-04 07:11:24 +01:00
}
return boxAroundHand;
}
transformRawCoords(rawCoords, box2, angle, rotationMatrix) {
const boxSize = box.getBoxSize(box2);
2020-12-17 01:16:54 +01:00
const scaleFactor = [boxSize[0] / this.inputSize, boxSize[1] / this.inputSize, (boxSize[0] + boxSize[1]) / this.inputSize / 2];
2020-11-04 07:11:24 +01:00
const coordsScaled = rawCoords.map((coord) => [
scaleFactor[0] * (coord[0] - this.inputSize / 2),
scaleFactor[1] * (coord[1] - this.inputSize / 2),
2020-12-17 01:16:54 +01:00
scaleFactor[2] * coord[2],
2020-11-04 07:11:24 +01:00
]);
const coordsRotationMatrix = util.buildRotationMatrix(angle, [0, 0]);
const coordsRotated = coordsScaled.map((coord) => {
const rotated = util.rotatePoint(coord, coordsRotationMatrix);
return [...rotated, coord[2]];
});
const inverseRotationMatrix = util.invertTransformMatrix(rotationMatrix);
const boxCenter = [...box.getBoxCenter(box2), 1];
const originalBoxCenter = [
util.dot(boxCenter, inverseRotationMatrix[0]),
util.dot(boxCenter, inverseRotationMatrix[1]),
];
return coordsRotated.map((coord) => [
coord[0] + originalBoxCenter[0],
coord[1] + originalBoxCenter[1],
coord[2],
]);
}
async estimateHands(image, config) {
2020-11-08 07:17:25 +01:00
let useFreshBox = false;
2020-11-08 15:56:02 +01:00
// run new detector every skipFrames unless we only want box to start with
let boxes;
2020-12-11 16:11:49 +01:00
if ((this.skipped === 0) || (this.skipped > config.hand.skipFrames) || !config.hand.landmarks || !config.videoOptimized) {
2020-11-26 16:37:04 +01:00
boxes = await this.handDetector.estimateHandBounds(image, config);
2020-12-11 16:11:49 +01:00
this.skipped = 0;
2020-11-08 15:56:02 +01:00
}
2020-12-11 16:11:49 +01:00
if (config.videoOptimized) this.skipped++;
2020-11-08 15:56:02 +01:00
2020-11-08 07:17:25 +01:00
// if detector result count doesn't match current working set, use it to reset current working set
2021-04-25 19:16:04 +02:00
if (boxes && (boxes.length > 0) && ((boxes.length !== this.detectedHands) && (this.detectedHands !== config.hand.maxDetected) || !config.hand.landmarks)) {
2020-11-08 07:17:25 +01:00
this.detectedHands = 0;
2020-11-26 16:37:04 +01:00
this.storedBoxes = [...boxes];
// for (const possible of boxes) this.storedBoxes.push(possible);
2020-11-08 07:17:25 +01:00
if (this.storedBoxes.length > 0) useFreshBox = true;
2020-11-04 07:11:24 +01:00
}
2021-02-08 18:47:38 +01:00
const hands: Array<{}> = [];
2021-03-01 23:20:02 +01:00
if (config.hand.skipInitial && this.detectedHands === 0) this.skipped = 0;
2020-11-08 15:56:02 +01:00
2020-11-08 07:17:25 +01:00
// go through working set of boxes
2020-11-26 16:37:04 +01:00
for (let i = 0; i < this.storedBoxes.length; i++) {
2020-11-08 07:17:25 +01:00
const currentBox = this.storedBoxes[i];
2020-11-04 07:11:24 +01:00
if (!currentBox) continue;
if (config.hand.landmarks) {
2020-12-10 21:46:45 +01:00
const angle = config.hand.rotation ? util.computeRotation(currentBox.palmLandmarks[PALM_LANDMARKS_INDEX_OF_PALM_BASE], currentBox.palmLandmarks[PALM_LANDMARKS_INDEX_OF_MIDDLE_FINGER_BASE]) : 0;
2020-11-08 15:56:02 +01:00
const palmCenter = box.getBoxCenter(currentBox);
const palmCenterNormalized = [palmCenter[0] / image.shape[2], palmCenter[1] / image.shape[1]];
2020-12-10 21:46:45 +01:00
const rotatedImage = config.hand.rotation ? tf.image.rotateWithOffset(image, angle, 0, palmCenterNormalized) : image.clone();
2020-11-08 15:56:02 +01:00
const rotationMatrix = util.buildRotationMatrix(-angle, palmCenter);
const newBox = useFreshBox ? this.getBoxForPalmLandmarks(currentBox.palmLandmarks, rotationMatrix) : currentBox;
const croppedInput = box.cutBoxFromImageAndResize(newBox, rotatedImage, [this.inputSize, this.inputSize]);
const handImage = croppedInput.div(255);
croppedInput.dispose();
rotatedImage.dispose();
2020-11-26 16:37:04 +01:00
const [confidenceT, keypoints] = await this.landmarkDetector.predict(handImage);
2020-11-08 15:56:02 +01:00
handImage.dispose();
2020-11-26 16:37:04 +01:00
const confidence = confidenceT.dataSync()[0];
confidenceT.dispose();
if (confidence >= config.hand.minConfidence) {
2020-11-08 15:56:02 +01:00
const keypointsReshaped = tf.reshape(keypoints, [-1, 3]);
const rawCoords = keypointsReshaped.arraySync();
keypoints.dispose();
keypointsReshaped.dispose();
const coords = this.transformRawCoords(rawCoords, newBox, angle, rotationMatrix);
const nextBoundingBox = this.getBoxForHandLandmarks(coords);
this.storedBoxes[i] = nextBoundingBox;
const result = {
landmarks: coords,
2020-11-26 16:37:04 +01:00
confidence,
2021-02-08 18:47:38 +01:00
box: { topLeft: nextBoundingBox.startPoint, bottomRight: nextBoundingBox.endPoint },
2020-11-08 15:56:02 +01:00
};
hands.push(result);
} else {
this.storedBoxes[i] = null;
}
2020-11-04 07:11:24 +01:00
keypoints.dispose();
2020-11-04 20:59:30 +01:00
} else {
2020-12-10 21:46:45 +01:00
// const enlarged = box.enlargeBox(box.squarifyBox(box.shiftBox(currentBox, HAND_BOX_SHIFT_VECTOR)), HAND_BOX_ENLARGE_FACTOR);
const enlarged = box.enlargeBox(box.squarifyBox(currentBox), HAND_BOX_ENLARGE_FACTOR);
2020-11-04 20:59:30 +01:00
const result = {
2020-11-08 15:56:02 +01:00
confidence: currentBox.confidence,
2021-02-08 18:47:38 +01:00
box: { topLeft: enlarged.startPoint, bottomRight: enlarged.endPoint },
2020-11-04 20:59:30 +01:00
};
hands.push(result);
2020-11-04 07:11:24 +01:00
}
}
2020-11-08 07:17:25 +01:00
this.storedBoxes = this.storedBoxes.filter((a) => a !== null);
2020-11-04 20:59:30 +01:00
this.detectedHands = hands.length;
2020-11-04 07:11:24 +01:00
return hands;
}
// eslint-disable-next-line class-methods-use-this
calculateLandmarksBoundingBox(landmarks) {
const xs = landmarks.map((d) => d[0]);
const ys = landmarks.map((d) => d[1]);
const startPoint = [Math.min(...xs), Math.min(...ys)];
const endPoint = [Math.max(...xs), Math.max(...ys)];
return { startPoint, endPoint };
}
}