Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/dequantize_linear.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
'use strict';

import {mul, sub} from './binary.js';
import {blockwiseExpand} from './lib/broadcast.js';
import {validateQDQParams} from './lib/validate-input.js';

/**
* Elementwise operator to scale a low precision integer (typically uint8 with a zero-point bias)
Expand All @@ -12,5 +14,9 @@ import {mul, sub} from './binary.js';
* @return {Tensor}
*/
export function dequantizeLinear(input, scale, zeroPoint) {
return mul(sub(input, zeroPoint), scale);
validateQDQParams(input, scale, zeroPoint);

const broadcastedScale = blockwiseExpand(scale, input.shape);
const broadcastedZeroPoint = blockwiseExpand(zeroPoint, input.shape);
return mul(sub(input, broadcastedZeroPoint), broadcastedScale);
}
57 changes: 57 additions & 0 deletions src/lib/broadcast.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import {Tensor} from './tensor.js';
import {expand} from '../expand.js';
import {reshape} from '../reshape.js';

/**
* Broadcast a Tensor to a compatible shape NumPy-style.
Expand Down Expand Up @@ -70,3 +72,58 @@ export function getBroadcastShape(shapeA, shapeB) {
}
return outShape;
}

export function blockwiseExpand(input, outputShape) {
// Given the original input and a desired output shape, this expands each axis
// by repeating the block the number of times per that axis. Though, backend
// implementations might have much more efficient upsampling operators that
// can accept multiple dimensions to upsample all dimensions at once by
// integer multiples (like tile) using nearest neighbor resampling:
// output = resample(scale, {sizes: input.shape})

let output = input;

for (let axis = 0; axis < input.shape.length; ++axis) {
const oldShape = output.shape;
const oldDimensionLength = oldShape[axis];
const newDimensionLength = outputShape[axis];

if (newDimensionLength != oldDimensionLength) {
// Since tile/expand can only accept repetitions of entire dimension
// slices (not repeating individual elements along an axis), temporarily
// reshape the tensor to enable them to broadcast the elements up to the
// full block size, utilizing an inserted dimension of size 1.
const elementRepeatCount = newDimensionLength / oldDimensionLength;
const flattenedShape = getFlattenedShapeAroundAxis(oldShape, axis);
const unexpandedShape =
[flattenedShape[0], flattenedShape[1], 1, flattenedShape[2]];
const expandedShape = [
flattenedShape[0],
flattenedShape[1],
elementRepeatCount,
flattenedShape[2],
];
const reshapedInput = reshape(output, unexpandedShape);
output = expand(reshapedInput, expandedShape);

const newShape = [...oldShape];
newShape[axis] = newDimensionLength;
output = reshape(output, newShape);
}
}

return output;
}

// Compute the flattened shape before and after the given axis, yielding a
// 3-element list: e.g.
// - inputShape = [2,3,4,5,6] with axis = 2 yields shape [6,4,30].
// - inputShape = [4] with axis = 0 yields shape [1,4,1].
function getFlattenedShapeAroundAxis(inputShape, axis) {
axis = Math.max(Math.min(axis, inputShape.length - 1), 0);
const shapeBefore = inputShape.slice(0, axis);
const shapeAfter = inputShape.slice(axis + 1, inputShape.length);
const countBefore = shapeBefore.reduce((a, b) => a * b, 1);
const countAfter = shapeAfter.reduce((a, b) => a * b, 1);
return [countBefore, inputShape[axis], countAfter];
}
29 changes: 29 additions & 0 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -861,3 +861,32 @@ export function validatePadParams(input, beginningPadding, endingPadding, mode)
}
}
}

export function validateQDQParams(input, scale, zeroPoint) {
const inputRank = input.rank;
const inputShape = input.shape;
const scaleRank = scale.rank;
const scaleShape = scale.shape;
const zeroPointRank = zeroPoint.rank;
const zeroPointShape = zeroPoint.shape;

if (inputRank != scaleRank) {
throw new Error(
`The scale's rank ${scaleRank} is not equal to the input's rank ${inputRank}.`);
}

if (inputRank != zeroPointRank) {
throw new Error(
`The zeroPoint's rank ${zeroPointRank} is not equal to the input's rank ${inputRank}.`);
}

if (!scaleShape.every((size, index) => size === zeroPointShape[index])) {
throw new Error(
`The scale's shape [${scaleShape}] is not equal to the zeroPoint's shape [${zeroPointShape}].`);
}

if (!inputShape.every((size, index) => size % scaleShape[index]) === 0) {
throw new Error(
`The scale's shape or zeroPoint's shape [${scaleShape}] is not a multiple of the zeroPoint's shape [${zeroPointShape}].`);
}
}
27 changes: 24 additions & 3 deletions src/quantize_linear.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,26 @@
import {add, div} from './binary.js';
import {clamp} from './clamp.js';
import {unary} from './unary.js';
import {blockwiseExpand} from './lib/broadcast.js';
import {validateQDQParams} from './lib/validate-input.js';

/**
* This function finds the nearest integer for x.
* In case of halves, the rule is to round them to the nearest even integer.
* @param {Number} x
* @return {Number} An integer number
*/
function roundToNearestEvens(x) {
return Math.floor(x) % 2 == 0 ? Math.floor(x) : Math.ceil(x);
if (Number.isInteger(x)) {
return x;
} else {
if (Math.abs(x - Math.trunc(x)) === 0.5) {
// case of halves
return Math.floor(x) % 2 == 0 ? Math.floor(x) : Math.ceil(x);
} else {
return Math.round(x);
}
}
}

