human/src/posenet/decoders.ts

55 lines
1.9 KiB
TypeScript
Raw Normal View History

2020-11-18 14:26:28 +01:00
import * as tf from '../../dist/tfjs.esm.js';
2020-11-10 02:13:38 +01:00
import * as kpt from './keypoints';
2020-10-12 01:22:43 +02:00
2021-02-08 17:39:09 +01:00
export function getPointsConfidence(heatmapScores, heatMapCoords) {
2020-10-12 01:22:43 +02:00
const numKeypoints = heatMapCoords.shape[0];
const result = new Float32Array(numKeypoints);
for (let keypoint = 0; keypoint < numKeypoints; keypoint++) {
const y = heatMapCoords.get(keypoint, 0);
const x = heatMapCoords.get(keypoint, 1);
result[keypoint] = heatmapScores.get(y, x, keypoint);
}
return result;
}
2020-12-17 00:36:24 +01:00
2020-10-12 01:22:43 +02:00
function getOffsetPoint(y, x, keypoint, offsetsBuffer) {
return {
y: offsetsBuffer.get(y, x, keypoint),
x: offsetsBuffer.get(y, x, keypoint + kpt.NUM_KEYPOINTS),
};
}
2020-12-17 00:36:24 +01:00
2021-02-08 17:39:09 +01:00
export function getOffsetVectors(heatMapCoordsBuffer, offsetsBuffer) {
2021-02-08 18:47:38 +01:00
const result: Array<number> = [];
2020-10-12 01:22:43 +02:00
for (let keypoint = 0; keypoint < kpt.NUM_KEYPOINTS; keypoint++) {
const heatmapY = heatMapCoordsBuffer.get(keypoint, 0).valueOf();
const heatmapX = heatMapCoordsBuffer.get(keypoint, 1).valueOf();
const { x, y } = getOffsetPoint(heatmapY, heatmapX, keypoint, offsetsBuffer);
result.push(y);
result.push(x);
}
return tf.tensor2d(result, [kpt.NUM_KEYPOINTS, 2]);
}
2020-12-17 00:36:24 +01:00
2021-02-08 17:39:09 +01:00
export function getOffsetPoints(heatMapCoordsBuffer, outputStride, offsetsBuffer) {
2020-12-17 00:36:24 +01:00
return tf.tidy(() => heatMapCoordsBuffer.toTensor().mul(tf.scalar(outputStride, 'int32')).toFloat().add(getOffsetVectors(heatMapCoordsBuffer, offsetsBuffer)));
2020-10-12 01:22:43 +02:00
}
function mod(a, b) {
return tf.tidy(() => {
const floored = a.div(tf.scalar(b, 'int32'));
return a.sub(floored.mul(tf.scalar(b, 'int32')));
});
}
2020-12-17 00:36:24 +01:00
2021-02-08 17:39:09 +01:00
export function argmax2d(inputs) {
2020-10-12 01:22:43 +02:00
const [height, width, depth] = inputs.shape;
return tf.tidy(() => {
const reshaped = inputs.reshape([height * width, depth]);
const coords = reshaped.argMax(0);
const yCoords = coords.div(tf.scalar(width, 'int32')).expandDims(1);
const xCoords = mod(coords, width).expandDims(1);
return tf.concat([yCoords, xCoords], 1);
});
}