From 72bd656e690e75d0354daf5a0c8937755f9aa522 Mon Sep 17 00:00:00 2001 From: BruceDai Date: Fri, 9 May 2025 11:37:21 +0800 Subject: [PATCH] Remove symmetric mode and add parameters validation for pad op --- src/lib/validate-input.js | 30 ++++++++++++++ src/pad.js | 12 +++--- test/pad_test.js | 87 --------------------------------------- 3 files changed, 36 insertions(+), 93 deletions(-) diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js index a560512..88ff160 100644 --- a/src/lib/validate-input.js +++ b/src/lib/validate-input.js @@ -831,3 +831,33 @@ export function validateTileParams(input, repetitions) { `Invalid repetitions ${repetitions} - it should be an Array of positive integers.`); } } + +export function validatePadParams(input, beginningPadding, endingPadding, mode) { + const inputRank = input.rank; + if (inputRank === 0) { + throw new Error(`The input's rank should be greater than 0.`); + } + if (beginningPadding.length !== inputRank) { + throw new Error(`Invalid beginningPadding, beginningPadding's size ${beginningPadding.length}` + + ` is not equal to input's rank ${inputRank}.`); + } + if (endingPadding.length !== inputRank) { + throw new Error(`Invalid endingPadding, endingPadding's size ${beginningPadding.length} is ` + + `not equal to input's rank ${inputRank}.`); + } + if (mode === 'reflection') { + const inputShape = input.shape; + for (let index = 0; index < inputRank; ++index) { + if (beginningPadding[index] >= inputShape[index]) { + throw new Error(`Invalid beginningPadding on reflection mode, beginningPadding[index] ` + + `${beginningPadding[index]} is greater than or equal to inputShape[index] ` + + `${inputShape[index]}.`); + } + if (endingPadding[index] >= inputShape[index]) { + throw new Error(`Invalid endingPadding on reflection mode, endingPadding[index] ` + + `${endingPadding[index]} is greater than or equal to inputShape[index] ` + + `${inputShape[index]}.`); + } + } + } +} diff --git a/src/pad.js b/src/pad.js index 24fca24..69ff9c2 100644 --- a/src/pad.js +++ b/src/pad.js @@ -1,6 +1,7 @@ 'use strict'; import {Tensor} from './lib/tensor.js'; +import {validatePadParams} from './lib/validate-input.js'; /** * Get mapped location from source tensor. @@ -24,16 +25,14 @@ function getMappedLocation(location, inputShape, beginningPadding, mode) { } } } else { - // reflection mode or symmetric mode - const offset = mode === 'symmetric' ? 1 : 0; + // reflection mode for (let i = 0; i < rank; i++) { if (mappedLocation[i] < beginningPadding[i]) { mappedLocation[i] = beginningPadding[i] + (beginningPadding[i] - mappedLocation[i]) - - beginningPadding[i] - offset; + beginningPadding[i]; } else if (mappedLocation[i] >= beginningPadding[i] + inputShape[i]) { mappedLocation[i] = beginningPadding[i] + inputShape[i] - 1 - - (mappedLocation[i] - (beginningPadding[i] + inputShape[i] -1)) - - beginningPadding[i] + offset; + (mappedLocation[i] - (beginningPadding[i] + inputShape[i] -1)) - beginningPadding[i]; } else { mappedLocation[i] -= beginningPadding[i]; } @@ -66,7 +65,7 @@ function updateOutputElement(index, source, destination, beginningPadding, mode, if (needPadding) { if (mode === 'constant') { result = value; - } else if (mode === 'edge' || mode === 'reflection' || mode === 'symmetric') { + } else if (mode === 'edge' || mode === 'reflection') { const targetLocation = getMappedLocation(location, sourceShape, beginningPadding, mode); result = source.getValueByLocation(targetLocation); } else { @@ -95,6 +94,7 @@ export function pad( mode='constant', value=0, } = {}) { + validatePadParams(input, beginningPadding, endingPadding, mode); const outputShape = input.shape.map((v, i) => v + beginningPadding[i] + endingPadding[i]); const output = new Tensor(outputShape); for (let i = 0; i < output.size; ++i) { diff --git a/test/pad_test.js b/test/pad_test.js index 102d368..b84a3a3 100644 --- a/test/pad_test.js +++ b/test/pad_test.js @@ -167,91 +167,4 @@ describe('test pad', function() { ], }); }); - - it('pad symmetric mode 2D', function() { - testPad( - { - shape: [2, 3], - values: [1, 2, 3, 4, 5, 6], - }, - [1, 2], - [1, 2], - { - mode: 'symmetric', - }, - { - shape: [4, 7], - values: [ - 2., 1., 1., 2., 3., 3., 2., - 2., 1., 1., 2., 3., 3., 2., - 5., 4., 4., 5., 6., 6., 5., - 5., 4., 4., 5., 6., 6., 5., - ], - }); - }); - - it('pad symmetric mode 4D', function() { - testPad( - { - shape: [2, 2, 3, 3], - values: [ - 0, 1, 2, - 3, 4, 5, - 6, 7, 8, - - 9, 10, 11, - 12, 13, 14, - 15, 16, 17, - - 18, 19, 20, - 21, 22, 23, - 24, 25, 26, - - 27, 28, 29, - 30, 31, 32, - 33, 34, 35, - ], - }, - [0, 0, 2, 2], - [0, 0, 2, 2], - { - mode: 'symmetric', - }, - { - shape: [2, 2, 7, 7], - values: [ - 4, 3, 3, 4, 5, 5, 4, - 1, 0, 0, 1, 2, 2, 1, - 1, 0, 0, 1, 2, 2, 1, - 4, 3, 3, 4, 5, 5, 4, - 7, 6, 6, 7, 8, 8, 7, - 7, 6, 6, 7, 8, 8, 7, - 4, 3, 3, 4, 5, 5, 4, - - 13, 12, 12, 13, 14, 14, 13, - 10, 9, 9, 10, 11, 11, 10, - 10, 9, 9, 10, 11, 11, 10, - 13, 12, 12, 13, 14, 14, 13, - 16, 15, 15, 16, 17, 17, 16, - 16, 15, 15, 16, 17, 17, 16, - 13, 12, 12, 13, 14, 14, 13, - - 22, 21, 21, 22, 23, 23, 22, - 19, 18, 18, 19, 20, 20, 19, - 19, 18, 18, 19, 20, 20, 19, - 22, 21, 21, 22, 23, 23, 22, - 25, 24, 24, 25, 26, 26, 25, - 25, 24, 24, 25, 26, 26, 25, - 22, 21, 21, 22, 23, 23, 22, - - 31, 30, 30, 31, 32, 32, 31, - 28, 27, 27, 28, 29, 29, 28, - 28, 27, 27, 28, 29, 29, 28, - 31, 30, 30, 31, 32, 32, 31, - 34, 33, 33, 34, 35, 35, 34, - 34, 33, 33, 34, 35, 35, 34, - 31, 30, 30, 31, 32, 32, 31, - ], - }); - }); });