/**
Expand All @@ -19,9 +36,13 @@ function roundToNearestEvens(x) {
* @return {Tensor}
*/
export function quantizeLinear(input, scale, zeroPoint, dataType) {
const dividedOutput = div(input, scale);
validateQDQParams(input, scale, zeroPoint);

const broadcastedScale = blockwiseExpand(scale, input.shape);
const broadcastedZeroPoint = blockwiseExpand(zeroPoint, input.shape);
const dividedOutput = div(input, broadcastedScale);
const roundedOutput = unary(dividedOutput, (x) => roundToNearestEvens(x));
const addedOutput = add(roundedOutput, zeroPoint);
const addedOutput = add(roundedOutput, broadcastedZeroPoint);

let maxValue; let minValue;
switch (dataType) {
Expand Down
6 changes: 4 additions & 2 deletions test/dequantize_linear_test.js
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,19 @@ describe('test dequantizeLinear', function() {
],
},
{ // scale
shape: [3, 1],
shape: [1, 1, 3, 1],
value: [
1,
2,
4,
],
},
{ // zeroPoint
shape: [1],
shape: [1, 1, 3, 1],
value: [
0,
0,
0,
],
},
{ // expected
Expand Down
11 changes: 6 additions & 5 deletions test/quantize_linear_test.js
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ describe('test quantizeLinear', function() {
value: [
0, 2, 3, 255,
0, 1, 2, 255,
0, 0, 0, 200,
0, 0, 1, 200,
],
},
);
Expand All @@ -107,16 +107,17 @@ describe('test quantizeLinear', function() {
],
},
{ // scale
shape: [2, 1],
shape: [1, 1, 2, 1],
value: [
2,
2,
],
},
{ // zeroPoint
shape: [1],
shape: [1, 1, 2, 1],
value: [
10,
10,
],
},
'int8',
Expand All @@ -139,11 +140,11 @@ describe('test quantizeLinear', function() {
],
},
{ // scale
shape: [1],
shape: [1, 1, 1, 1],
value: [1],
},
{ // zeroPoint
shape: [1],
shape: [1, 1, 1, 1],
value: [100],
},
'int8',
Expand Down