2021-05-25 14:58:20 +02:00
/ * *
* HSE - FaceRes Module
* Returns Age , Gender , Descriptor
* Implements Face simmilarity function
* /
2021-04-09 14:07:58 +02:00
import { log , join } from '../helpers' ;
2021-03-21 19:18:51 +01:00
import * as tf from '../../dist/tfjs.esm.js' ;
2021-05-23 03:47:59 +02:00
import { Tensor , GraphModel } from '../tfjs/types' ;
2021-06-03 15:41:53 +02:00
import { Config } from '../config' ;
2021-03-21 19:18:51 +01:00
2021-05-23 03:47:59 +02:00
let model : GraphModel ;
2021-06-01 14:59:09 +02:00
const last : Array < {
age : number ,
gender : string ,
genderScore : number ,
descriptor : number [ ] ,
} > = [ ] ;
2021-05-18 17:26:16 +02:00
let lastCount = 0 ;
2021-03-21 19:18:51 +01:00
let skipped = Number . MAX_SAFE_INTEGER ;
type DB = Array < { name : string , source : string , embedding : number [ ] } > ;
2021-06-03 15:41:53 +02:00
export async function load ( config : Config ) : Promise < GraphModel > {
2021-05-23 03:47:59 +02:00
const modelUrl = join ( config . modelBasePath , config . face . description . modelPath ) ;
2021-03-21 19:18:51 +01:00
if ( ! model ) {
2021-05-23 03:47:59 +02:00
// @ts-ignore type mismatch for GraphModel
model = await tf . loadGraphModel ( modelUrl ) ;
if ( ! model ) log ( 'load model failed:' , config . face . description . modelPath ) ;
else if ( config . debug ) log ( 'load model:' , modelUrl ) ;
} else if ( config . debug ) log ( 'cached model:' , modelUrl ) ;
2021-03-21 19:18:51 +01:00
return model ;
}
2021-06-03 15:41:53 +02:00
export function similarity ( embedding1 : Array < number > , embedding2 : Array < number > , order = 2 ) : number {
2021-03-21 19:18:51 +01:00
if ( ! embedding1 || ! embedding2 ) return 0 ;
if ( embedding1 ? . length === 0 || embedding2 ? . length === 0 ) return 0 ;
if ( embedding1 ? . length !== embedding2 ? . length ) return 0 ;
// general minkowski distance, euclidean distance is limited case where order is 2
2021-04-13 17:05:52 +02:00
const distance = 5.0 * embedding1
2021-06-05 23:51:46 +02:00
. map ( ( _val , i ) = > ( Math . abs ( embedding1 [ i ] - embedding2 [ i ] ) * * order ) ) // distance squared
2021-03-21 19:18:51 +01:00
. reduce ( ( sum , now ) = > ( sum + now ) , 0 ) // sum all distances
* * ( 1 / order ) ; // get root of
const res = Math . max ( 0 , 100 - distance ) / 100.0 ;
return res ;
}
export function match ( embedding : Array < number > , db : DB , threshold = 0 ) {
let best = { similarity : 0 , name : '' , source : '' , embedding : [ ] as number [ ] } ;
if ( ! embedding || ! db || ! Array . isArray ( embedding ) || ! Array . isArray ( db ) ) return best ;
for ( const f of db ) {
if ( f . embedding && f . name ) {
const perc = similarity ( embedding , f . embedding ) ;
if ( perc > threshold && perc > best . similarity ) best = { . . . f , similarity : perc } ;
}
}
return best ;
}
export function enhance ( input ) : Tensor {
const image = tf . tidy ( ( ) = > {
// input received from detector is already normalized to 0..1
// input is also assumed to be straightened
const tensor = input . image || input . tensor || input ;
2021-03-23 20:24:58 +01:00
if ( ! ( tensor instanceof tf . Tensor ) ) return null ;
2021-03-21 19:18:51 +01:00
// do a tight crop of image and resize it to fit the model
const box = [ [ 0.05 , 0.15 , 0.85 , 0.85 ] ] ; // empyrical values for top, left, bottom, right
2021-04-13 17:05:52 +02:00
// const box = [[0.0, 0.0, 1.0, 1.0]]; // basically no crop for test
2021-05-23 03:47:59 +02:00
if ( ! model . inputs [ 0 ] . shape ) return null ; // model has no shape so no point continuing
2021-03-21 19:18:51 +01:00
const crop = ( tensor . shape . length === 3 )
? tf . image . cropAndResize ( tf . expandDims ( tensor , 0 ) , box , [ 0 ] , [ model . inputs [ 0 ] . shape [ 2 ] , model . inputs [ 0 ] . shape [ 1 ] ] ) // add batch dimension if missing
: tf . image . cropAndResize ( tensor , box , [ 0 ] , [ model . inputs [ 0 ] . shape [ 2 ] , model . inputs [ 0 ] . shape [ 1 ] ] ) ;
2021-05-17 14:56:57 +02:00
/ *
// just resize to fit the embedding model instead of cropping
const crop = tf . image . resizeBilinear ( tensor , [ model . inputs [ 0 ] . shape [ 2 ] , model . inputs [ 0 ] . shape [ 1 ] ] , false ) ;
* /
2021-03-21 19:18:51 +01:00
/ *
// convert to black&white to avoid colorization impact
const rgb = [ 0.2989 , 0.5870 , 0.1140 ] ; // factors for red/green/blue colors when converting to grayscale: https://www.mathworks.com/help/matlab/ref/rgb2gray.html
const [ red , green , blue ] = tf . split ( crop , 3 , 3 ) ;
const redNorm = tf . mul ( red , rgb [ 0 ] ) ;
const greenNorm = tf . mul ( green , rgb [ 1 ] ) ;
const blueNorm = tf . mul ( blue , rgb [ 2 ] ) ;
const grayscale = tf . addN ( [ redNorm , greenNorm , blueNorm ] ) ;
const merge = tf . stack ( [ grayscale , grayscale , grayscale ] , 3 ) . squeeze ( 4 ) ;
* /
/ *
2021-05-17 14:56:57 +02:00
// increase image pseudo-contrast 100%
// (or do it per-channel so mean is done on each channel)
// (or calculate histogram and do it based on histogram)
2021-03-21 19:18:51 +01:00
const mean = merge . mean ( ) ;
2021-05-17 14:56:57 +02:00
const factor = 2 ;
2021-03-21 19:18:51 +01:00
const contrast = merge . sub ( mean ) . mul ( factor ) . add ( mean ) ;
* /
2021-04-13 17:05:52 +02:00
2021-03-21 19:18:51 +01:00
/ *
// normalize brightness from 0..1
2021-05-17 14:56:57 +02:00
// silly way of creating pseudo-hdr of image
2021-03-21 19:18:51 +01:00
const darken = crop . sub ( crop . min ( ) ) ;
const lighten = darken . div ( darken . max ( ) ) ;
* /
2021-04-13 17:05:52 +02:00
2021-07-29 22:06:03 +02:00
const norm = tf . mul ( crop , 255 ) ;
2021-03-21 19:18:51 +01:00
return norm ;
} ) ;
return image ;
}
2021-06-03 15:41:53 +02:00
export async function predict ( image : Tensor , config : Config , idx , count ) {
2021-03-21 19:18:51 +01:00
if ( ! model ) return null ;
2021-05-18 17:26:16 +02:00
if ( ( skipped < config . face . description . skipFrames ) && config . skipFrame && ( lastCount === count ) && last [ idx ] ? . age && ( last [ idx ] ? . age > 0 ) ) {
2021-03-21 19:18:51 +01:00
skipped ++ ;
2021-06-01 14:59:09 +02:00
return last [ idx ] ;
2021-03-21 19:18:51 +01:00
}
2021-05-18 17:26:16 +02:00
skipped = 0 ;
2021-03-21 19:18:51 +01:00
return new Promise ( async ( resolve ) = > {
const enhanced = enhance ( image ) ;
let resT ;
const obj = {
age : < number > 0 ,
gender : < string > 'unknown' ,
2021-06-01 14:59:09 +02:00
genderScore : < number > 0 ,
descriptor : < number [ ] > [ ] ,
} ;
2021-03-21 19:18:51 +01:00
2021-04-25 19:16:04 +02:00
if ( config . face . description . enabled ) resT = await model . predict ( enhanced ) ;
2021-03-21 19:18:51 +01:00
tf . dispose ( enhanced ) ;
if ( resT ) {
tf . tidy ( ( ) = > {
2021-08-12 15:31:16 +02:00
const gender = resT . find ( ( t ) = > t . shape [ 1 ] === 1 ) . dataSync ( ) ; // inside tf.tidy
2021-03-21 19:18:51 +01:00
const confidence = Math . trunc ( 200 * Math . abs ( ( gender [ 0 ] - 0.5 ) ) ) / 100 ;
2021-04-25 00:43:59 +02:00
if ( confidence > config . face . description . minConfidence ) {
2021-03-21 19:18:51 +01:00
obj . gender = gender [ 0 ] <= 0.5 ? 'female' : 'male' ;
2021-06-01 14:59:09 +02:00
obj . genderScore = Math . min ( 0.99 , confidence ) ;
2021-03-21 19:18:51 +01:00
}
2021-08-12 15:31:16 +02:00
const age = tf . argMax ( resT . find ( ( t ) = > t . shape [ 1 ] === 100 ) , 1 ) . dataSync ( ) [ 0 ] ; // inside tf.tidy
const all = resT . find ( ( t ) = > t . shape [ 1 ] === 100 ) . dataSync ( ) ; // inside tf.tidy
2021-03-21 19:18:51 +01:00
obj . age = Math . round ( all [ age - 1 ] > all [ age + 1 ] ? 10 * age - 100 * all [ age - 1 ] : 10 * age + 100 * all [ age + 1 ] ) / 10 ;
const desc = resT . find ( ( t ) = > t . shape [ 1 ] === 1024 ) ;
2021-05-17 14:56:57 +02:00
// const reshape = desc.reshape([128, 8]); // reshape large 1024-element descriptor to 128 x 8
// const reduce = reshape.logSumExp(1); // reduce 2nd dimension by calculating logSumExp on it which leaves us with 128-element descriptor
2021-03-21 19:18:51 +01:00
2021-08-12 15:31:16 +02:00
obj . descriptor = [ . . . desc . dataSync ( ) ] ; // inside tf.tidy
2021-03-21 19:18:51 +01:00
} ) ;
resT . forEach ( ( t ) = > tf . dispose ( t ) ) ;
}
2021-05-18 17:26:16 +02:00
last [ idx ] = obj ;
lastCount = count ;
2021-03-21 19:18:51 +01:00
resolve ( obj ) ;
} ) ;
}