2021-09-25 17:51:15 +02:00
/ * *
* PoseNet body detection model implementation
2021-09-28 18:01:48 +02:00
*
* Based on : [ * * PoseNet * * ] ( https : //medium.com/tensorflow/real-time-human-pose-estimation-in-the-browser-with-tensorflow-js-7dd0bc881cd5)
2021-09-25 17:51:15 +02:00
* /
2022-10-17 02:28:57 +02:00
import * as tf from 'dist/tfjs.esm.js' ;
2022-01-17 17:03:21 +01:00
import { log } from '../util/util' ;
2022-01-16 15:49:55 +01:00
import { loadModel } from '../tfjs/load' ;
2021-12-15 15:26:32 +01:00
import type { BodyResult , BodyLandmark , Box } from '../result' ;
2022-10-17 02:28:57 +02:00
import type { Tensor , GraphModel , Tensor4D } from '../tfjs/types' ;
2021-09-28 18:01:48 +02:00
import type { Config } from '../config' ;
import { env } from '../util/env' ;
import * as utils from './posenetutils' ;
let model : GraphModel ;
const poseNetOutputs = [ 'MobilenetV1/offset_2/BiasAdd' /* offsets */ , 'MobilenetV1/heatmap_2/BiasAdd' /* heatmapScores */ , 'MobilenetV1/displacement_fwd_2/BiasAdd' /* displacementFwd */ , 'MobilenetV1/displacement_bwd_2/BiasAdd' /* displacementBwd */ ] ;
2021-04-24 22:04:49 +02:00
const localMaximumRadius = 1 ;
2021-05-05 02:46:33 +02:00
const outputStride = 16 ;
const squaredNmsRadius = 50 * * 2 ;
2021-04-24 22:04:49 +02:00
2022-08-21 21:23:03 +02:00
function traverse ( edgeId : number , sourceKeypoint , targetId , scores , offsets , displacements , offsetRefineStep = 2 ) {
2021-04-24 22:04:49 +02:00
const getDisplacement = ( point ) = > ( {
y : displacements.get ( point . y , point . x , edgeId ) ,
x : displacements.get ( point . y , point . x , ( displacements . shape [ 2 ] / 2 ) + edgeId ) ,
} ) ;
const getStridedIndexNearPoint = ( point , height , width ) = > ( {
y : utils.clamp ( Math . round ( point . y / outputStride ) , 0 , height - 1 ) ,
x : utils.clamp ( Math . round ( point . x / outputStride ) , 0 , width - 1 ) ,
} ) ;
2021-05-05 16:07:44 +02:00
const [ height , width ] = scores . shape ;
2021-04-24 22:04:49 +02:00
// Nearest neighbor interpolation for the source->target displacements.
const sourceKeypointIndices = getStridedIndexNearPoint ( sourceKeypoint . position , height , width ) ;
const displacement = getDisplacement ( sourceKeypointIndices ) ;
const displacedPoint = utils . addVectors ( sourceKeypoint . position , displacement ) ;
let targetKeypoint = displacedPoint ;
for ( let i = 0 ; i < offsetRefineStep ; i ++ ) {
const targetKeypointIndices = getStridedIndexNearPoint ( targetKeypoint , height , width ) ;
2021-05-05 16:07:44 +02:00
const offsetPoint = utils . getOffsetPoint ( targetKeypointIndices . y , targetKeypointIndices . x , targetId , offsets ) ;
targetKeypoint = utils . addVectors (
{ x : targetKeypointIndices.x * outputStride , y : targetKeypointIndices.y * outputStride } ,
{ x : offsetPoint.x , y : offsetPoint.y } ,
) ;
2021-04-24 22:04:49 +02:00
}
const targetKeyPointIndices = getStridedIndexNearPoint ( targetKeypoint , height , width ) ;
2021-05-05 16:07:44 +02:00
const score = scores . get ( targetKeyPointIndices . y , targetKeyPointIndices . x , targetId ) ;
2021-09-28 18:01:48 +02:00
return { position : targetKeypoint , part : utils.partNames [ targetId ] , score } ;
2021-04-24 22:04:49 +02:00
}
2021-05-05 02:46:33 +02:00
export function decodePose ( root , scores , offsets , displacementsFwd , displacementsBwd ) {
2021-09-28 18:01:48 +02:00
const tuples = utils . poseChain . map ( ( [ parentJoinName , childJoinName ] ) = > ( [ utils . partIds [ parentJoinName ] , utils . partIds [ childJoinName ] ] ) ) ;
2021-05-05 16:07:44 +02:00
const edgesFwd = tuples . map ( ( [ , childJointId ] ) = > childJointId ) ;
const edgesBwd = tuples . map ( ( [ parentJointId ] ) = > parentJointId ) ;
2021-05-05 02:46:33 +02:00
const numParts = scores . shape [ 2 ] ; // [21,21,17]
2021-05-05 16:07:44 +02:00
const numEdges = edgesFwd . length ;
const keypoints = new Array ( numParts ) ;
2021-04-24 22:04:49 +02:00
// Start a new detection instance at the position of the root.
2021-05-05 02:46:33 +02:00
const rootPoint = utils . getImageCoords ( root . part , outputStride , offsets ) ;
2021-05-05 16:07:44 +02:00
keypoints [ root . part . id ] = {
2021-05-05 02:46:33 +02:00
score : root.score ,
2021-12-15 15:26:32 +01:00
part : utils.partNames [ root . part . id ] as BodyLandmark ,
2021-04-24 22:04:49 +02:00
position : rootPoint ,
} ;
// Decode the part positions upwards in the tree, following the backward displacements.
for ( let edge = numEdges - 1 ; edge >= 0 ; -- edge ) {
2021-05-05 16:07:44 +02:00
const sourceId = edgesFwd [ edge ] ;
const targetId = edgesBwd [ edge ] ;
if ( keypoints [ sourceId ] && ! keypoints [ targetId ] ) {
keypoints [ targetId ] = traverse ( edge , keypoints [ sourceId ] , targetId , scores , offsets , displacementsBwd ) ;
2021-04-24 22:04:49 +02:00
}
}
// Decode the part positions downwards in the tree, following the forward displacements.
for ( let edge = 0 ; edge < numEdges ; ++ edge ) {
2021-05-05 16:07:44 +02:00
const sourceId = edgesBwd [ edge ] ;
const targetId = edgesFwd [ edge ] ;
if ( keypoints [ sourceId ] && ! keypoints [ targetId ] ) {
keypoints [ targetId ] = traverse ( edge , keypoints [ sourceId ] , targetId , scores , offsets , displacementsFwd ) ;
2021-04-24 22:04:49 +02:00
}
}
2021-05-05 16:07:44 +02:00
return keypoints ;
2021-04-24 22:04:49 +02:00
}
2022-08-21 21:23:03 +02:00
function scoreIsMaximumInLocalWindow ( keypointId , score : number , heatmapY : number , heatmapX : number , scores ) {
const [ height , width ] : [ number , number ] = scores . shape ;
2021-04-24 22:04:49 +02:00
let localMaximum = true ;
const yStart = Math . max ( heatmapY - localMaximumRadius , 0 ) ;
const yEnd = Math . min ( heatmapY + localMaximumRadius + 1 , height ) ;
for ( let yCurrent = yStart ; yCurrent < yEnd ; ++ yCurrent ) {
const xStart = Math . max ( heatmapX - localMaximumRadius , 0 ) ;
const xEnd = Math . min ( heatmapX + localMaximumRadius + 1 , width ) ;
for ( let xCurrent = xStart ; xCurrent < xEnd ; ++ xCurrent ) {
if ( scores . get ( yCurrent , xCurrent , keypointId ) > score ) {
localMaximum = false ;
break ;
}
}
if ( ! localMaximum ) break ;
}
return localMaximum ;
}
2021-04-25 19:16:04 +02:00
export function buildPartWithScoreQueue ( minConfidence , scores ) {
2021-04-24 22:04:49 +02:00
const [ height , width , numKeypoints ] = scores . shape ;
const queue = new utils . MaxHeap ( height * width * numKeypoints , ( { score } ) = > score ) ;
for ( let heatmapY = 0 ; heatmapY < height ; ++ heatmapY ) {
for ( let heatmapX = 0 ; heatmapX < width ; ++ heatmapX ) {
for ( let keypointId = 0 ; keypointId < numKeypoints ; ++ keypointId ) {
const score = scores . get ( heatmapY , heatmapX , keypointId ) ;
// Only consider parts with score greater or equal to threshold as root candidates.
2021-04-25 19:16:04 +02:00
if ( score < minConfidence ) continue ;
2021-04-24 22:04:49 +02:00
// Only consider keypoints whose score is maximum in a local window.
if ( scoreIsMaximumInLocalWindow ( keypointId , score , heatmapY , heatmapX , scores ) ) queue . enqueue ( { score , part : { heatmapY , heatmapX , id : keypointId } } ) ;
}
}
}
return queue ;
}
2021-04-25 19:16:04 +02:00
function withinRadius ( poses , { x , y } , keypointId ) {
2021-04-24 22:04:49 +02:00
return poses . some ( ( { keypoints } ) = > {
2021-05-05 16:07:44 +02:00
const correspondingKeypoint = keypoints [ keypointId ] ? . position ;
if ( ! correspondingKeypoint ) return false ;
2021-04-24 22:04:49 +02:00
return utils . squaredDistance ( y , x , correspondingKeypoint . y , correspondingKeypoint . x ) <= squaredNmsRadius ;
} ) ;
}
2021-05-05 16:07:44 +02:00
function getInstanceScore ( existingPoses , keypoints ) {
const notOverlappedKeypointScores = keypoints . reduce ( ( result , { position , score } , keypointId ) = > {
2021-04-25 19:16:04 +02:00
if ( ! withinRadius ( existingPoses , position , keypointId ) ) result += score ;
2021-04-24 22:04:49 +02:00
return result ;
} , 0.0 ) ;
2021-05-05 16:07:44 +02:00
return notOverlappedKeypointScores / keypoints . length ;
2021-04-24 22:04:49 +02:00
}
2021-05-05 16:07:44 +02:00
export function decode ( offsets , scores , displacementsFwd , displacementsBwd , maxDetected , minConfidence ) {
2022-08-21 19:34:51 +02:00
const poses : { keypoints , box : Box , score : number } [ ] = [ ] ;
2021-05-05 16:07:44 +02:00
const queue = buildPartWithScoreQueue ( minConfidence , scores ) ;
2021-04-25 19:16:04 +02:00
// Generate at most maxDetected object instances per image in decreasing root part score order.
while ( poses . length < maxDetected && ! queue . empty ( ) ) {
2021-04-24 22:04:49 +02:00
// The top element in the queue is the next root candidate.
const root = queue . dequeue ( ) ;
// Part-based non-maximum suppression: We reject a root candidate if it is within a disk of `nmsRadius` pixels from the corresponding part of a previously detected instance.
2021-05-23 03:47:59 +02:00
// @ts-ignore this one is tree walk
2021-05-05 16:07:44 +02:00
const rootImageCoords = utils . getImageCoords ( root . part , outputStride , offsets ) ;
2021-05-23 03:47:59 +02:00
// @ts-ignore this one is tree walk
2021-04-25 19:16:04 +02:00
if ( withinRadius ( poses , rootImageCoords , root . part . id ) ) continue ;
2021-04-24 22:04:49 +02:00
// Else start a new detection instance at the position of the root.
2021-05-05 16:07:44 +02:00
let keypoints = decodePose ( root , scores , offsets , displacementsFwd , displacementsBwd ) ;
2021-05-05 02:46:33 +02:00
keypoints = keypoints . filter ( ( a ) = > a . score > minConfidence ) ;
2021-04-25 19:16:04 +02:00
const score = getInstanceScore ( poses , keypoints ) ;
2021-04-24 22:04:49 +02:00
const box = utils . getBoundingBox ( keypoints ) ;
2021-04-25 19:16:04 +02:00
if ( score > minConfidence ) poses . push ( { keypoints , box , score : Math.round ( 100 * score ) / 100 } ) ;
2021-04-24 22:04:49 +02:00
}
return poses ;
}
2021-09-28 18:01:48 +02:00
2022-10-17 02:28:57 +02:00
export async function predict ( input : Tensor4D , config : Config ) : Promise < BodyResult [ ] > {
2021-10-10 23:52:43 +02:00
/ * * p o s e n e t i s m o s t l y o b s o l e t e
* caching is not implemented
* /
2022-08-30 16:28:33 +02:00
if ( ! model ? . [ 'executor' ] ) return [ ] ;
2021-09-28 18:01:48 +02:00
const res = tf . tidy ( ( ) = > {
if ( ! model . inputs [ 0 ] . shape ) return [ ] ;
const resized = tf . image . resizeBilinear ( input , [ model . inputs [ 0 ] . shape [ 2 ] , model . inputs [ 0 ] . shape [ 1 ] ] ) ;
const normalized = tf . sub ( tf . div ( tf . cast ( resized , 'float32' ) , 127.5 ) , 1.0 ) ;
2022-08-21 19:34:51 +02:00
const results : Tensor [ ] = model . execute ( normalized , poseNetOutputs ) as Tensor [ ] ;
2021-09-28 18:01:48 +02:00
const results3d = results . map ( ( y ) = > tf . squeeze ( y , [ 0 ] ) ) ;
2021-12-18 18:24:01 +01:00
results3d [ 1 ] = tf . sigmoid ( results3d [ 1 ] ) ; // apply sigmoid on scores
2021-09-28 18:01:48 +02:00
return results3d ;
} ) ;
const buffers = await Promise . all ( res . map ( ( tensor : Tensor ) = > tensor . buffer ( ) ) ) ;
for ( const t of res ) tf . dispose ( t ) ;
2022-08-21 21:23:03 +02:00
const decoded = decode ( buffers [ 0 ] , buffers [ 1 ] , buffers [ 2 ] , buffers [ 3 ] , config . body . maxDetected , config . body . minConfidence ) ;
2021-09-28 18:01:48 +02:00
if ( ! model . inputs [ 0 ] . shape ) return [ ] ;
2022-08-21 19:34:51 +02:00
const scaled = utils . scalePoses ( decoded , [ input . shape [ 1 ] , input . shape [ 2 ] ] , [ model . inputs [ 0 ] . shape [ 2 ] , model . inputs [ 0 ] . shape [ 1 ] ] ) ;
2021-09-28 18:01:48 +02:00
return scaled ;
}
export async function load ( config : Config ) : Promise < GraphModel > {
2022-01-17 17:03:21 +01:00
if ( ! model || env . initial ) model = await loadModel ( config . body . modelPath ) ;
else if ( config . debug ) log ( 'cached model:' , model [ 'modelUrl' ] ) ;
2021-09-28 18:01:48 +02:00
return model ;
}