2020-11-10 02:13:38 +01:00
import * as buildParts from './buildParts' ;
import * as decodePose from './decodePose' ;
import * as vectors from './vectors' ;
2020-10-12 01:22:43 +02:00
2020-12-17 00:36:24 +01:00
const kLocalMaximumRadius = 1 ;
2020-10-12 01:22:43 +02:00
function withinNmsRadiusOfCorrespondingPoint ( poses , squaredNmsRadius , { x , y } , keypointId ) {
return poses . some ( ( { keypoints } ) => {
const correspondingKeypoint = keypoints [ keypointId ] . position ;
return vectors . squaredDistance ( y , x , correspondingKeypoint . y , correspondingKeypoint . x ) <= squaredNmsRadius ;
} ) ;
}
2020-12-17 00:36:24 +01:00
2020-10-12 01:22:43 +02:00
function getInstanceScore ( existingPoses , squaredNmsRadius , instanceKeypoints ) {
const notOverlappedKeypointScores = instanceKeypoints . reduce ( ( result , { position , score } , keypointId ) => {
2020-12-08 16:50:26 +01:00
if ( ! withinNmsRadiusOfCorrespondingPoint ( existingPoses , squaredNmsRadius , position , keypointId ) ) result += score ;
2020-10-12 01:22:43 +02:00
return result ;
} , 0.0 ) ;
return notOverlappedKeypointScores / instanceKeypoints . length ;
}
2020-12-17 00:36:24 +01:00
function decodeMultiplePoses ( scoresBuffer , offsetsBuffer , displacementsFwdBuffer , displacementsBwdBuffer , config ) {
2020-10-12 01:22:43 +02:00
const poses = [ ] ;
2020-12-17 00:36:24 +01:00
const queue = buildParts . buildPartWithScoreQueue ( config . body . scoreThreshold , kLocalMaximumRadius , scoresBuffer ) ;
const squaredNmsRadius = config . body . nmsRadius ^ 2 ;
// Generate at most maxDetections object instances per image in decreasing root part score order.
while ( poses . length < config . body . maxDetections && ! queue . empty ( ) ) {
2020-10-12 01:22:43 +02:00
// The top element in the queue is the next root candidate.
const root = queue . dequeue ( ) ;
2020-12-17 00:36:24 +01:00
// Part-based non-maximum suppression: We reject a root candidate if it is within a disk of `nmsRadius` pixels from the corresponding part of a previously detected instance.
const rootImageCoords = vectors . getImageCoords ( root . part , config . body . outputStride , offsetsBuffer ) ;
2020-10-12 01:22:43 +02:00
if ( withinNmsRadiusOfCorrespondingPoint ( poses , squaredNmsRadius , rootImageCoords , root . part . id ) ) continue ;
2020-12-17 00:36:24 +01:00
// Else start a new detection instance at the position of the root.
const keypoints = decodePose . decodePose ( root , scoresBuffer , offsetsBuffer , config . body . outputStride , displacementsFwdBuffer , displacementsBwdBuffer ) ;
2020-10-12 01:22:43 +02:00
const score = getInstanceScore ( poses , squaredNmsRadius , keypoints ) ;
2020-12-17 00:36:24 +01:00
if ( score > config . body . scoreThreshold ) poses . push ( { keypoints , score } ) ;
2020-10-12 01:22:43 +02:00
}
return poses ;
}
exports . decodeMultiplePoses = decodeMultiplePoses ;