2020-11-10 02:13:38 +01:00
import * as buildParts from './buildParts' ;
import * as decodePose from './decodePose' ;
import * as vectors from './vectors' ;
2020-10-12 01:22:43 +02:00
2020-12-17 00:36:24 +01:00
const kLocalMaximumRadius = 1 ;
2021-03-04 16:33:08 +01:00
const defaultOutputStride = 16 ;
2020-12-17 00:36:24 +01:00
2020-10-12 01:22:43 +02:00
function withinNmsRadiusOfCorrespondingPoint ( poses , squaredNmsRadius , { x , y } , keypointId ) {
return poses . some ( ( { keypoints } ) = > {
const correspondingKeypoint = keypoints [ keypointId ] . position ;
return vectors . squaredDistance ( y , x , correspondingKeypoint . y , correspondingKeypoint . x ) <= squaredNmsRadius ;
} ) ;
}
2020-12-17 00:36:24 +01:00
2020-10-12 01:22:43 +02:00
function getInstanceScore ( existingPoses , squaredNmsRadius , instanceKeypoints ) {
const notOverlappedKeypointScores = instanceKeypoints . reduce ( ( result , { position , score } , keypointId ) = > {
2020-12-08 16:50:26 +01:00
if ( ! withinNmsRadiusOfCorrespondingPoint ( existingPoses , squaredNmsRadius , position , keypointId ) ) result += score ;
2020-10-12 01:22:43 +02:00
return result ;
} , 0.0 ) ;
return notOverlappedKeypointScores / instanceKeypoints . length ;
}
2020-12-17 00:36:24 +01:00
2021-03-11 16:26:14 +01:00
export function decodeMultiplePoses ( scoresBuffer , offsetsBuffer , displacementsFwdBuffer , displacementsBwdBuffer , nmsRadius , maxDetections , scoreThreshold ) {
2021-02-08 18:47:38 +01:00
const poses : Array < { keypoints : any , score : number } > = [ ] ;
2021-03-11 16:26:14 +01:00
const queue = buildParts . buildPartWithScoreQueue ( scoreThreshold , kLocalMaximumRadius , scoresBuffer ) ;
const squaredNmsRadius = nmsRadius ^ 2 ;
2020-12-17 00:36:24 +01:00
// Generate at most maxDetections object instances per image in decreasing root part score order.
2021-03-11 16:26:14 +01:00
while ( poses . length < maxDetections && ! queue . empty ( ) ) {
2020-10-12 01:22:43 +02:00
// The top element in the queue is the next root candidate.
const root = queue . dequeue ( ) ;
2020-12-17 00:36:24 +01:00
// Part-based non-maximum suppression: We reject a root candidate if it is within a disk of `nmsRadius` pixels from the corresponding part of a previously detected instance.
2021-03-04 16:33:08 +01:00
const rootImageCoords = vectors . getImageCoords ( root . part , defaultOutputStride , offsetsBuffer ) ;
2020-10-12 01:22:43 +02:00
if ( withinNmsRadiusOfCorrespondingPoint ( poses , squaredNmsRadius , rootImageCoords , root . part . id ) ) continue ;
2020-12-17 00:36:24 +01:00
// Else start a new detection instance at the position of the root.
2021-03-04 16:33:08 +01:00
const keypoints = decodePose . decodePose ( root , scoresBuffer , offsetsBuffer , defaultOutputStride , displacementsFwdBuffer , displacementsBwdBuffer ) ;
2020-10-12 01:22:43 +02:00
const score = getInstanceScore ( poses , squaredNmsRadius , keypoints ) ;
2021-03-11 16:26:14 +01:00
if ( score > scoreThreshold ) poses . push ( { keypoints , score } ) ;
2020-10-12 01:22:43 +02:00
}
return poses ;
}