human/src/hand/handdetector.js

99 lines
4.1 KiB
JavaScript
Raw Normal View History

2020-11-04 07:11:24 +01:00
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
2020-11-17 18:38:48 +01:00
import { tf } from '../../dist/tfjs.esm.js';
2020-11-10 02:13:38 +01:00
import * as box from './box';
2020-10-12 01:22:43 +02:00
class HandDetector {
2020-11-04 07:11:24 +01:00
constructor(model, inputSize, anchorsAnnotated) {
2020-10-12 01:22:43 +02:00
this.model = model;
2020-11-04 07:11:24 +01:00
this.anchors = anchorsAnnotated.map((anchor) => [anchor.x_center, anchor.y_center]);
2020-10-12 01:22:43 +02:00
this.anchorsTensor = tf.tensor2d(this.anchors);
2020-11-04 07:11:24 +01:00
this.inputSizeTensor = tf.tensor1d([inputSize, inputSize]);
this.doubleInputSizeTensor = tf.tensor1d([inputSize * 2, inputSize * 2]);
2020-10-12 01:22:43 +02:00
}
normalizeBoxes(boxes) {
return tf.tidy(() => {
const boxOffsets = tf.slice(boxes, [0, 0], [-1, 2]);
const boxSizes = tf.slice(boxes, [0, 2], [-1, 2]);
const boxCenterPoints = tf.add(tf.div(boxOffsets, this.inputSizeTensor), this.anchorsTensor);
const halfBoxSizes = tf.div(boxSizes, this.doubleInputSizeTensor);
const startPoints = tf.mul(tf.sub(boxCenterPoints, halfBoxSizes), this.inputSizeTensor);
const endPoints = tf.mul(tf.add(boxCenterPoints, halfBoxSizes), this.inputSizeTensor);
return tf.concat2d([startPoints, endPoints], 1);
});
}
normalizeLandmarks(rawPalmLandmarks, index) {
return tf.tidy(() => {
const landmarks = tf.add(tf.div(rawPalmLandmarks.reshape([-1, 7, 2]), this.inputSizeTensor), this.anchors[index]);
return tf.mul(landmarks, this.inputSizeTensor);
});
}
2020-11-08 07:17:25 +01:00
async getBoxes(input, config) {
const batched = this.model.predict(input);
const predictions = batched.squeeze();
2020-11-08 15:56:02 +01:00
batched.dispose();
2020-11-08 07:17:25 +01:00
const scores = tf.tidy(() => tf.sigmoid(tf.slice(predictions, [0, 0], [-1, 1])).squeeze());
2020-11-08 15:56:02 +01:00
const scoresVal = scores.dataSync();
2020-11-08 07:17:25 +01:00
const rawBoxes = tf.slice(predictions, [0, 1], [-1, 4]);
2020-10-12 01:22:43 +02:00
const boxes = this.normalizeBoxes(rawBoxes);
2020-11-08 15:56:02 +01:00
rawBoxes.dispose();
const filteredT = await tf.image.nonMaxSuppressionAsync(boxes, scores, config.maxHands, config.iouThreshold, config.scoreThreshold);
const filtered = filteredT.arraySync();
2020-11-08 18:26:45 +01:00
2020-11-08 15:56:02 +01:00
scores.dispose();
filteredT.dispose();
2020-11-04 07:11:24 +01:00
const hands = [];
2020-11-08 15:56:02 +01:00
for (const boxIndex of filtered) {
if (scoresVal[boxIndex] >= config.minConfidence) {
const matchingBox = tf.slice(boxes, [boxIndex, 0], [1, -1]);
const rawPalmLandmarks = tf.slice(predictions, [boxIndex, 5], [1, 14]);
const palmLandmarks = tf.tidy(() => this.normalizeLandmarks(rawPalmLandmarks, boxIndex).reshape([-1, 2]));
rawPalmLandmarks.dispose();
hands.push({ box: matchingBox, palmLandmarks, confidence: scoresVal[boxIndex] });
}
2020-11-04 07:11:24 +01:00
}
2020-11-08 15:56:02 +01:00
predictions.dispose();
boxes.dispose();
2020-11-04 07:11:24 +01:00
return hands;
2020-10-12 01:22:43 +02:00
}
async estimateHandBounds(input, config) {
2020-11-04 07:11:24 +01:00
const inputHeight = input.shape[1];
const inputWidth = input.shape[2];
2020-11-04 20:59:30 +01:00
const image = tf.tidy(() => input.resizeBilinear([config.inputSize, config.inputSize]).div(127.5).sub(1));
2020-11-08 07:17:25 +01:00
const predictions = await this.getBoxes(image, config);
2020-11-04 20:59:30 +01:00
image.dispose();
if (!predictions || predictions.length === 0) return null;
2020-10-14 17:43:33 +02:00
const hands = [];
2020-11-04 07:11:24 +01:00
for (const prediction of predictions) {
2020-11-08 15:56:02 +01:00
const boxes = prediction.box.dataSync();
const startPoint = boxes.slice(0, 2);
const endPoint = boxes.slice(2, 4);
const palmLandmarks = prediction.palmLandmarks.arraySync();
2020-11-08 07:17:25 +01:00
prediction.box.dispose();
2020-10-14 17:43:33 +02:00
prediction.palmLandmarks.dispose();
2020-11-08 15:56:02 +01:00
hands.push(box.scaleBoxCoordinates({ startPoint, endPoint, palmLandmarks, confidence: prediction.confidence }, [inputWidth / config.inputSize, inputHeight / config.inputSize]));
2020-10-14 17:43:33 +02:00
}
return hands;
2020-10-12 01:22:43 +02:00
}
}
exports.HandDetector = HandDetector;