2021-04-24 22:04:49 +02:00
import * as utils from './utils' ;
import * as kpt from './keypoints' ;
const localMaximumRadius = 1 ;
2021-05-05 02:46:33 +02:00
const outputStride = 16 ;
const squaredNmsRadius = 50 * * 2 ;
2021-04-24 22:04:49 +02:00
2021-05-05 16:07:44 +02:00
function traverse ( edgeId , sourceKeypoint , targetId , scores , offsets , displacements , offsetRefineStep = 2 ) {
2021-04-24 22:04:49 +02:00
const getDisplacement = ( point ) = > ( {
y : displacements.get ( point . y , point . x , edgeId ) ,
x : displacements.get ( point . y , point . x , ( displacements . shape [ 2 ] / 2 ) + edgeId ) ,
} ) ;
const getStridedIndexNearPoint = ( point , height , width ) = > ( {
y : utils.clamp ( Math . round ( point . y / outputStride ) , 0 , height - 1 ) ,
x : utils.clamp ( Math . round ( point . x / outputStride ) , 0 , width - 1 ) ,
} ) ;
2021-05-05 16:07:44 +02:00
const [ height , width ] = scores . shape ;
2021-04-24 22:04:49 +02:00
// Nearest neighbor interpolation for the source->target displacements.
const sourceKeypointIndices = getStridedIndexNearPoint ( sourceKeypoint . position , height , width ) ;
const displacement = getDisplacement ( sourceKeypointIndices ) ;
const displacedPoint = utils . addVectors ( sourceKeypoint . position , displacement ) ;
let targetKeypoint = displacedPoint ;
for ( let i = 0 ; i < offsetRefineStep ; i ++ ) {
const targetKeypointIndices = getStridedIndexNearPoint ( targetKeypoint , height , width ) ;
2021-05-05 16:07:44 +02:00
const offsetPoint = utils . getOffsetPoint ( targetKeypointIndices . y , targetKeypointIndices . x , targetId , offsets ) ;
targetKeypoint = utils . addVectors (
{ x : targetKeypointIndices.x * outputStride , y : targetKeypointIndices.y * outputStride } ,
{ x : offsetPoint.x , y : offsetPoint.y } ,
) ;
2021-04-24 22:04:49 +02:00
}
const targetKeyPointIndices = getStridedIndexNearPoint ( targetKeypoint , height , width ) ;
2021-05-05 16:07:44 +02:00
const score = scores . get ( targetKeyPointIndices . y , targetKeyPointIndices . x , targetId ) ;
return { position : targetKeypoint , part : kpt.partNames [ targetId ] , score } ;
2021-04-24 22:04:49 +02:00
}
2021-05-05 02:46:33 +02:00
export function decodePose ( root , scores , offsets , displacementsFwd , displacementsBwd ) {
2021-05-05 16:07:44 +02:00
const tuples = kpt . poseChain . map ( ( [ parentJoinName , childJoinName ] ) = > ( [ kpt . partIds [ parentJoinName ] , kpt . partIds [ childJoinName ] ] ) ) ;
const edgesFwd = tuples . map ( ( [ , childJointId ] ) = > childJointId ) ;
const edgesBwd = tuples . map ( ( [ parentJointId ] ) = > parentJointId ) ;
2021-05-05 02:46:33 +02:00
const numParts = scores . shape [ 2 ] ; // [21,21,17]
2021-05-05 16:07:44 +02:00
const numEdges = edgesFwd . length ;
const keypoints = new Array ( numParts ) ;
2021-04-24 22:04:49 +02:00
// Start a new detection instance at the position of the root.
2021-05-05 02:46:33 +02:00
const rootPoint = utils . getImageCoords ( root . part , outputStride , offsets ) ;
2021-05-05 16:07:44 +02:00
keypoints [ root . part . id ] = {
2021-05-05 02:46:33 +02:00
score : root.score ,
part : kpt.partNames [ root . part . id ] ,
2021-04-24 22:04:49 +02:00
position : rootPoint ,
} ;
// Decode the part positions upwards in the tree, following the backward displacements.
for ( let edge = numEdges - 1 ; edge >= 0 ; -- edge ) {
2021-05-05 16:07:44 +02:00
const sourceId = edgesFwd [ edge ] ;
const targetId = edgesBwd [ edge ] ;
if ( keypoints [ sourceId ] && ! keypoints [ targetId ] ) {
keypoints [ targetId ] = traverse ( edge , keypoints [ sourceId ] , targetId , scores , offsets , displacementsBwd ) ;
2021-04-24 22:04:49 +02:00
}
}
// Decode the part positions downwards in the tree, following the forward displacements.
for ( let edge = 0 ; edge < numEdges ; ++ edge ) {
2021-05-05 16:07:44 +02:00
const sourceId = edgesBwd [ edge ] ;
const targetId = edgesFwd [ edge ] ;
if ( keypoints [ sourceId ] && ! keypoints [ targetId ] ) {
keypoints [ targetId ] = traverse ( edge , keypoints [ sourceId ] , targetId , scores , offsets , displacementsFwd ) ;
2021-04-24 22:04:49 +02:00
}
}
2021-05-05 16:07:44 +02:00
return keypoints ;
2021-04-24 22:04:49 +02:00
}
function scoreIsMaximumInLocalWindow ( keypointId , score , heatmapY , heatmapX , scores ) {
const [ height , width ] = scores . shape ;
let localMaximum = true ;
const yStart = Math . max ( heatmapY - localMaximumRadius , 0 ) ;
const yEnd = Math . min ( heatmapY + localMaximumRadius + 1 , height ) ;
for ( let yCurrent = yStart ; yCurrent < yEnd ; ++ yCurrent ) {
const xStart = Math . max ( heatmapX - localMaximumRadius , 0 ) ;
const xEnd = Math . min ( heatmapX + localMaximumRadius + 1 , width ) ;
for ( let xCurrent = xStart ; xCurrent < xEnd ; ++ xCurrent ) {
if ( scores . get ( yCurrent , xCurrent , keypointId ) > score ) {
localMaximum = false ;
break ;
}
}
if ( ! localMaximum ) break ;
}
return localMaximum ;
}
2021-04-25 19:16:04 +02:00
export function buildPartWithScoreQueue ( minConfidence , scores ) {
2021-04-24 22:04:49 +02:00
const [ height , width , numKeypoints ] = scores . shape ;
const queue = new utils . MaxHeap ( height * width * numKeypoints , ( { score } ) = > score ) ;
for ( let heatmapY = 0 ; heatmapY < height ; ++ heatmapY ) {
for ( let heatmapX = 0 ; heatmapX < width ; ++ heatmapX ) {
for ( let keypointId = 0 ; keypointId < numKeypoints ; ++ keypointId ) {
const score = scores . get ( heatmapY , heatmapX , keypointId ) ;
// Only consider parts with score greater or equal to threshold as root candidates.
2021-04-25 19:16:04 +02:00
if ( score < minConfidence ) continue ;
2021-04-24 22:04:49 +02:00
// Only consider keypoints whose score is maximum in a local window.
if ( scoreIsMaximumInLocalWindow ( keypointId , score , heatmapY , heatmapX , scores ) ) queue . enqueue ( { score , part : { heatmapY , heatmapX , id : keypointId } } ) ;
}
}
}
return queue ;
}
2021-04-25 19:16:04 +02:00
function withinRadius ( poses , { x , y } , keypointId ) {
2021-04-24 22:04:49 +02:00
return poses . some ( ( { keypoints } ) = > {
2021-05-05 16:07:44 +02:00
const correspondingKeypoint = keypoints [ keypointId ] ? . position ;
if ( ! correspondingKeypoint ) return false ;
2021-04-24 22:04:49 +02:00
return utils . squaredDistance ( y , x , correspondingKeypoint . y , correspondingKeypoint . x ) <= squaredNmsRadius ;
} ) ;
}
2021-05-05 16:07:44 +02:00
function getInstanceScore ( existingPoses , keypoints ) {
const notOverlappedKeypointScores = keypoints . reduce ( ( result , { position , score } , keypointId ) = > {
2021-04-25 19:16:04 +02:00
if ( ! withinRadius ( existingPoses , position , keypointId ) ) result += score ;
2021-04-24 22:04:49 +02:00
return result ;
} , 0.0 ) ;
2021-05-05 16:07:44 +02:00
return notOverlappedKeypointScores / keypoints . length ;
2021-04-24 22:04:49 +02:00
}
2021-05-05 16:07:44 +02:00
export function decode ( offsets , scores , displacementsFwd , displacementsBwd , maxDetected , minConfidence ) {
2021-05-22 20:53:51 +02:00
const poses : Array < { keypoints , box : [ number , number , number , number ] , score : number } > = [ ] ;
2021-05-05 16:07:44 +02:00
const queue = buildPartWithScoreQueue ( minConfidence , scores ) ;
2021-04-25 19:16:04 +02:00
// Generate at most maxDetected object instances per image in decreasing root part score order.
while ( poses . length < maxDetected && ! queue . empty ( ) ) {
2021-04-24 22:04:49 +02:00
// The top element in the queue is the next root candidate.
const root = queue . dequeue ( ) ;
// Part-based non-maximum suppression: We reject a root candidate if it is within a disk of `nmsRadius` pixels from the corresponding part of a previously detected instance.
2021-05-05 16:07:44 +02:00
const rootImageCoords = utils . getImageCoords ( root . part , outputStride , offsets ) ;
2021-04-25 19:16:04 +02:00
if ( withinRadius ( poses , rootImageCoords , root . part . id ) ) continue ;
2021-04-24 22:04:49 +02:00
// Else start a new detection instance at the position of the root.
2021-05-05 16:07:44 +02:00
let keypoints = decodePose ( root , scores , offsets , displacementsFwd , displacementsBwd ) ;
2021-05-05 02:46:33 +02:00
keypoints = keypoints . filter ( ( a ) = > a . score > minConfidence ) ;
2021-04-25 19:16:04 +02:00
const score = getInstanceScore ( poses , keypoints ) ;
2021-04-24 22:04:49 +02:00
const box = utils . getBoundingBox ( keypoints ) ;
2021-04-25 19:16:04 +02:00
if ( score > minConfidence ) poses . push ( { keypoints , box , score : Math.round ( 100 * score ) / 100 } ) ;
2021-04-24 22:04:49 +02:00
}
return poses ;
}