refactoring
parent
d85c913347
commit
1b68ca1160
|
@ -1,2 +1,3 @@
|
|||
node_modules
|
||||
pnpm-lock.yaml
|
||||
test
|
||||
|
|
|
@ -20,8 +20,8 @@ function str(json) {
|
|||
function log(...txt) {
|
||||
// eslint-disable-next-line no-console
|
||||
console.log(...txt);
|
||||
// @ts-ignore
|
||||
document.getElementById('log').innerHTML += `<br>${txt}`;
|
||||
const div = document.getElementById('log');
|
||||
if (div) div.innerHTML += `<br>${txt}`;
|
||||
}
|
||||
|
||||
// helper function to draw detected faces
|
||||
|
|
|
@ -19,8 +19,8 @@ function str(json) {
|
|||
function log(...txt) {
|
||||
// eslint-disable-next-line no-console
|
||||
console.log(...txt);
|
||||
// @ts-ignore
|
||||
document.getElementById('log').innerHTML += `<br>${txt}`;
|
||||
const div = document.getElementById('log');
|
||||
if (div) div.innerHTML += `<br>${txt}`;
|
||||
}
|
||||
|
||||
// helper function to draw detected faces
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -22,7 +22,7 @@
|
|||
},
|
||||
"scripts": {
|
||||
"start": "node --trace-warnings demo/node-singleprocess.js",
|
||||
"dev": "npm install && node server/serve.js",
|
||||
"dev": "node --trace-warnings server/serve.js",
|
||||
"build": "rimraf dist/* types/* typedoc/* && node server/build.js",
|
||||
"lint": "eslint src/**/* demo/*.js server/*.js",
|
||||
"test": "eslint src/**/* demo/*.js server/*.js"
|
||||
|
@ -38,7 +38,6 @@
|
|||
"emotion-detection",
|
||||
"face-recognition"
|
||||
],
|
||||
"peerDependencies": {},
|
||||
"devDependencies": {
|
||||
"@tensorflow/tfjs": "^3.3.0",
|
||||
"@tensorflow/tfjs-backend-wasm": "^3.3.0",
|
||||
|
@ -50,7 +49,7 @@
|
|||
"@vladmandic/pilogger": "^0.2.15",
|
||||
"chokidar": "^3.5.1",
|
||||
"dayjs": "^1.10.4",
|
||||
"esbuild": "^0.9.3",
|
||||
"esbuild": "^0.9.5",
|
||||
"eslint": "^7.22.0",
|
||||
"eslint-config-airbnb-base": "^14.2.1",
|
||||
"eslint-plugin-import": "^2.22.1",
|
||||
|
|
|
@ -209,7 +209,7 @@ async function build(f, msg, dev = false) {
|
|||
for (const [targetGroupName, targetGroup] of Object.entries(targets)) {
|
||||
for (const [targetName, targetOptions] of Object.entries(targetGroup)) {
|
||||
// if triggered from watch mode, rebuild only browser bundle
|
||||
if ((require.main !== module) && (targetGroupName !== 'browserBundle')) continue;
|
||||
// if ((require.main !== module) && (targetGroupName !== 'browserBundle')) continue;
|
||||
// @ts-ignore
|
||||
const meta = await esbuild.build({ ...common, ...targetOptions });
|
||||
const stats = await getStats(meta);
|
||||
|
|
|
@ -5,9 +5,7 @@ import { seperateWeightMaps } from '../faceProcessor/util';
|
|||
import { TinyXception } from '../xception/TinyXception';
|
||||
import { extractParams } from './extractParams';
|
||||
import { extractParamsFromWeightMap } from './extractParamsFromWeightMap';
|
||||
import {
|
||||
AgeAndGenderPrediction, Gender, NetOutput, NetParams,
|
||||
} from './types';
|
||||
import { AgeAndGenderPrediction, Gender, NetOutput, NetParams } from './types';
|
||||
import { NeuralNetwork } from '../NeuralNetwork';
|
||||
import { NetInput, TNetInput, toNetInput } from '../dom/index';
|
||||
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
import * as tf from '../../dist/tfjs.esm';
|
||||
|
||||
import {
|
||||
disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping,
|
||||
} from '../common/index';
|
||||
import { disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping } from '../common/index';
|
||||
import { NetParams } from './types';
|
||||
|
||||
export function extractParamsFromWeightMap(
|
||||
|
|
|
@ -107,9 +107,7 @@ export class Box<BoxType = any> implements IBoundingBox, IRect {
|
|||
height += diff;
|
||||
}
|
||||
|
||||
return new Box({
|
||||
x, y, width, height,
|
||||
});
|
||||
return new Box({ x, y, width, height });
|
||||
}
|
||||
|
||||
public rescale(s: IDimensions | number): Box<BoxType> {
|
||||
|
@ -136,9 +134,7 @@ export class Box<BoxType = any> implements IBoundingBox, IRect {
|
|||
}
|
||||
|
||||
public clipAtImageBorders(imgWidth: number, imgHeight: number): Box<BoxType> {
|
||||
const {
|
||||
x, y, right, bottom,
|
||||
} = this;
|
||||
const { x, y, right, bottom } = this;
|
||||
const clippedX = Math.max(x, 0);
|
||||
const clippedY = Math.max(y, 0);
|
||||
|
||||
|
|
|
@ -3,9 +3,7 @@ import * as tf from '../../dist/tfjs.esm';
|
|||
import { Dimensions } from '../classes/Dimensions';
|
||||
import { env } from '../env/index';
|
||||
import { padToSquare } from '../ops/padToSquare';
|
||||
import {
|
||||
computeReshapedDimensions, isTensor3D, isTensor4D, range,
|
||||
} from '../utils/index';
|
||||
import { computeReshapedDimensions, isTensor3D, isTensor4D, range } from '../utils/index';
|
||||
import { createCanvasFromMedia } from './createCanvas';
|
||||
import { imageToSquare } from './imageToSquare';
|
||||
import { TResolvedNetInput } from './types';
|
||||
|
@ -23,10 +21,7 @@ export class NetInput {
|
|||
|
||||
private _inputSize: number
|
||||
|
||||
constructor(
|
||||
inputs: Array<TResolvedNetInput>,
|
||||
treatAsBatchInput: boolean = false,
|
||||
) {
|
||||
constructor(inputs: Array<TResolvedNetInput>, treatAsBatchInput: boolean = false) {
|
||||
if (!Array.isArray(inputs)) {
|
||||
throw new Error(`NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have ${inputs}`);
|
||||
}
|
||||
|
@ -131,13 +126,11 @@ export class NetInput {
|
|||
const input = this.getInput(batchIdx);
|
||||
|
||||
if (input instanceof tf.Tensor) {
|
||||
// @ts-ignore: error TS2344: Type 'Rank.R4' does not satisfy the constraint 'Tensor<Rank>'.
|
||||
let imgTensor = isTensor4D(input) ? input : input.expandDims<tf.Rank.R4>();
|
||||
// @ts-ignore: error TS2344: Type 'Rank.R4' does not satisfy the constraint 'Tensor<Rank>'.
|
||||
let imgTensor = isTensor4D(input) ? input : tf.expandDims(input);
|
||||
imgTensor = padToSquare(imgTensor, isCenterInputs);
|
||||
|
||||
if (imgTensor.shape[1] !== inputSize || imgTensor.shape[2] !== inputSize) {
|
||||
imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize]);
|
||||
imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize], false, false);
|
||||
}
|
||||
|
||||
return imgTensor.as3D(inputSize, inputSize, 3);
|
||||
|
@ -150,9 +143,7 @@ export class NetInput {
|
|||
throw new Error(`toBatchTensor - at batchIdx ${batchIdx}, expected input to be instanceof tf.Tensor or instanceof HTMLCanvasElement, instead have ${input}`);
|
||||
});
|
||||
|
||||
// const batchTensor = tf.stack(inputTensors.map(t => t.toFloat())).as4D(this.batchSize, inputSize, inputSize, 3)
|
||||
const batchTensor = tf.stack(inputTensors.map((t) => tf.cast(t, 'float32'))).as4D(this.batchSize, inputSize, inputSize, 3);
|
||||
// const batchTensor = tf.stack(inputTensors.map(t => tf.Tensor.as4D(tf.cast(t, 'float32'))), this.batchSize, inputSize, inputSize, 3);
|
||||
|
||||
return batchTensor;
|
||||
});
|
||||
|
|
|
@ -14,10 +14,7 @@ import { TNetInput } from './types';
|
|||
* @param detections The face detection results or face bounding boxes for that image.
|
||||
* @returns The Canvases of the corresponding image region for each detected face.
|
||||
*/
|
||||
export async function extractFaces(
|
||||
input: TNetInput,
|
||||
detections: Array<FaceDetection | Rect>,
|
||||
): Promise<HTMLCanvasElement[]> {
|
||||
export async function extractFaces(input: TNetInput, detections: Array<FaceDetection | Rect>): Promise<HTMLCanvasElement[]> {
|
||||
const { Canvas } = env.getEnv();
|
||||
|
||||
let canvas = input as HTMLCanvasElement;
|
||||
|
@ -36,16 +33,13 @@ export async function extractFaces(
|
|||
}
|
||||
|
||||
const ctx = getContext2dOrThrow(canvas);
|
||||
const boxes = detections.map(
|
||||
(det) => (det instanceof FaceDetection
|
||||
const boxes = detections
|
||||
.map((det) => (det instanceof FaceDetection
|
||||
? det.forSize(canvas.width, canvas.height).box.floor()
|
||||
: det),
|
||||
)
|
||||
: det))
|
||||
.map((box) => box.clipAtImageBorders(canvas.width, canvas.height));
|
||||
|
||||
return boxes.map(({
|
||||
x, y, width, height,
|
||||
}) => {
|
||||
return boxes.map(({ x, y, width, height }) => {
|
||||
const faceImg = createCanvas({ width, height });
|
||||
if (width > 0 && height > 0) getContext2dOrThrow(faceImg).putImageData(ctx.getImageData(x, y, width, height), 0, 0);
|
||||
return faceImg;
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
/* eslint-disable max-classes-per-file */
|
||||
import { Box, IBoundingBox, IRect } from '../classes/index';
|
||||
import { getContext2dOrThrow } from '../dom/getContext2dOrThrow';
|
||||
import {
|
||||
AnchorPosition, DrawTextField, DrawTextFieldOptions, IDrawTextFieldOptions,
|
||||
} from './DrawTextField';
|
||||
import { AnchorPosition, DrawTextField, DrawTextFieldOptions, IDrawTextFieldOptions } from './DrawTextField';
|
||||
|
||||
export interface IDrawBoxOptions {
|
||||
boxColor?: string
|
||||
|
|
|
@ -2,7 +2,5 @@ export function isNodejs(): boolean {
|
|||
return typeof global === 'object'
|
||||
&& typeof require === 'function'
|
||||
&& typeof module !== 'undefined'
|
||||
// issues with gatsby.js: module.exports is undefined
|
||||
// && !!module.exports
|
||||
&& typeof process !== 'undefined' && !!process.version;
|
||||
}
|
||||
|
|
|
@ -1,9 +1,4 @@
|
|||
import {
|
||||
extractConvParamsFactory,
|
||||
extractSeparableConvParamsFactory,
|
||||
ExtractWeightsFunction,
|
||||
ParamMapping,
|
||||
} from '../common/index';
|
||||
import { extractConvParamsFactory, extractSeparableConvParamsFactory, ExtractWeightsFunction, ParamMapping } from '../common/index';
|
||||
import { DenseBlock3Params, DenseBlock4Params } from './types';
|
||||
|
||||
export function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
|
||||
|
|
|
@ -2,11 +2,7 @@ import * as tf from '../../dist/tfjs.esm';
|
|||
|
||||
import { fullyConnectedLayer } from '../common/fullyConnectedLayer';
|
||||
import { NetInput } from '../dom/index';
|
||||
import {
|
||||
FaceFeatureExtractorParams,
|
||||
IFaceFeatureExtractor,
|
||||
TinyFaceFeatureExtractorParams,
|
||||
} from '../faceFeatureExtractor/types';
|
||||
import { FaceFeatureExtractorParams, IFaceFeatureExtractor, TinyFaceFeatureExtractorParams } from '../faceFeatureExtractor/types';
|
||||
import { NeuralNetwork } from '../NeuralNetwork';
|
||||
import { extractParams } from './extractParams';
|
||||
import { extractParamsFromWeightMap } from './extractParamsFromWeightMap';
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
import * as tf from '../../dist/tfjs.esm';
|
||||
|
||||
import {
|
||||
disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping,
|
||||
} from '../common/index';
|
||||
import { disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping } from '../common/index';
|
||||
import { NetParams } from './types';
|
||||
|
||||
export function extractParamsFromWeightMap(
|
||||
|
|
|
@ -1,12 +1,8 @@
|
|||
import * as tf from '../../dist/tfjs.esm';
|
||||
|
||||
import {
|
||||
ConvParams, extractWeightsFactory, ExtractWeightsFunction, ParamMapping,
|
||||
} from '../common/index';
|
||||
import { ConvParams, extractWeightsFactory, ExtractWeightsFunction, ParamMapping } from '../common/index';
|
||||
import { isFloat } from '../utils/index';
|
||||
import {
|
||||
ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams,
|
||||
} from './types';
|
||||
import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types';
|
||||
|
||||
function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
|
||||
function extractFilterValues(numFilterValues: number, numFilters: number, filterSize: number): tf.Tensor4D {
|
||||
|
|
|
@ -6,14 +6,8 @@ import { WithFaceLandmarks } from '../factories/WithFaceLandmarks';
|
|||
import { ComposableTask } from './ComposableTask';
|
||||
import { extractAllFacesAndComputeResults, extractSingleFaceAndComputeResult } from './extractFacesAndComputeResults';
|
||||
import { nets } from './nets';
|
||||
import {
|
||||
PredictAllAgeAndGenderWithFaceAlignmentTask,
|
||||
PredictSingleAgeAndGenderWithFaceAlignmentTask,
|
||||
} from './PredictAgeAndGenderTask';
|
||||
import {
|
||||
PredictAllFaceExpressionsWithFaceAlignmentTask,
|
||||
PredictSingleFaceExpressionsWithFaceAlignmentTask,
|
||||
} from './PredictFaceExpressionsTask';
|
||||
import { PredictAllAgeAndGenderWithFaceAlignmentTask, PredictSingleAgeAndGenderWithFaceAlignmentTask } from './PredictAgeAndGenderTask';
|
||||
import { PredictAllFaceExpressionsWithFaceAlignmentTask, PredictSingleFaceExpressionsWithFaceAlignmentTask } from './PredictFaceExpressionsTask';
|
||||
|
||||
export class ComputeFaceDescriptorsTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> {
|
||||
constructor(
|
||||
|
|
|
@ -10,14 +10,8 @@ import { extendWithFaceLandmarks, WithFaceLandmarks } from '../factories/WithFac
|
|||
import { ComposableTask } from './ComposableTask';
|
||||
import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks';
|
||||
import { nets } from './nets';
|
||||
import {
|
||||
PredictAllAgeAndGenderWithFaceAlignmentTask,
|
||||
PredictSingleAgeAndGenderWithFaceAlignmentTask,
|
||||
} from './PredictAgeAndGenderTask';
|
||||
import {
|
||||
PredictAllFaceExpressionsWithFaceAlignmentTask,
|
||||
PredictSingleFaceExpressionsWithFaceAlignmentTask,
|
||||
} from './PredictFaceExpressionsTask';
|
||||
import { PredictAllAgeAndGenderWithFaceAlignmentTask, PredictSingleAgeAndGenderWithFaceAlignmentTask } from './PredictAgeAndGenderTask';
|
||||
import { PredictAllFaceExpressionsWithFaceAlignmentTask, PredictSingleFaceExpressionsWithFaceAlignmentTask } from './PredictFaceExpressionsTask';
|
||||
|
||||
export class DetectFaceLandmarksTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> {
|
||||
constructor(
|
||||
|
|
|
@ -11,12 +11,7 @@ import { ComposableTask } from './ComposableTask';
|
|||
import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks';
|
||||
import { extractAllFacesAndComputeResults, extractSingleFaceAndComputeResult } from './extractFacesAndComputeResults';
|
||||
import { nets } from './nets';
|
||||
import {
|
||||
PredictAllFaceExpressionsTask,
|
||||
PredictAllFaceExpressionsWithFaceAlignmentTask,
|
||||
PredictSingleFaceExpressionsTask,
|
||||
PredictSingleFaceExpressionsWithFaceAlignmentTask,
|
||||
} from './PredictFaceExpressionsTask';
|
||||
import { PredictAllFaceExpressionsTask, PredictAllFaceExpressionsWithFaceAlignmentTask, PredictSingleFaceExpressionsTask, PredictSingleFaceExpressionsWithFaceAlignmentTask } from './PredictFaceExpressionsTask';
|
||||
|
||||
export class PredictAgeAndGenderTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> {
|
||||
constructor(
|
||||
|
|
|
@ -10,12 +10,7 @@ import { ComposableTask } from './ComposableTask';
|
|||
import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks';
|
||||
import { extractAllFacesAndComputeResults, extractSingleFaceAndComputeResult } from './extractFacesAndComputeResults';
|
||||
import { nets } from './nets';
|
||||
import {
|
||||
PredictAllAgeAndGenderTask,
|
||||
PredictAllAgeAndGenderWithFaceAlignmentTask,
|
||||
PredictSingleAgeAndGenderTask,
|
||||
PredictSingleAgeAndGenderWithFaceAlignmentTask,
|
||||
} from './PredictAgeAndGenderTask';
|
||||
import { PredictAllAgeAndGenderTask, PredictAllAgeAndGenderWithFaceAlignmentTask, PredictSingleAgeAndGenderTask, PredictSingleAgeAndGenderWithFaceAlignmentTask } from './PredictAgeAndGenderTask';
|
||||
|
||||
export class PredictFaceExpressionsTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> {
|
||||
constructor(
|
||||
|
|
|
@ -41,7 +41,6 @@ export function padToSquare(
|
|||
paddingTensorAppend,
|
||||
]
|
||||
.filter((t) => !!t)
|
||||
// .map((t: tf.Tensor) => t.toFloat()) as tf.Tensor4D[]
|
||||
.map((t: tf.Tensor) => tf.cast(t, 'float32')) as tf.Tensor4D[];
|
||||
return tf.concat(tensorsToStack, paddingAxis);
|
||||
});
|
||||
|
|
|
@ -27,7 +27,7 @@ export class SsdMobilenetv1 extends NeuralNetwork<NetParams> {
|
|||
|
||||
return tf.tidy(() => {
|
||||
const batchTensor = tf.cast(input.toBatchTensor(512, false), 'float32');
|
||||
const x = tf.sub(tf.mul(batchTensor, tf.scalar(0.007843137718737125)), tf.scalar(1)) as tf.Tensor4D;
|
||||
const x = tf.sub(tf.div(batchTensor, 127.5), 1) as tf.Tensor4D; // input is normalized -1..1
|
||||
const features = mobileNetV1(x, params.mobilenetv1);
|
||||
|
||||
const { boxPredictions, classPredictions } = predictionLayer(features.out, features.conv11, params.prediction_layer);
|
||||
|
@ -40,10 +40,7 @@ export class SsdMobilenetv1 extends NeuralNetwork<NetParams> {
|
|||
return this.forwardInput(await toNetInput(input));
|
||||
}
|
||||
|
||||
public async locateFaces(
|
||||
input: TNetInput,
|
||||
options: ISsdMobilenetv1Options = {},
|
||||
): Promise<FaceDetection[]> {
|
||||
public async locateFaces(input: TNetInput, options: ISsdMobilenetv1Options = {}): Promise<FaceDetection[]> {
|
||||
const { maxResults, minConfidence } = new SsdMobilenetv1Options(options);
|
||||
|
||||
const netInput = await toNetInput(input);
|
||||
|
|
|
@ -1,11 +1,7 @@
|
|||
import * as tf from '../../dist/tfjs.esm';
|
||||
|
||||
import {
|
||||
ExtractWeightsFunction, ParamMapping, ConvParams, extractWeightsFactory,
|
||||
} from '../common/index';
|
||||
import {
|
||||
MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams,
|
||||
} from './types';
|
||||
import { ExtractWeightsFunction, ParamMapping, ConvParams, extractWeightsFactory } from '../common/index';
|
||||
import { MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams } from './types';
|
||||
|
||||
function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
|
||||
function extractDepthwiseConvParams(numChannels: number, mappedPrefix: string): MobileNetV1.DepthwiseConvParams {
|
||||
|
|
|
@ -1,12 +1,8 @@
|
|||
import * as tf from '../../dist/tfjs.esm';
|
||||
|
||||
import {
|
||||
ConvParams, disposeUnusedWeightTensors, extractWeightEntryFactory, ParamMapping,
|
||||
} from '../common/index';
|
||||
import { ConvParams, disposeUnusedWeightTensors, extractWeightEntryFactory, ParamMapping } from '../common/index';
|
||||
import { isTensor3D } from '../utils/index';
|
||||
import {
|
||||
BoxPredictionParams, MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams,
|
||||
} from './types';
|
||||
import { BoxPredictionParams, MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams } from './types';
|
||||
|
||||
function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
|
||||
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);
|
||||
|
|
|
@ -5,11 +5,7 @@ import { MobileNetV1 } from './types';
|
|||
|
||||
const epsilon = 0.0010000000474974513;
|
||||
|
||||
function depthwiseConvLayer(
|
||||
x: tf.Tensor4D,
|
||||
params: MobileNetV1.DepthwiseConvParams,
|
||||
strides: [number, number],
|
||||
) {
|
||||
function depthwiseConvLayer(x: tf.Tensor4D, params: MobileNetV1.DepthwiseConvParams, strides: [number, number]) {
|
||||
return tf.tidy(() => {
|
||||
let out = tf.depthwiseConv2d(x, params.filters, strides, 'same');
|
||||
out = tf.batchNorm<tf.Rank.R4>(
|
||||
|
|
|
@ -12,15 +12,12 @@ function IOU(boxes: tf.Tensor2D, i: number, j: number) {
|
|||
const xmaxJ = Math.max(boxesData[j][1], boxesData[j][3]);
|
||||
const areaI = (ymaxI - yminI) * (xmaxI - xminI);
|
||||
const areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
|
||||
if (areaI <= 0 || areaJ <= 0) {
|
||||
return 0.0;
|
||||
}
|
||||
if (areaI <= 0 || areaJ <= 0) return 0.0;
|
||||
const intersectionYmin = Math.max(yminI, yminJ);
|
||||
const intersectionXmin = Math.max(xminI, xminJ);
|
||||
const intersectionYmax = Math.min(ymaxI, ymaxJ);
|
||||
const intersectionXmax = Math.min(xmaxI, xmaxJ);
|
||||
const intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0)
|
||||
* Math.max(intersectionXmax - intersectionXmin, 0.0);
|
||||
const intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) * Math.max(intersectionXmax - intersectionXmin, 0.0);
|
||||
return intersectionArea / (areaI + areaJ - intersectionArea);
|
||||
}
|
||||
|
||||
|
@ -32,10 +29,7 @@ export function nonMaxSuppression(
|
|||
scoreThreshold: number,
|
||||
): number[] {
|
||||
const numBoxes = boxes.shape[0];
|
||||
const outputSize = Math.min(
|
||||
maxOutputSize,
|
||||
numBoxes,
|
||||
);
|
||||
const outputSize = Math.min(maxOutputSize, numBoxes);
|
||||
|
||||
const candidates = scores
|
||||
.map((score, boxIndex) => ({ score, boxIndex }))
|
||||
|
|
|
@ -13,17 +13,11 @@ function getCenterCoordinatesAndSizesLayer(x: tf.Tensor2D) {
|
|||
tf.add(vec[0], tf.div(sizes[0], tf.scalar(2))),
|
||||
tf.add(vec[1], tf.div(sizes[1], tf.scalar(2))),
|
||||
];
|
||||
return {
|
||||
sizes,
|
||||
centers,
|
||||
};
|
||||
return { sizes, centers };
|
||||
}
|
||||
|
||||
function decodeBoxesLayer(x0: tf.Tensor2D, x1: tf.Tensor2D) {
|
||||
const {
|
||||
sizes,
|
||||
centers,
|
||||
} = getCenterCoordinatesAndSizesLayer(x0);
|
||||
const { sizes, centers } = getCenterCoordinatesAndSizesLayer(x0);
|
||||
|
||||
const vec = tf.unstack(tf.transpose(x1, [1, 0]));
|
||||
const div0_out = tf.div(tf.mul(tf.exp(tf.div(vec[2], tf.scalar(5))), sizes[0]), tf.scalar(2));
|
||||
|
@ -42,11 +36,7 @@ function decodeBoxesLayer(x0: tf.Tensor2D, x1: tf.Tensor2D) {
|
|||
);
|
||||
}
|
||||
|
||||
export function outputLayer(
|
||||
boxPredictions: tf.Tensor4D,
|
||||
classPredictions: tf.Tensor4D,
|
||||
params: OutputLayerParams,
|
||||
) {
|
||||
export function outputLayer(boxPredictions: tf.Tensor4D, classPredictions: tf.Tensor4D, params: OutputLayerParams) {
|
||||
return tf.tidy(() => {
|
||||
const batchSize = boxPredictions.shape[0];
|
||||
|
||||
|
@ -54,25 +44,16 @@ export function outputLayer(
|
|||
tf.reshape(tf.tile(params.extra_dim, [batchSize, 1, 1]), [-1, 4]) as tf.Tensor2D,
|
||||
tf.reshape(boxPredictions, [-1, 4]) as tf.Tensor2D,
|
||||
);
|
||||
boxes = tf.reshape(
|
||||
boxes,
|
||||
[batchSize, (boxes.shape[0] / batchSize), 4],
|
||||
);
|
||||
boxes = tf.reshape(boxes, [batchSize, (boxes.shape[0] / batchSize), 4]);
|
||||
|
||||
const scoresAndClasses = tf.sigmoid(tf.slice(classPredictions, [0, 0, 1], [-1, -1, -1]));
|
||||
let scores = tf.slice(scoresAndClasses, [0, 0, 0], [-1, -1, 1]) as tf.Tensor;
|
||||
|
||||
scores = tf.reshape(
|
||||
scores,
|
||||
[batchSize, scores.shape[1] as number],
|
||||
);
|
||||
scores = tf.reshape(scores, [batchSize, scores.shape[1] as number]);
|
||||
|
||||
const boxesByBatch = tf.unstack(boxes) as tf.Tensor2D[];
|
||||
const scoresByBatch = tf.unstack(scores) as tf.Tensor1D[];
|
||||
|
||||
return {
|
||||
boxes: boxesByBatch,
|
||||
scores: scoresByBatch,
|
||||
};
|
||||
return { boxes: boxesByBatch, scores: scoresByBatch };
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1,11 +1,6 @@
|
|||
import * as tf from '../../dist/tfjs.esm';
|
||||
|
||||
import {
|
||||
disposeUnusedWeightTensors,
|
||||
extractWeightEntryFactory,
|
||||
loadSeparableConvParamsFactory,
|
||||
ParamMapping,
|
||||
} from '../common/index';
|
||||
import { disposeUnusedWeightTensors, extractWeightEntryFactory, loadSeparableConvParamsFactory, ParamMapping } from '../common/index';
|
||||
import { loadConvParamsFactory } from '../common/loadConvParamsFactory';
|
||||
import { range } from '../utils/index';
|
||||
import { MainBlockParams, ReductionBlockParams, TinyXceptionParams } from './types';
|
||||
|
|
Loading…
Reference in New Issue