2021-09-21 22:48:16 +02:00
/ * *
2021-09-25 17:51:15 +02:00
* HandTrack model implementation
*
* Based on :
* - Hand Detection & Skeleton : [ * * MediaPipe HandPose * * ] ( https : //drive.google.com/file/d/1sv4sSb9BSNVZhLzxXJ0jBv9DqD-4jnAz/view)
* - Hand Tracking : [ * * HandTracking * * ] ( https : //github.com/victordibia/handtracking)
2021-09-21 22:48:16 +02:00
* /
2021-10-22 22:09:52 +02:00
import { log , join , now } from '../util/util' ;
2021-10-09 00:39:04 +02:00
import * as box from '../util/box' ;
2021-09-21 22:48:16 +02:00
import * as tf from '../../dist/tfjs.esm.js' ;
2021-12-15 15:26:32 +01:00
import type { HandResult , HandType , Box , Point } from '../result' ;
2021-09-21 22:48:16 +02:00
import type { GraphModel , Tensor } from '../tfjs/types' ;
import type { Config } from '../config' ;
2021-09-27 19:58:13 +02:00
import { env } from '../util/env' ;
2021-09-28 18:01:48 +02:00
import * as fingerPose from './fingerpose' ;
2021-09-23 20:09:41 +02:00
import { fakeOps } from '../tfjs/backend' ;
2021-11-17 02:16:49 +01:00
import { constants } from '../tfjs/constants' ;
2021-09-21 22:48:16 +02:00
const models : [ GraphModel | null , GraphModel | null ] = [ null , null ] ;
const modelOutputNodes = [ 'StatefulPartitionedCall/Postprocessor/Slice' , 'StatefulPartitionedCall/Postprocessor/ExpandDims_1' ] ;
2021-09-23 20:09:41 +02:00
2021-09-22 21:16:14 +02:00
const inputSize = [ [ 0 , 0 ] , [ 0 , 0 ] ] ;
2021-09-21 22:48:16 +02:00
2021-09-27 14:53:41 +02:00
const classes = [ 'hand' , 'fist' , 'pinch' , 'point' , 'face' , 'tip' , 'pinchtip' ] ;
2021-10-11 04:29:20 +02:00
const faceIndex = 4 ;
2021-09-21 22:48:16 +02:00
2021-11-08 13:36:26 +01:00
const boxExpandFact = 1.6 ;
2021-10-10 23:52:43 +02:00
const maxDetectorResolution = 512 ;
2021-10-11 04:29:20 +02:00
const detectorExpandFact = 1.4 ;
2021-10-09 00:39:04 +02:00
2021-10-23 15:38:52 +02:00
let skipped = Number . MAX_SAFE_INTEGER ;
2021-10-22 22:09:52 +02:00
let lastTime = 0 ;
2021-10-09 00:39:04 +02:00
let outputSize : [ number , number ] = [ 0 , 0 ] ;
2021-09-21 22:48:16 +02:00
type HandDetectResult = {
id : number ,
score : number ,
2021-09-27 15:19:43 +02:00
box : Box ,
boxRaw : Box ,
2021-10-09 00:39:04 +02:00
boxCrop : Box ,
2021-12-15 15:26:32 +01:00
label : HandType ,
2021-09-21 22:48:16 +02:00
}
2021-09-22 21:16:14 +02:00
const cache : {
2021-10-09 00:39:04 +02:00
boxes : Array < HandDetectResult > ,
hands : Array < HandResult > ;
2021-09-22 21:16:14 +02:00
} = {
2021-10-09 00:39:04 +02:00
boxes : [ ] ,
hands : [ ] ,
2021-09-22 21:16:14 +02:00
} ;
2021-09-21 22:48:16 +02:00
const fingerMap = {
2021-11-24 22:17:03 +01:00
thumb : [ 0 , 1 , 2 , 3 , 4 ] ,
index : [ 0 , 5 , 6 , 7 , 8 ] ,
middle : [ 0 , 9 , 10 , 11 , 12 ] ,
ring : [ 0 , 13 , 14 , 15 , 16 ] ,
pinky : [ 0 , 17 , 18 , 19 , 20 ] ,
2021-09-21 22:48:16 +02:00
palm : [ 0 ] ,
} ;
2021-09-23 20:09:41 +02:00
export async function loadDetect ( config : Config ) : Promise < GraphModel > {
2021-09-24 15:55:27 +02:00
// HandTrack Model: Original: <https://github.com/victordibia/handtracking> TFJS Port: <https://github.com/victordibia/handtrack.js/>
2021-09-23 20:09:41 +02:00
if ( env . initial ) models [ 0 ] = null ;
2021-09-21 22:48:16 +02:00
if ( ! models [ 0 ] ) {
2021-09-23 20:09:41 +02:00
// handtrack model has some kernel ops defined in model but those are never referenced and non-existent in tfjs
// ideally need to prune the model itself
fakeOps ( [ 'tensorlistreserve' , 'enter' , 'tensorlistfromtensor' , 'merge' , 'loopcond' , 'switch' , 'exit' , 'tensorliststack' , 'nextiteration' , 'tensorlistsetitem' , 'tensorlistgetitem' , 'reciprocal' , 'shape' , 'split' , 'where' ] , config ) ;
2021-09-21 22:48:16 +02:00
models [ 0 ] = await tf . loadGraphModel ( join ( config . modelBasePath , config . hand . detector ? . modelPath || '' ) ) as unknown as GraphModel ;
const inputs = Object . values ( models [ 0 ] . modelSignature [ 'inputs' ] ) ;
2021-09-22 21:16:14 +02:00
inputSize [ 0 ] [ 0 ] = Array . isArray ( inputs ) ? parseInt ( inputs [ 0 ] . tensorShape . dim [ 1 ] . size ) : 0 ;
inputSize [ 0 ] [ 1 ] = Array . isArray ( inputs ) ? parseInt ( inputs [ 0 ] . tensorShape . dim [ 2 ] . size ) : 0 ;
2021-10-13 16:56:56 +02:00
if ( ! models [ 0 ] || ! models [ 0 ] [ 'modelUrl' ] ) log ( 'load model failed:' , config . hand . detector ? . modelPath ) ;
2021-09-21 22:48:16 +02:00
else if ( config . debug ) log ( 'load model:' , models [ 0 ] [ 'modelUrl' ] ) ;
} else if ( config . debug ) log ( 'cached model:' , models [ 0 ] [ 'modelUrl' ] ) ;
2021-09-23 20:09:41 +02:00
return models [ 0 ] ;
}
export async function loadSkeleton ( config : Config ) : Promise < GraphModel > {
if ( env . initial ) models [ 1 ] = null ;
2021-09-21 22:48:16 +02:00
if ( ! models [ 1 ] ) {
models [ 1 ] = await tf . loadGraphModel ( join ( config . modelBasePath , config . hand . skeleton ? . modelPath || '' ) ) as unknown as GraphModel ;
const inputs = Object . values ( models [ 1 ] . modelSignature [ 'inputs' ] ) ;
2021-09-22 21:16:14 +02:00
inputSize [ 1 ] [ 0 ] = Array . isArray ( inputs ) ? parseInt ( inputs [ 0 ] . tensorShape . dim [ 1 ] . size ) : 0 ;
inputSize [ 1 ] [ 1 ] = Array . isArray ( inputs ) ? parseInt ( inputs [ 0 ] . tensorShape . dim [ 2 ] . size ) : 0 ;
2021-10-13 16:56:56 +02:00
if ( ! models [ 1 ] || ! models [ 1 ] [ 'modelUrl' ] ) log ( 'load model failed:' , config . hand . skeleton ? . modelPath ) ;
2021-09-21 22:48:16 +02:00
else if ( config . debug ) log ( 'load model:' , models [ 1 ] [ 'modelUrl' ] ) ;
} else if ( config . debug ) log ( 'cached model:' , models [ 1 ] [ 'modelUrl' ] ) ;
2021-09-23 20:09:41 +02:00
return models [ 1 ] ;
}
export async function load ( config : Config ) : Promise < [ GraphModel | null , GraphModel | null ] > {
if ( ! models [ 0 ] ) await loadDetect ( config ) ;
if ( ! models [ 1 ] ) await loadSkeleton ( config ) ;
return models ;
2021-09-21 22:48:16 +02:00
}
async function detectHands ( input : Tensor , config : Config ) : Promise < HandDetectResult [ ] > {
const hands : HandDetectResult [ ] = [ ] ;
if ( ! input || ! models [ 0 ] ) return hands ;
const t : Record < string , Tensor > = { } ;
2021-09-22 21:16:14 +02:00
const ratio = ( input . shape [ 2 ] || 1 ) / ( input . shape [ 1 ] || 1 ) ;
2021-10-10 23:52:43 +02:00
const height = Math . min ( Math . round ( ( input . shape [ 1 ] || 0 ) / 8 ) * 8 , maxDetectorResolution ) ; // use dynamic input size but cap at 512
2021-09-22 21:16:14 +02:00
const width = Math . round ( height * ratio / 8 ) * 8 ;
t . resize = tf . image . resizeBilinear ( input , [ height , width ] ) ; // todo: resize with padding
2021-09-21 22:48:16 +02:00
t . cast = tf . cast ( t . resize , 'int32' ) ;
[ t . rawScores , t . rawBoxes ] = await models [ 0 ] . executeAsync ( t . cast , modelOutputNodes ) as Tensor [ ] ;
t . boxes = tf . squeeze ( t . rawBoxes , [ 0 , 2 ] ) ;
t . scores = tf . squeeze ( t . rawScores , [ 0 ] ) ;
2021-10-11 04:29:20 +02:00
const classScores : Array < Tensor > = tf . unstack ( t . scores , 1 ) ; // unstack scores based on classes
tf . dispose ( classScores [ faceIndex ] ) ;
classScores . splice ( faceIndex , 1 ) ; // remove faces
2021-10-09 00:39:04 +02:00
t . filtered = tf . stack ( classScores , 1 ) ; // restack
2021-10-11 04:29:20 +02:00
tf . dispose ( classScores ) ;
2021-10-09 00:39:04 +02:00
t . max = tf . max ( t . filtered , 1 ) ; // max overall score
t . argmax = tf . argMax ( t . filtered , 1 ) ; // class index of max overall score
2021-09-21 22:48:16 +02:00
let id = 0 ;
2021-10-09 00:39:04 +02:00
t . nms = await tf . image . nonMaxSuppressionAsync ( t . boxes , t . max , config . hand . maxDetected , config . hand . iouThreshold , config . hand . minConfidence ) ;
const nms = await t . nms . data ( ) ;
const scores = await t . max . data ( ) ;
const classNum = await t . argmax . data ( ) ;
for ( const nmsIndex of Array . from ( nms ) ) { // generates results for each class
const boxSlice = tf . slice ( t . boxes , nmsIndex , 1 ) ;
2021-10-11 04:29:20 +02:00
const boxYX = await boxSlice . data ( ) ;
2021-10-09 00:39:04 +02:00
tf . dispose ( boxSlice ) ;
2021-10-11 04:29:20 +02:00
// const boxSquareSize = Math.max(boxData[3] - boxData[1], boxData[2] - boxData[0]);
const boxData : Box = [ boxYX [ 1 ] , boxYX [ 0 ] , boxYX [ 3 ] - boxYX [ 1 ] , boxYX [ 2 ] - boxYX [ 0 ] ] ; // yx box reshaped to standard box
const boxRaw : Box = box . scale ( boxData , detectorExpandFact ) ;
2021-10-10 23:52:43 +02:00
const boxCrop : Box = box . crop ( boxRaw ) ; // crop box is based on raw box
2021-10-11 04:29:20 +02:00
const boxFull : Box = [ Math . trunc ( boxData [ 0 ] * outputSize [ 0 ] ) , Math . trunc ( boxData [ 1 ] * outputSize [ 1 ] ) , Math . trunc ( boxData [ 2 ] * outputSize [ 0 ] ) , Math . trunc ( boxData [ 3 ] * outputSize [ 1 ] ) ] ;
2021-10-09 00:39:04 +02:00
const score = scores [ nmsIndex ] ;
2021-12-15 15:26:32 +01:00
const label = classes [ classNum [ nmsIndex ] ] as HandType ;
2021-10-10 23:52:43 +02:00
const hand : HandDetectResult = { id : id ++ , score , box : boxFull , boxRaw , boxCrop , label } ;
2021-10-09 00:39:04 +02:00
hands . push ( hand ) ;
2021-09-21 22:48:16 +02:00
}
Object . keys ( t ) . forEach ( ( tensor ) = > tf . dispose ( t [ tensor ] ) ) ;
2021-09-22 21:16:14 +02:00
hands . sort ( ( a , b ) = > b . score - a . score ) ;
if ( hands . length > ( config . hand . maxDetected || 1 ) ) hands . length = ( config . hand . maxDetected || 1 ) ;
2021-09-21 22:48:16 +02:00
return hands ;
}
async function detectFingers ( input : Tensor , h : HandDetectResult , config : Config ) : Promise < HandResult > {
2021-10-09 00:39:04 +02:00
const hand : HandResult = { // initial values inherited from hand detect
2021-09-21 22:48:16 +02:00
id : h.id ,
score : Math.round ( 100 * h . score ) / 100 ,
boxScore : Math.round ( 100 * h . score ) / 100 ,
fingerScore : 0 ,
box : h.box ,
boxRaw : h.boxRaw ,
label : h.label ,
keypoints : [ ] ,
landmarks : { } as HandResult [ 'landmarks' ] ,
annotations : { } as HandResult [ 'annotations' ] ,
} ;
2021-10-09 00:39:04 +02:00
if ( input && models [ 1 ] && config . hand . landmarks && h . score > ( config . hand . minConfidence || 0 ) ) {
2021-09-22 21:16:14 +02:00
const t : Record < string , Tensor > = { } ;
2021-10-10 23:52:43 +02:00
t . crop = tf . image . cropAndResize ( input , [ h . boxCrop ] , [ 0 ] , [ inputSize [ 1 ] [ 0 ] , inputSize [ 1 ] [ 1 ] ] , 'bilinear' ) ;
2021-11-17 00:31:07 +01:00
t . div = tf . div ( t . crop , constants . tf255 ) ;
2021-10-31 14:06:33 +01:00
[ t . score , t . keypoints ] = models [ 1 ] . execute ( t . div , [ 'Identity_1' , 'Identity' ] ) as Tensor [ ] ;
2021-09-28 19:48:29 +02:00
const rawScore = ( await t . score . data ( ) ) [ 0 ] ;
const score = ( 100 - Math . trunc ( 100 / ( 1 + Math . exp ( rawScore ) ) ) ) / 100 ; // reverse sigmoid value
if ( score >= ( config . hand . minConfidence || 0 ) ) {
2021-09-22 21:16:14 +02:00
hand . fingerScore = score ;
t . reshaped = tf . reshape ( t . keypoints , [ - 1 , 3 ] ) ;
2021-10-10 23:52:43 +02:00
const coordsData : Point [ ] = await t . reshaped . array ( ) as Point [ ] ;
const coordsRaw : Point [ ] = coordsData . map ( ( kpt ) = > [ kpt [ 0 ] / inputSize [ 1 ] [ 1 ] , kpt [ 1 ] / inputSize [ 1 ] [ 0 ] , ( kpt [ 2 ] || 0 ) ] ) ;
const coordsNorm : Point [ ] = coordsRaw . map ( ( kpt ) = > [ kpt [ 0 ] * h . boxRaw [ 2 ] , kpt [ 1 ] * h . boxRaw [ 3 ] , ( kpt [ 2 ] || 0 ) ] ) ;
hand . keypoints = ( coordsNorm ) . map ( ( kpt ) = > [
2021-10-11 04:29:20 +02:00
outputSize [ 0 ] * ( kpt [ 0 ] + h . boxRaw [ 0 ] ) ,
outputSize [ 1 ] * ( kpt [ 1 ] + h . boxRaw [ 1 ] ) ,
2021-10-10 23:52:43 +02:00
( kpt [ 2 ] || 0 ) ,
2021-09-22 21:16:14 +02:00
] ) ;
2021-11-08 13:36:26 +01:00
hand . landmarks = fingerPose . analyze ( hand . keypoints ) as HandResult [ 'landmarks' ] ; // calculate finger gestures
2021-09-22 21:16:14 +02:00
for ( const key of Object . keys ( fingerMap ) ) { // map keypoints to per-finger annotations
hand . annotations [ key ] = fingerMap [ key ] . map ( ( index ) = > ( hand . landmarks && hand . keypoints [ index ] ? hand . keypoints [ index ] : null ) ) ;
}
2021-09-21 22:48:16 +02:00
}
2021-09-22 21:16:14 +02:00
Object . keys ( t ) . forEach ( ( tensor ) = > tf . dispose ( t [ tensor ] ) ) ;
2021-09-21 22:48:16 +02:00
}
return hand ;
}
export async function predict ( input : Tensor , config : Config ) : Promise < HandResult [ ] > {
2021-10-09 00:39:04 +02:00
if ( ! models [ 0 ] || ! models [ 1 ] || ! models [ 0 ] ? . inputs [ 0 ] . shape || ! models [ 1 ] ? . inputs [ 0 ] . shape ) return [ ] ; // something is wrong with the model
2021-09-21 22:48:16 +02:00
outputSize = [ input . shape [ 2 ] || 0 , input . shape [ 1 ] || 0 ] ;
2021-10-09 00:39:04 +02:00
skipped ++ ; // increment skip frames
2021-10-23 15:38:52 +02:00
const skipTime = ( config . hand . skipTime || 0 ) > ( now ( ) - lastTime ) ;
const skipFrame = skipped < ( config . hand . skipFrames || 0 ) ;
if ( config . skipAllowed && skipTime && skipFrame ) {
2021-10-09 00:39:04 +02:00
return cache . hands ; // return cached results without running anything
}
return new Promise ( async ( resolve ) = > {
2021-10-23 15:38:52 +02:00
const skipTimeExtended = 3 * ( config . hand . skipTime || 0 ) > ( now ( ) - lastTime ) ;
const skipFrameExtended = skipped < 3 * ( config . hand . skipFrames || 0 ) ;
if ( config . skipAllowed && cache . hands . length === config . hand . maxDetected ) { // we have all detected hands so we're definitely skipping
2021-10-10 23:52:43 +02:00
cache . hands = await Promise . all ( cache . boxes . map ( ( handBox ) = > detectFingers ( input , handBox , config ) ) ) ;
2021-10-23 15:38:52 +02:00
} else if ( config . skipAllowed && skipTimeExtended && skipFrameExtended && cache . hands . length > 0 ) { // we have some cached results: maybe not enough but anyhow continue for bit longer
2021-10-13 20:49:41 +02:00
cache . hands = await Promise . all ( cache . boxes . map ( ( handBox ) = > detectFingers ( input , handBox , config ) ) ) ;
} else { // finally rerun detector
2021-10-10 23:52:43 +02:00
cache . boxes = await detectHands ( input , config ) ;
2021-10-22 22:09:52 +02:00
lastTime = now ( ) ;
2021-10-10 23:52:43 +02:00
cache . hands = await Promise . all ( cache . boxes . map ( ( handBox ) = > detectFingers ( input , handBox , config ) ) ) ;
skipped = 0 ;
}
2021-10-09 00:39:04 +02:00
const oldCache = [ . . . cache . boxes ] ;
cache . boxes . length = 0 ; // reset cache
2021-10-10 23:52:43 +02:00
if ( config . cacheSensitivity > 0 ) {
for ( let i = 0 ; i < cache . hands . length ; i ++ ) {
const boxKpt = box . square ( cache . hands [ i ] . keypoints , outputSize ) ;
if ( boxKpt . box [ 2 ] / ( input . shape [ 2 ] || 1 ) > 0.05 && boxKpt . box [ 3 ] / ( input . shape [ 1 ] || 1 ) > 0.05 && cache . hands [ i ] . fingerScore && cache . hands [ i ] . fingerScore > ( config . hand . minConfidence || 0 ) ) {
const boxScale = box . scale ( boxKpt . box , boxExpandFact ) ;
const boxScaleRaw = box . scale ( boxKpt . boxRaw , boxExpandFact ) ;
const boxCrop = box . crop ( boxScaleRaw ) ;
cache . boxes . push ( { . . . oldCache [ i ] , box : boxScale , boxRaw : boxScaleRaw , boxCrop } ) ;
}
2021-10-09 00:39:04 +02:00
}
}
2021-10-20 15:10:57 +02:00
for ( let i = 0 ; i < cache . hands . length ; i ++ ) { // replace deteced boxes with calculated boxes in final output
const bbox = box . calc ( cache . hands [ i ] . keypoints , outputSize ) ;
cache . hands [ i ] . box = bbox . box ;
cache . hands [ i ] . boxRaw = bbox . boxRaw ;
}
2021-10-09 00:39:04 +02:00
resolve ( cache . hands ) ;
} ) ;
2021-09-21 22:48:16 +02:00
}