refactoring

pull/46/head
Vladimir Mandic 2021-03-19 18:46:36 -04:00
parent d85c913347
commit 1b68ca1160
41 changed files with 164 additions and 274 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
node_modules node_modules
pnpm-lock.yaml pnpm-lock.yaml
test

View File

@ -20,8 +20,8 @@ function str(json) {
function log(...txt) { function log(...txt) {
// eslint-disable-next-line no-console // eslint-disable-next-line no-console
console.log(...txt); console.log(...txt);
// @ts-ignore const div = document.getElementById('log');
document.getElementById('log').innerHTML += `<br>${txt}`; if (div) div.innerHTML += `<br>${txt}`;
} }
// helper function to draw detected faces // helper function to draw detected faces

View File

@ -19,8 +19,8 @@ function str(json) {
function log(...txt) { function log(...txt) {
// eslint-disable-next-line no-console // eslint-disable-next-line no-console
console.log(...txt); console.log(...txt);
// @ts-ignore const div = document.getElementById('log');
document.getElementById('log').innerHTML += `<br>${txt}`; if (div) div.innerHTML += `<br>${txt}`;
} }
// helper function to draw detected faces // helper function to draw detected faces

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

96
dist/face-api.esm.js vendored

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

96
dist/face-api.js vendored

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -22,7 +22,7 @@
}, },
"scripts": { "scripts": {
"start": "node --trace-warnings demo/node-singleprocess.js", "start": "node --trace-warnings demo/node-singleprocess.js",
"dev": "npm install && node server/serve.js", "dev": "node --trace-warnings server/serve.js",
"build": "rimraf dist/* types/* typedoc/* && node server/build.js", "build": "rimraf dist/* types/* typedoc/* && node server/build.js",
"lint": "eslint src/**/* demo/*.js server/*.js", "lint": "eslint src/**/* demo/*.js server/*.js",
"test": "eslint src/**/* demo/*.js server/*.js" "test": "eslint src/**/* demo/*.js server/*.js"
@ -38,7 +38,6 @@
"emotion-detection", "emotion-detection",
"face-recognition" "face-recognition"
], ],
"peerDependencies": {},
"devDependencies": { "devDependencies": {
"@tensorflow/tfjs": "^3.3.0", "@tensorflow/tfjs": "^3.3.0",
"@tensorflow/tfjs-backend-wasm": "^3.3.0", "@tensorflow/tfjs-backend-wasm": "^3.3.0",
@ -50,7 +49,7 @@
"@vladmandic/pilogger": "^0.2.15", "@vladmandic/pilogger": "^0.2.15",
"chokidar": "^3.5.1", "chokidar": "^3.5.1",
"dayjs": "^1.10.4", "dayjs": "^1.10.4",
"esbuild": "^0.9.3", "esbuild": "^0.9.5",
"eslint": "^7.22.0", "eslint": "^7.22.0",
"eslint-config-airbnb-base": "^14.2.1", "eslint-config-airbnb-base": "^14.2.1",
"eslint-plugin-import": "^2.22.1", "eslint-plugin-import": "^2.22.1",

View File

@ -209,7 +209,7 @@ async function build(f, msg, dev = false) {
for (const [targetGroupName, targetGroup] of Object.entries(targets)) { for (const [targetGroupName, targetGroup] of Object.entries(targets)) {
for (const [targetName, targetOptions] of Object.entries(targetGroup)) { for (const [targetName, targetOptions] of Object.entries(targetGroup)) {
// if triggered from watch mode, rebuild only browser bundle // if triggered from watch mode, rebuild only browser bundle
if ((require.main !== module) && (targetGroupName !== 'browserBundle')) continue; // if ((require.main !== module) && (targetGroupName !== 'browserBundle')) continue;
// @ts-ignore // @ts-ignore
const meta = await esbuild.build({ ...common, ...targetOptions }); const meta = await esbuild.build({ ...common, ...targetOptions });
const stats = await getStats(meta); const stats = await getStats(meta);

View File

@ -5,9 +5,7 @@ import { seperateWeightMaps } from '../faceProcessor/util';
import { TinyXception } from '../xception/TinyXception'; import { TinyXception } from '../xception/TinyXception';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { extractParamsFromWeightMap } from './extractParamsFromWeightMap'; import { extractParamsFromWeightMap } from './extractParamsFromWeightMap';
import { import { AgeAndGenderPrediction, Gender, NetOutput, NetParams } from './types';
AgeAndGenderPrediction, Gender, NetOutput, NetParams,
} from './types';
import { NeuralNetwork } from '../NeuralNetwork'; import { NeuralNetwork } from '../NeuralNetwork';
import { NetInput, TNetInput, toNetInput } from '../dom/index'; import { NetInput, TNetInput, toNetInput } from '../dom/index';

View File

@ -1,8 +1,6 @@
import * as tf from '../../dist/tfjs.esm'; import * as tf from '../../dist/tfjs.esm';
import { import { disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping } from '../common/index';
disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping,
} from '../common/index';
import { NetParams } from './types'; import { NetParams } from './types';
export function extractParamsFromWeightMap( export function extractParamsFromWeightMap(

View File

@ -107,9 +107,7 @@ export class Box<BoxType = any> implements IBoundingBox, IRect {
height += diff; height += diff;
} }
return new Box({ return new Box({ x, y, width, height });
x, y, width, height,
});
} }
public rescale(s: IDimensions | number): Box<BoxType> { public rescale(s: IDimensions | number): Box<BoxType> {
@ -136,9 +134,7 @@ export class Box<BoxType = any> implements IBoundingBox, IRect {
} }
public clipAtImageBorders(imgWidth: number, imgHeight: number): Box<BoxType> { public clipAtImageBorders(imgWidth: number, imgHeight: number): Box<BoxType> {
const { const { x, y, right, bottom } = this;
x, y, right, bottom,
} = this;
const clippedX = Math.max(x, 0); const clippedX = Math.max(x, 0);
const clippedY = Math.max(y, 0); const clippedY = Math.max(y, 0);

View File

@ -3,9 +3,7 @@ import * as tf from '../../dist/tfjs.esm';
import { Dimensions } from '../classes/Dimensions'; import { Dimensions } from '../classes/Dimensions';
import { env } from '../env/index'; import { env } from '../env/index';
import { padToSquare } from '../ops/padToSquare'; import { padToSquare } from '../ops/padToSquare';
import { import { computeReshapedDimensions, isTensor3D, isTensor4D, range } from '../utils/index';
computeReshapedDimensions, isTensor3D, isTensor4D, range,
} from '../utils/index';
import { createCanvasFromMedia } from './createCanvas'; import { createCanvasFromMedia } from './createCanvas';
import { imageToSquare } from './imageToSquare'; import { imageToSquare } from './imageToSquare';
import { TResolvedNetInput } from './types'; import { TResolvedNetInput } from './types';
@ -23,10 +21,7 @@ export class NetInput {
private _inputSize: number private _inputSize: number
constructor( constructor(inputs: Array<TResolvedNetInput>, treatAsBatchInput: boolean = false) {
inputs: Array<TResolvedNetInput>,
treatAsBatchInput: boolean = false,
) {
if (!Array.isArray(inputs)) { if (!Array.isArray(inputs)) {
throw new Error(`NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have ${inputs}`); throw new Error(`NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have ${inputs}`);
} }
@ -131,13 +126,11 @@ export class NetInput {
const input = this.getInput(batchIdx); const input = this.getInput(batchIdx);
if (input instanceof tf.Tensor) { if (input instanceof tf.Tensor) {
// @ts-ignore: error TS2344: Type 'Rank.R4' does not satisfy the constraint 'Tensor<Rank>'. let imgTensor = isTensor4D(input) ? input : tf.expandDims(input);
let imgTensor = isTensor4D(input) ? input : input.expandDims<tf.Rank.R4>();
// @ts-ignore: error TS2344: Type 'Rank.R4' does not satisfy the constraint 'Tensor<Rank>'.
imgTensor = padToSquare(imgTensor, isCenterInputs); imgTensor = padToSquare(imgTensor, isCenterInputs);
if (imgTensor.shape[1] !== inputSize || imgTensor.shape[2] !== inputSize) { if (imgTensor.shape[1] !== inputSize || imgTensor.shape[2] !== inputSize) {
imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize]); imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize], false, false);
} }
return imgTensor.as3D(inputSize, inputSize, 3); return imgTensor.as3D(inputSize, inputSize, 3);
@ -150,9 +143,7 @@ export class NetInput {
throw new Error(`toBatchTensor - at batchIdx ${batchIdx}, expected input to be instanceof tf.Tensor or instanceof HTMLCanvasElement, instead have ${input}`); throw new Error(`toBatchTensor - at batchIdx ${batchIdx}, expected input to be instanceof tf.Tensor or instanceof HTMLCanvasElement, instead have ${input}`);
}); });
// const batchTensor = tf.stack(inputTensors.map(t => t.toFloat())).as4D(this.batchSize, inputSize, inputSize, 3)
const batchTensor = tf.stack(inputTensors.map((t) => tf.cast(t, 'float32'))).as4D(this.batchSize, inputSize, inputSize, 3); const batchTensor = tf.stack(inputTensors.map((t) => tf.cast(t, 'float32'))).as4D(this.batchSize, inputSize, inputSize, 3);
// const batchTensor = tf.stack(inputTensors.map(t => tf.Tensor.as4D(tf.cast(t, 'float32'))), this.batchSize, inputSize, inputSize, 3);
return batchTensor; return batchTensor;
}); });

View File

@ -14,10 +14,7 @@ import { TNetInput } from './types';
* @param detections The face detection results or face bounding boxes for that image. * @param detections The face detection results or face bounding boxes for that image.
* @returns The Canvases of the corresponding image region for each detected face. * @returns The Canvases of the corresponding image region for each detected face.
*/ */
export async function extractFaces( export async function extractFaces(input: TNetInput, detections: Array<FaceDetection | Rect>): Promise<HTMLCanvasElement[]> {
input: TNetInput,
detections: Array<FaceDetection | Rect>,
): Promise<HTMLCanvasElement[]> {
const { Canvas } = env.getEnv(); const { Canvas } = env.getEnv();
let canvas = input as HTMLCanvasElement; let canvas = input as HTMLCanvasElement;
@ -36,16 +33,13 @@ export async function extractFaces(
} }
const ctx = getContext2dOrThrow(canvas); const ctx = getContext2dOrThrow(canvas);
const boxes = detections.map( const boxes = detections
(det) => (det instanceof FaceDetection .map((det) => (det instanceof FaceDetection
? det.forSize(canvas.width, canvas.height).box.floor() ? det.forSize(canvas.width, canvas.height).box.floor()
: det), : det))
)
.map((box) => box.clipAtImageBorders(canvas.width, canvas.height)); .map((box) => box.clipAtImageBorders(canvas.width, canvas.height));
return boxes.map(({ return boxes.map(({ x, y, width, height }) => {
x, y, width, height,
}) => {
const faceImg = createCanvas({ width, height }); const faceImg = createCanvas({ width, height });
if (width > 0 && height > 0) getContext2dOrThrow(faceImg).putImageData(ctx.getImageData(x, y, width, height), 0, 0); if (width > 0 && height > 0) getContext2dOrThrow(faceImg).putImageData(ctx.getImageData(x, y, width, height), 0, 0);
return faceImg; return faceImg;

View File

@ -1,9 +1,7 @@
/* eslint-disable max-classes-per-file */ /* eslint-disable max-classes-per-file */
import { Box, IBoundingBox, IRect } from '../classes/index'; import { Box, IBoundingBox, IRect } from '../classes/index';
import { getContext2dOrThrow } from '../dom/getContext2dOrThrow'; import { getContext2dOrThrow } from '../dom/getContext2dOrThrow';
import { import { AnchorPosition, DrawTextField, DrawTextFieldOptions, IDrawTextFieldOptions } from './DrawTextField';
AnchorPosition, DrawTextField, DrawTextFieldOptions, IDrawTextFieldOptions,
} from './DrawTextField';
export interface IDrawBoxOptions { export interface IDrawBoxOptions {
boxColor?: string boxColor?: string

2
src/env/isNodejs.ts vendored
View File

@ -2,7 +2,5 @@ export function isNodejs(): boolean {
return typeof global === 'object' return typeof global === 'object'
&& typeof require === 'function' && typeof require === 'function'
&& typeof module !== 'undefined' && typeof module !== 'undefined'
// issues with gatsby.js: module.exports is undefined
// && !!module.exports
&& typeof process !== 'undefined' && !!process.version; && typeof process !== 'undefined' && !!process.version;
} }

View File

@ -1,9 +1,4 @@
import { import { extractConvParamsFactory, extractSeparableConvParamsFactory, ExtractWeightsFunction, ParamMapping } from '../common/index';
extractConvParamsFactory,
extractSeparableConvParamsFactory,
ExtractWeightsFunction,
ParamMapping,
} from '../common/index';
import { DenseBlock3Params, DenseBlock4Params } from './types'; import { DenseBlock3Params, DenseBlock4Params } from './types';
export function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) { export function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {

View File

@ -2,11 +2,7 @@ import * as tf from '../../dist/tfjs.esm';
import { fullyConnectedLayer } from '../common/fullyConnectedLayer'; import { fullyConnectedLayer } from '../common/fullyConnectedLayer';
import { NetInput } from '../dom/index'; import { NetInput } from '../dom/index';
import { import { FaceFeatureExtractorParams, IFaceFeatureExtractor, TinyFaceFeatureExtractorParams } from '../faceFeatureExtractor/types';
FaceFeatureExtractorParams,
IFaceFeatureExtractor,
TinyFaceFeatureExtractorParams,
} from '../faceFeatureExtractor/types';
import { NeuralNetwork } from '../NeuralNetwork'; import { NeuralNetwork } from '../NeuralNetwork';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { extractParamsFromWeightMap } from './extractParamsFromWeightMap'; import { extractParamsFromWeightMap } from './extractParamsFromWeightMap';

View File

@ -1,8 +1,6 @@
import * as tf from '../../dist/tfjs.esm'; import * as tf from '../../dist/tfjs.esm';
import { import { disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping } from '../common/index';
disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping,
} from '../common/index';
import { NetParams } from './types'; import { NetParams } from './types';
export function extractParamsFromWeightMap( export function extractParamsFromWeightMap(

View File

@ -1,12 +1,8 @@
import * as tf from '../../dist/tfjs.esm'; import * as tf from '../../dist/tfjs.esm';
import { import { ConvParams, extractWeightsFactory, ExtractWeightsFunction, ParamMapping } from '../common/index';
ConvParams, extractWeightsFactory, ExtractWeightsFunction, ParamMapping,
} from '../common/index';
import { isFloat } from '../utils/index'; import { isFloat } from '../utils/index';
import { import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types';
ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams,
} from './types';
function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) { function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
function extractFilterValues(numFilterValues: number, numFilters: number, filterSize: number): tf.Tensor4D { function extractFilterValues(numFilterValues: number, numFilters: number, filterSize: number): tf.Tensor4D {

View File

@ -6,14 +6,8 @@ import { WithFaceLandmarks } from '../factories/WithFaceLandmarks';
import { ComposableTask } from './ComposableTask'; import { ComposableTask } from './ComposableTask';
import { extractAllFacesAndComputeResults, extractSingleFaceAndComputeResult } from './extractFacesAndComputeResults'; import { extractAllFacesAndComputeResults, extractSingleFaceAndComputeResult } from './extractFacesAndComputeResults';
import { nets } from './nets'; import { nets } from './nets';
import { import { PredictAllAgeAndGenderWithFaceAlignmentTask, PredictSingleAgeAndGenderWithFaceAlignmentTask } from './PredictAgeAndGenderTask';
PredictAllAgeAndGenderWithFaceAlignmentTask, import { PredictAllFaceExpressionsWithFaceAlignmentTask, PredictSingleFaceExpressionsWithFaceAlignmentTask } from './PredictFaceExpressionsTask';
PredictSingleAgeAndGenderWithFaceAlignmentTask,
} from './PredictAgeAndGenderTask';
import {
PredictAllFaceExpressionsWithFaceAlignmentTask,
PredictSingleFaceExpressionsWithFaceAlignmentTask,
} from './PredictFaceExpressionsTask';
export class ComputeFaceDescriptorsTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> { export class ComputeFaceDescriptorsTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> {
constructor( constructor(

View File

@ -10,14 +10,8 @@ import { extendWithFaceLandmarks, WithFaceLandmarks } from '../factories/WithFac
import { ComposableTask } from './ComposableTask'; import { ComposableTask } from './ComposableTask';
import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks'; import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks';
import { nets } from './nets'; import { nets } from './nets';
import { import { PredictAllAgeAndGenderWithFaceAlignmentTask, PredictSingleAgeAndGenderWithFaceAlignmentTask } from './PredictAgeAndGenderTask';
PredictAllAgeAndGenderWithFaceAlignmentTask, import { PredictAllFaceExpressionsWithFaceAlignmentTask, PredictSingleFaceExpressionsWithFaceAlignmentTask } from './PredictFaceExpressionsTask';
PredictSingleAgeAndGenderWithFaceAlignmentTask,
} from './PredictAgeAndGenderTask';
import {
PredictAllFaceExpressionsWithFaceAlignmentTask,
PredictSingleFaceExpressionsWithFaceAlignmentTask,
} from './PredictFaceExpressionsTask';
export class DetectFaceLandmarksTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> { export class DetectFaceLandmarksTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> {
constructor( constructor(

View File

@ -11,12 +11,7 @@ import { ComposableTask } from './ComposableTask';
import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks'; import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks';
import { extractAllFacesAndComputeResults, extractSingleFaceAndComputeResult } from './extractFacesAndComputeResults'; import { extractAllFacesAndComputeResults, extractSingleFaceAndComputeResult } from './extractFacesAndComputeResults';
import { nets } from './nets'; import { nets } from './nets';
import { import { PredictAllFaceExpressionsTask, PredictAllFaceExpressionsWithFaceAlignmentTask, PredictSingleFaceExpressionsTask, PredictSingleFaceExpressionsWithFaceAlignmentTask } from './PredictFaceExpressionsTask';
PredictAllFaceExpressionsTask,
PredictAllFaceExpressionsWithFaceAlignmentTask,
PredictSingleFaceExpressionsTask,
PredictSingleFaceExpressionsWithFaceAlignmentTask,
} from './PredictFaceExpressionsTask';
export class PredictAgeAndGenderTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> { export class PredictAgeAndGenderTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> {
constructor( constructor(

View File

@ -10,12 +10,7 @@ import { ComposableTask } from './ComposableTask';
import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks'; import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks';
import { extractAllFacesAndComputeResults, extractSingleFaceAndComputeResult } from './extractFacesAndComputeResults'; import { extractAllFacesAndComputeResults, extractSingleFaceAndComputeResult } from './extractFacesAndComputeResults';
import { nets } from './nets'; import { nets } from './nets';
import { import { PredictAllAgeAndGenderTask, PredictAllAgeAndGenderWithFaceAlignmentTask, PredictSingleAgeAndGenderTask, PredictSingleAgeAndGenderWithFaceAlignmentTask } from './PredictAgeAndGenderTask';
PredictAllAgeAndGenderTask,
PredictAllAgeAndGenderWithFaceAlignmentTask,
PredictSingleAgeAndGenderTask,
PredictSingleAgeAndGenderWithFaceAlignmentTask,
} from './PredictAgeAndGenderTask';
export class PredictFaceExpressionsTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> { export class PredictFaceExpressionsTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> {
constructor( constructor(

View File

@ -41,7 +41,6 @@ export function padToSquare(
paddingTensorAppend, paddingTensorAppend,
] ]
.filter((t) => !!t) .filter((t) => !!t)
// .map((t: tf.Tensor) => t.toFloat()) as tf.Tensor4D[]
.map((t: tf.Tensor) => tf.cast(t, 'float32')) as tf.Tensor4D[]; .map((t: tf.Tensor) => tf.cast(t, 'float32')) as tf.Tensor4D[];
return tf.concat(tensorsToStack, paddingAxis); return tf.concat(tensorsToStack, paddingAxis);
}); });

View File

@ -27,7 +27,7 @@ export class SsdMobilenetv1 extends NeuralNetwork<NetParams> {
return tf.tidy(() => { return tf.tidy(() => {
const batchTensor = tf.cast(input.toBatchTensor(512, false), 'float32'); const batchTensor = tf.cast(input.toBatchTensor(512, false), 'float32');
const x = tf.sub(tf.mul(batchTensor, tf.scalar(0.007843137718737125)), tf.scalar(1)) as tf.Tensor4D; const x = tf.sub(tf.div(batchTensor, 127.5), 1) as tf.Tensor4D; // input is normalized -1..1
const features = mobileNetV1(x, params.mobilenetv1); const features = mobileNetV1(x, params.mobilenetv1);
const { boxPredictions, classPredictions } = predictionLayer(features.out, features.conv11, params.prediction_layer); const { boxPredictions, classPredictions } = predictionLayer(features.out, features.conv11, params.prediction_layer);
@ -40,10 +40,7 @@ export class SsdMobilenetv1 extends NeuralNetwork<NetParams> {
return this.forwardInput(await toNetInput(input)); return this.forwardInput(await toNetInput(input));
} }
public async locateFaces( public async locateFaces(input: TNetInput, options: ISsdMobilenetv1Options = {}): Promise<FaceDetection[]> {
input: TNetInput,
options: ISsdMobilenetv1Options = {},
): Promise<FaceDetection[]> {
const { maxResults, minConfidence } = new SsdMobilenetv1Options(options); const { maxResults, minConfidence } = new SsdMobilenetv1Options(options);
const netInput = await toNetInput(input); const netInput = await toNetInput(input);

View File

@ -1,11 +1,7 @@
import * as tf from '../../dist/tfjs.esm'; import * as tf from '../../dist/tfjs.esm';
import { import { ExtractWeightsFunction, ParamMapping, ConvParams, extractWeightsFactory } from '../common/index';
ExtractWeightsFunction, ParamMapping, ConvParams, extractWeightsFactory, import { MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams } from './types';
} from '../common/index';
import {
MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams,
} from './types';
function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) { function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
function extractDepthwiseConvParams(numChannels: number, mappedPrefix: string): MobileNetV1.DepthwiseConvParams { function extractDepthwiseConvParams(numChannels: number, mappedPrefix: string): MobileNetV1.DepthwiseConvParams {

View File

@ -1,12 +1,8 @@
import * as tf from '../../dist/tfjs.esm'; import * as tf from '../../dist/tfjs.esm';
import { import { ConvParams, disposeUnusedWeightTensors, extractWeightEntryFactory, ParamMapping } from '../common/index';
ConvParams, disposeUnusedWeightTensors, extractWeightEntryFactory, ParamMapping,
} from '../common/index';
import { isTensor3D } from '../utils/index'; import { isTensor3D } from '../utils/index';
import { import { BoxPredictionParams, MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams } from './types';
BoxPredictionParams, MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams,
} from './types';
function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) { function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings); const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);

View File

@ -5,11 +5,7 @@ import { MobileNetV1 } from './types';
const epsilon = 0.0010000000474974513; const epsilon = 0.0010000000474974513;
function depthwiseConvLayer( function depthwiseConvLayer(x: tf.Tensor4D, params: MobileNetV1.DepthwiseConvParams, strides: [number, number]) {
x: tf.Tensor4D,
params: MobileNetV1.DepthwiseConvParams,
strides: [number, number],
) {
return tf.tidy(() => { return tf.tidy(() => {
let out = tf.depthwiseConv2d(x, params.filters, strides, 'same'); let out = tf.depthwiseConv2d(x, params.filters, strides, 'same');
out = tf.batchNorm<tf.Rank.R4>( out = tf.batchNorm<tf.Rank.R4>(

View File

@ -12,15 +12,12 @@ function IOU(boxes: tf.Tensor2D, i: number, j: number) {
const xmaxJ = Math.max(boxesData[j][1], boxesData[j][3]); const xmaxJ = Math.max(boxesData[j][1], boxesData[j][3]);
const areaI = (ymaxI - yminI) * (xmaxI - xminI); const areaI = (ymaxI - yminI) * (xmaxI - xminI);
const areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ); const areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
if (areaI <= 0 || areaJ <= 0) { if (areaI <= 0 || areaJ <= 0) return 0.0;
return 0.0;
}
const intersectionYmin = Math.max(yminI, yminJ); const intersectionYmin = Math.max(yminI, yminJ);
const intersectionXmin = Math.max(xminI, xminJ); const intersectionXmin = Math.max(xminI, xminJ);
const intersectionYmax = Math.min(ymaxI, ymaxJ); const intersectionYmax = Math.min(ymaxI, ymaxJ);
const intersectionXmax = Math.min(xmaxI, xmaxJ); const intersectionXmax = Math.min(xmaxI, xmaxJ);
const intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) const intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) * Math.max(intersectionXmax - intersectionXmin, 0.0);
* Math.max(intersectionXmax - intersectionXmin, 0.0);
return intersectionArea / (areaI + areaJ - intersectionArea); return intersectionArea / (areaI + areaJ - intersectionArea);
} }
@ -32,10 +29,7 @@ export function nonMaxSuppression(
scoreThreshold: number, scoreThreshold: number,
): number[] { ): number[] {
const numBoxes = boxes.shape[0]; const numBoxes = boxes.shape[0];
const outputSize = Math.min( const outputSize = Math.min(maxOutputSize, numBoxes);
maxOutputSize,
numBoxes,
);
const candidates = scores const candidates = scores
.map((score, boxIndex) => ({ score, boxIndex })) .map((score, boxIndex) => ({ score, boxIndex }))

View File

@ -13,17 +13,11 @@ function getCenterCoordinatesAndSizesLayer(x: tf.Tensor2D) {
tf.add(vec[0], tf.div(sizes[0], tf.scalar(2))), tf.add(vec[0], tf.div(sizes[0], tf.scalar(2))),
tf.add(vec[1], tf.div(sizes[1], tf.scalar(2))), tf.add(vec[1], tf.div(sizes[1], tf.scalar(2))),
]; ];
return { return { sizes, centers };
sizes,
centers,
};
} }
function decodeBoxesLayer(x0: tf.Tensor2D, x1: tf.Tensor2D) { function decodeBoxesLayer(x0: tf.Tensor2D, x1: tf.Tensor2D) {
const { const { sizes, centers } = getCenterCoordinatesAndSizesLayer(x0);
sizes,
centers,
} = getCenterCoordinatesAndSizesLayer(x0);
const vec = tf.unstack(tf.transpose(x1, [1, 0])); const vec = tf.unstack(tf.transpose(x1, [1, 0]));
const div0_out = tf.div(tf.mul(tf.exp(tf.div(vec[2], tf.scalar(5))), sizes[0]), tf.scalar(2)); const div0_out = tf.div(tf.mul(tf.exp(tf.div(vec[2], tf.scalar(5))), sizes[0]), tf.scalar(2));
@ -42,11 +36,7 @@ function decodeBoxesLayer(x0: tf.Tensor2D, x1: tf.Tensor2D) {
); );
} }
export function outputLayer( export function outputLayer(boxPredictions: tf.Tensor4D, classPredictions: tf.Tensor4D, params: OutputLayerParams) {
boxPredictions: tf.Tensor4D,
classPredictions: tf.Tensor4D,
params: OutputLayerParams,
) {
return tf.tidy(() => { return tf.tidy(() => {
const batchSize = boxPredictions.shape[0]; const batchSize = boxPredictions.shape[0];
@ -54,25 +44,16 @@ export function outputLayer(
tf.reshape(tf.tile(params.extra_dim, [batchSize, 1, 1]), [-1, 4]) as tf.Tensor2D, tf.reshape(tf.tile(params.extra_dim, [batchSize, 1, 1]), [-1, 4]) as tf.Tensor2D,
tf.reshape(boxPredictions, [-1, 4]) as tf.Tensor2D, tf.reshape(boxPredictions, [-1, 4]) as tf.Tensor2D,
); );
boxes = tf.reshape( boxes = tf.reshape(boxes, [batchSize, (boxes.shape[0] / batchSize), 4]);
boxes,
[batchSize, (boxes.shape[0] / batchSize), 4],
);
const scoresAndClasses = tf.sigmoid(tf.slice(classPredictions, [0, 0, 1], [-1, -1, -1])); const scoresAndClasses = tf.sigmoid(tf.slice(classPredictions, [0, 0, 1], [-1, -1, -1]));
let scores = tf.slice(scoresAndClasses, [0, 0, 0], [-1, -1, 1]) as tf.Tensor; let scores = tf.slice(scoresAndClasses, [0, 0, 0], [-1, -1, 1]) as tf.Tensor;
scores = tf.reshape( scores = tf.reshape(scores, [batchSize, scores.shape[1] as number]);
scores,
[batchSize, scores.shape[1] as number],
);
const boxesByBatch = tf.unstack(boxes) as tf.Tensor2D[]; const boxesByBatch = tf.unstack(boxes) as tf.Tensor2D[];
const scoresByBatch = tf.unstack(scores) as tf.Tensor1D[]; const scoresByBatch = tf.unstack(scores) as tf.Tensor1D[];
return { return { boxes: boxesByBatch, scores: scoresByBatch };
boxes: boxesByBatch,
scores: scoresByBatch,
};
}); });
} }

View File

@ -1,11 +1,6 @@
import * as tf from '../../dist/tfjs.esm'; import * as tf from '../../dist/tfjs.esm';
import { import { disposeUnusedWeightTensors, extractWeightEntryFactory, loadSeparableConvParamsFactory, ParamMapping } from '../common/index';
disposeUnusedWeightTensors,
extractWeightEntryFactory,
loadSeparableConvParamsFactory,
ParamMapping,
} from '../common/index';
import { loadConvParamsFactory } from '../common/loadConvParamsFactory'; import { loadConvParamsFactory } from '../common/loadConvParamsFactory';
import { range } from '../utils/index'; import { range } from '../utils/index';
import { MainBlockParams, ReductionBlockParams, TinyXceptionParams } from './types'; import { MainBlockParams, ReductionBlockParams, TinyXceptionParams } from './types';