2021-05-25 14:58:20 +02:00
/ * *
2021-09-25 17:51:15 +02:00
* FaceRes model implementation
*
2021-05-25 14:58:20 +02:00
* Returns Age , Gender , Descriptor
* Implements Face simmilarity function
2021-09-25 17:51:15 +02:00
*
* Based on : [ * * HSE - FaceRes * * ] ( https : //github.com/HSE-asavchenko/HSE_FaceRec_tf)
2021-05-25 14:58:20 +02:00
* /
2021-10-22 22:09:52 +02:00
import { log , join , now } from '../util/util' ;
2021-03-21 19:18:51 +01:00
import * as tf from '../../dist/tfjs.esm.js' ;
2021-09-13 19:28:35 +02:00
import type { Tensor , GraphModel } from '../tfjs/types' ;
import type { Config } from '../config' ;
2021-09-27 19:58:13 +02:00
import { env } from '../util/env' ;
2021-03-21 19:18:51 +01:00
2021-09-17 17:23:00 +02:00
let model : GraphModel | null ;
2021-06-01 14:59:09 +02:00
const last : Array < {
age : number ,
gender : string ,
genderScore : number ,
descriptor : number [ ] ,
} > = [ ] ;
2021-10-22 22:09:52 +02:00
let lastTime = 0 ;
2021-05-18 17:26:16 +02:00
let lastCount = 0 ;
2021-03-21 19:18:51 +01:00
let skipped = Number . MAX_SAFE_INTEGER ;
2021-06-03 15:41:53 +02:00
export async function load ( config : Config ) : Promise < GraphModel > {
2021-09-12 05:54:35 +02:00
const modelUrl = join ( config . modelBasePath , config . face . description ? . modelPath || '' ) ;
2021-09-17 17:23:00 +02:00
if ( env . initial ) model = null ;
2021-03-21 19:18:51 +01:00
if ( ! model ) {
2021-08-17 14:51:17 +02:00
model = await tf . loadGraphModel ( modelUrl ) as unknown as GraphModel ;
2021-09-12 05:54:35 +02:00
if ( ! model ) log ( 'load model failed:' , config . face . description ? . modelPath || '' ) ;
2021-05-23 03:47:59 +02:00
else if ( config . debug ) log ( 'load model:' , modelUrl ) ;
} else if ( config . debug ) log ( 'cached model:' , modelUrl ) ;
2021-03-21 19:18:51 +01:00
return model ;
}
export function enhance ( input ) : Tensor {
2021-11-05 16:28:06 +01:00
const tensor = ( input . image || input . tensor || input ) as Tensor ; // input received from detector is already normalized to 0..1, input is also assumed to be straightened
if ( ! model ? . inputs [ 0 ] . shape ) return tensor ; // model has no shape so no point continuing
const crop = tf . image . resizeBilinear ( tensor , [ model . inputs [ 0 ] . shape [ 2 ] , model . inputs [ 0 ] . shape [ 1 ] ] , false ) ;
2021-11-06 15:21:51 +01:00
const norm = tf . mul ( crop , 255 ) ;
tf . dispose ( crop ) ;
return norm ;
2021-11-05 16:28:06 +01:00
/ *
2021-11-12 21:07:23 +01:00
// do a tight crop of image and resize it to fit the model
2021-11-05 16:28:06 +01:00
const box = [ [ 0.05 , 0.15 , 0.85 , 0.85 ] ] ; // empyrical values for top, left, bottom, right
const crop = ( tensor . shape . length === 3 )
? tf . image . cropAndResize ( tf . expandDims ( tensor , 0 ) , box , [ 0 ] , [ model . inputs [ 0 ] . shape [ 2 ] , model . inputs [ 0 ] . shape [ 1 ] ] ) // add batch dimension if missing
: tf . image . cropAndResize ( tensor , box , [ 0 ] , [ model . inputs [ 0 ] . shape [ 2 ] , model . inputs [ 0 ] . shape [ 1 ] ] ) ;
* /
/ *
// convert to black&white to avoid colorization impact
const rgb = [ 0.2989 , 0.5870 , 0.1140 ] ; // factors for red/green/blue colors when converting to grayscale: https://www.mathworks.com/help/matlab/ref/rgb2gray.html
const [ red , green , blue ] = tf . split ( crop , 3 , 3 ) ;
const redNorm = tf . mul ( red , rgb [ 0 ] ) ;
const greenNorm = tf . mul ( green , rgb [ 1 ] ) ;
const blueNorm = tf . mul ( blue , rgb [ 2 ] ) ;
const grayscale = tf . addN ( [ redNorm , greenNorm , blueNorm ] ) ;
const merge = tf . stack ( [ grayscale , grayscale , grayscale ] , 3 ) . squeeze ( 4 ) ;
* /
2021-03-21 19:18:51 +01:00
}
2021-11-13 18:23:32 +01:00
export async function predict ( image : Tensor , config : Config , idx , count ) : Promise < { age : number , gender : string , genderScore : number , descriptor : number [ ] } > {
if ( ! model ) return { age : 0 , gender : 'unknown' , genderScore : 0 , descriptor : [ ] } ;
2021-10-23 15:38:52 +02:00
const skipFrame = skipped < ( config . face . description ? . skipFrames || 0 ) ;
const skipTime = ( config . face . description ? . skipTime || 0 ) > ( now ( ) - lastTime ) ;
if ( config . skipAllowed && skipFrame && skipTime && ( lastCount === count ) && last [ idx ] ? . age && ( last [ idx ] ? . age > 0 ) ) {
2021-03-21 19:18:51 +01:00
skipped ++ ;
2021-06-01 14:59:09 +02:00
return last [ idx ] ;
2021-03-21 19:18:51 +01:00
}
2021-05-18 17:26:16 +02:00
skipped = 0 ;
2021-03-21 19:18:51 +01:00
return new Promise ( async ( resolve ) = > {
const obj = {
age : < number > 0 ,
gender : < string > 'unknown' ,
2021-06-01 14:59:09 +02:00
genderScore : < number > 0 ,
descriptor : < number [ ] > [ ] ,
} ;
2021-03-21 19:18:51 +01:00
2021-10-22 22:09:52 +02:00
if ( config . face . description ? . enabled ) {
const enhanced = enhance ( image ) ;
2021-11-02 16:07:11 +01:00
const resT = model ? . execute ( enhanced ) as Tensor [ ] ;
2021-10-22 22:09:52 +02:00
lastTime = now ( ) ;
tf . dispose ( enhanced ) ;
const genderT = await resT . find ( ( t ) = > t . shape [ 1 ] === 1 ) as Tensor ;
const gender = await genderT . data ( ) ;
2021-08-14 17:16:26 +02:00
const confidence = Math . trunc ( 200 * Math . abs ( ( gender [ 0 ] - 0.5 ) ) ) / 100 ;
2021-09-12 05:54:35 +02:00
if ( confidence > ( config . face . description ? . minConfidence || 0 ) ) {
2021-08-14 17:16:26 +02:00
obj . gender = gender [ 0 ] <= 0.5 ? 'female' : 'male' ;
obj . genderScore = Math . min ( 0.99 , confidence ) ;
}
const argmax = tf . argMax ( resT . find ( ( t ) = > t . shape [ 1 ] === 100 ) , 1 ) ;
const age = ( await argmax . data ( ) ) [ 0 ] ;
2021-09-13 19:28:35 +02:00
tf . dispose ( argmax ) ;
2021-10-22 22:09:52 +02:00
const ageT = resT . find ( ( t ) = > t . shape [ 1 ] === 100 ) as Tensor ;
const all = await ageT . data ( ) ;
2021-08-14 17:16:26 +02:00
obj . age = Math . round ( all [ age - 1 ] > all [ age + 1 ] ? 10 * age - 100 * all [ age - 1 ] : 10 * age + 100 * all [ age + 1 ] ) / 10 ;
const desc = resT . find ( ( t ) = > t . shape [ 1 ] === 1024 ) ;
// const reshape = desc.reshape([128, 8]); // reshape large 1024-element descriptor to 128 x 8
// const reduce = reshape.logSumExp(1); // reduce 2nd dimension by calculating logSumExp on it which leaves us with 128-element descriptor
2021-10-22 22:09:52 +02:00
const descriptor = desc ? await desc . data ( ) : < number [ ] > [ ] ;
obj . descriptor = Array . from ( descriptor ) ;
2021-03-21 19:18:51 +01:00
resT . forEach ( ( t ) = > tf . dispose ( t ) ) ;
}
2021-05-18 17:26:16 +02:00
last [ idx ] = obj ;
lastCount = count ;
2021-03-21 19:18:51 +01:00
resolve ( obj ) ;
} ) ;
}