human/src/face/facemesh.js

58 lines
2.3 KiB
JavaScript
Raw Normal View History

2020-10-12 01:22:43 +02:00
const tf = require('@tensorflow/tfjs');
2020-10-13 04:01:35 +02:00
const blazeface = require('./blazeface');
2020-10-12 01:22:43 +02:00
const keypoints = require('./keypoints');
2020-11-04 07:11:24 +01:00
const pipe = require('./facepipeline');
2020-10-12 01:22:43 +02:00
const uv_coords = require('./uvcoords');
2020-10-12 03:21:41 +02:00
const triangulation = require('./triangulation').default;
2020-10-12 01:22:43 +02:00
class MediaPipeFaceMesh {
constructor(blazeFace, blazeMeshModel, irisModel, config) {
this.pipeline = new pipe.Pipeline(blazeFace, blazeMeshModel, irisModel, config);
2020-10-13 04:01:35 +02:00
if (config) this.config = config;
2020-10-12 01:22:43 +02:00
}
async estimateFaces(input, config) {
if (config) this.config = config;
const predictions = await this.pipeline.predict(input, config);
2020-10-13 04:01:35 +02:00
const results = [];
for (const prediction of (predictions || [])) {
2020-10-17 13:15:23 +02:00
// guard against disposed tensors on long running operations such as pause in middle of processing
if (prediction.isDisposedInternal) continue;
2020-11-06 19:50:16 +01:00
const mesh = prediction.coords ? prediction.coords.arraySync() : null;
const annotations = {};
if (mesh && mesh.length > 0) {
for (const key in keypoints.MESH_ANNOTATIONS) {
if (this.config.iris.enabled || key.includes('Iris') === false) {
annotations[key] = keypoints.MESH_ANNOTATIONS[key].map((index) => mesh[index]);
2020-10-12 01:22:43 +02:00
}
}
}
2020-11-06 19:50:16 +01:00
results.push({
confidence: prediction.confidence || 0,
box: prediction.box ? [prediction.box.startPoint[0], prediction.box.startPoint[1], prediction.box.endPoint[0] - prediction.box.startPoint[0], prediction.box.endPoint[1] - prediction.box.startPoint[1]] : 0,
mesh,
annotations,
image: prediction.image ? tf.clone(prediction.image) : null,
});
if (prediction.coords) prediction.coords.dispose();
if (prediction.image) prediction.image.dispose();
2020-10-12 01:22:43 +02:00
}
return results;
}
}
2020-10-13 04:01:35 +02:00
async function load(config) {
const models = await Promise.all([
blazeface.load(config),
tf.loadGraphModel(config.mesh.modelPath, { fromTFHub: config.mesh.modelPath.includes('tfhub.dev') }),
tf.loadGraphModel(config.iris.modelPath, { fromTFHub: config.iris.modelPath.includes('tfhub.dev') }),
]);
const faceMesh = new MediaPipeFaceMesh(models[0], models[1], models[2], config);
return faceMesh;
}
exports.load = load;
2020-10-12 01:22:43 +02:00
exports.MediaPipeFaceMesh = MediaPipeFaceMesh;
2020-10-13 04:01:35 +02:00
exports.uv_coords = uv_coords;
exports.triangulation = triangulation;