From 814c241b6a0509a966b458a3017f193e91451d1a Mon Sep 17 00:00:00 2001 From: BruceDai Date: Mon, 6 Jan 2025 15:00:10 +0800 Subject: [PATCH 1/7] update reshape without null --- src/lstm.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lstm.js b/src/lstm.js index 99474d7..bd3f660 100644 --- a/src/lstm.js +++ b/src/lstm.js @@ -86,8 +86,8 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize, recurrentBias: currentRecurrentBias[dir], peepholeWeight: currentPeepholeWeight[dir], layout: layout, activations: activations}); - const output = reshape(results[0], [1, null, hiddenSize]); - const cell = reshape(results[1], [1, null, hiddenSize]); + const output = reshape(results[0], [1, batchSize, hiddenSize]); + const cell = reshape(results[1], [1, batchSize, hiddenSize]); nextHidden = (nextHidden ? concat([nextHidden, output], 0) : output); nextCell = (nextCell ? concat([nextCell, cell], 0) : cell); @@ -97,7 +97,7 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize, cellState = nextCell; if (returnSequence) { - nextHidden = reshape(nextHidden, [1, numDirections, null, hiddenSize]); + nextHidden = reshape(nextHidden, [1, numDirections, batchSize, hiddenSize]); sequence = (sequence ? concat([sequence, nextHidden], 0) : nextHidden); } } From f7e4fff474e1eafe67e7836ef2d66d361b572b12 Mon Sep 17 00:00:00 2001 From: BruceDai Date: Tue, 7 Jan 2025 10:46:39 +0800 Subject: [PATCH 2/7] fix output sequence issue with backward and both directions --- src/lstm.js | 22 +++++++++++++- test/lstm_test.js | 74 ++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 88 insertions(+), 8 deletions(-) diff --git a/src/lstm.js b/src/lstm.js index bd3f660..6068651 100644 --- a/src/lstm.js +++ b/src/lstm.js @@ -3,8 +3,10 @@ import {concat} from './concat.js'; import {lstmCell} from './lstm_cell.js'; import {reshape, squeeze} from './reshape.js'; +import {reverse} from './reverse.js'; import {sizeOfShape, Tensor} from './lib/tensor.js'; import {sigmoid} from './sigmoid.js'; +import {split} from './split.js'; import {slice} from './slice.js'; import {tanh} from './tanh.js'; import {validateLstmParams} from './lib/validate-input.js'; @@ -85,10 +87,11 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize, currentHidden[dir], currentCell[dir], hiddenSize, {bias: currentBias[dir], recurrentBias: currentRecurrentBias[dir], peepholeWeight: currentPeepholeWeight[dir], layout: layout, activations: activations}); - + // Expand [batchSize, hiddenSize] to [numDirections, batchSize, hiddenSize] const output = reshape(results[0], [1, batchSize, hiddenSize]); const cell = reshape(results[1], [1, batchSize, hiddenSize]); + // Concat along 0 axis (for numDirections dimension) nextHidden = (nextHidden ? concat([nextHidden, output], 0) : output); nextCell = (nextCell ? concat([nextCell, cell], 0) : cell); } @@ -97,10 +100,27 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize, cellState = nextCell; if (returnSequence) { + // Expand [numDirections, batchSize, hiddenSize] to + // [steps, numDirections, batchSize, hiddenSize] nextHidden = reshape(nextHidden, [1, numDirections, batchSize, hiddenSize]); + // Concat output sequence along 0 axis (for steps dimension) sequence = (sequence ? concat([sequence, nextHidden], 0) : nextHidden); } } + if (direction === 'backward') { + // Reverse output sequence alog [0] axes (for steps dimension) + sequence = reverse(sequence, {axes: [0]}); + } else if (direction === 'both') { + // Split output sequence into forward-sequence and backward-sequence two sequences along 1 axis + // (for numDirections dimension) + const [sequenceForward, sequenceBackward] = split(sequence, 2, {axis: 1}); + // Reverse backward-sequence alog [0] axes (for only steps dimension) + const reversedSequenceBackward = reverse(sequenceBackward, {axes: [0]}); + sequence = concat([sequenceForward, reversedSequenceBackward], 1); + } else { + // No need update sequence for 'forward' direction + } + return (sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState]); } diff --git a/test/lstm_test.js b/test/lstm_test.js index c03d885..b887480 100644 --- a/test/lstm_test.js +++ b/test/lstm_test.js @@ -45,7 +45,6 @@ describe('test lstm', function() { input, weight, recurrentWeight, steps, hiddenSize, {bias, recurrentBias, peepholeWeight, initialHiddenState, initialCellState, returnSequence, activations}); - console.log('outputs: ', outputs); utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); utils.checkShape(outputs[1], [numDirections, batchSize, hiddenSize]); utils.checkShape(outputs[2], [steps, numDirections, batchSize, hiddenSize]); @@ -65,7 +64,7 @@ describe('test lstm', function() { } }); - it('lstm steps=2 direction="backward" returnSequence=true' + + it('lstm steps=2 direction="backward" returnSequence=true ' + 'activations=[relu, relu, relu]', function() { const steps = 2; const numDirections = 1; @@ -106,7 +105,6 @@ describe('test lstm', function() { input, weight, recurrentWeight, steps, hiddenSize, {bias, recurrentBias, peepholeWeight, initialHiddenState, initialCellState, direction, returnSequence, activations}); - console.log('outputs: ', outputs); utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); utils.checkShape(outputs[1], [numDirections, batchSize, hiddenSize]); utils.checkShape(outputs[2], [steps, numDirections, batchSize, hiddenSize]); @@ -114,14 +112,76 @@ describe('test lstm', function() { [10.469, 58.02899999999999, 74.529, 518.9490000000001], [5.51, 20.009999999999998, 19.11, 75.21000000000001], [ - 1, - 8, - 1, - 8, 10.469, 58.02899999999999, 74.529, 518.9490000000001, + 1, + 8, + 1, + 8, + ], + ]; + for (let i = 0; i < expected.length; ++i) { + utils.checkValue(outputs[i], expected[i]); + } + }); + + it('lstm steps=2 direction="both" returnSequence=true', function() { + const steps = 2; + const numDirections = 2; + const batchSize = 2; + const inputSize = 2; + const hiddenSize = 2; + const input = new Tensor([steps, batchSize, inputSize], + new Float32Array([1, 2, 2, 1, 3, 4, 1, 2])); + const weight = new Tensor([numDirections, 4 * hiddenSize, inputSize], + new Float32Array([ + 1, -1, 2, -2, 1, -1, 2, -2, + 1, -1, 2, -2, 1, -1, 2, -2, + 1, -1, 2, -2, 1, -1, 2, -2, + 1, -1, 2, -2, 1, -1, 2, -2, + ])); + const recurrentWeight = new Tensor([numDirections, 4 * hiddenSize, hiddenSize], + new Array(2 * 4 * hiddenSize * hiddenSize).fill(0.1)); + const bias = new Tensor([numDirections, 4 * hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const recurrentBias = new Tensor([numDirections, 4 * hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const returnSequence = true; + const direction = 'both'; + const outputs = lstm( + input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, direction, returnSequence}); + utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); + utils.checkShape(outputs[1], [numDirections, batchSize, hiddenSize]); + utils.checkShape(outputs[2], [steps, numDirections, batchSize, hiddenSize]); + const expected = [ + [ + 0.5764073262004139, 0.8236227651782412, + 0.6612355785279247, 0.8442635760318142, + 0.5764073262004139, 0.8236227651782412, + 0.8635294727880538, 0.9491350760903781, + ], + [ + 1.0171455721466105, 1.6205496282195793, + 1.338846378789257, 1.7642604746965693, + 1.0171455721466105, 1.6205496282195793, + 1.485626937219704, 1.8449554199024933, + ], + [ + 0.36960635293570576, 0.6082834181835157, + 0.7037753329989016, 0.7586680430344475, + 0.5764073262004139, 0.8236227651782412, + 0.8635294727880538, 0.9491350760903781, + 0.5764073262004139, 0.8236227651782412, + 0.6612355785279247, 0.8442635760318142, + 0.36960635293570576, 0.6082834181835157, + 0.36960635293570576, 0.6082834181835157, ], ]; for (let i = 0; i < expected.length; ++i) { From 613f056709e052f01bb6ade25fba774652ef2b43 Mon Sep 17 00:00:00 2001 From: BruceDai Date: Tue, 7 Jan 2025 16:48:18 +0800 Subject: [PATCH 3/7] address comments --- src/lstm.js | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/lstm.js b/src/lstm.js index 6068651..0a557db 100644 --- a/src/lstm.js +++ b/src/lstm.js @@ -91,7 +91,7 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize, const output = reshape(results[0], [1, batchSize, hiddenSize]); const cell = reshape(results[1], [1, batchSize, hiddenSize]); - // Concat along 0 axis (for numDirections dimension) + // Concat along axis 0 (numDirections dimension) nextHidden = (nextHidden ? concat([nextHidden, output], 0) : output); nextCell = (nextCell ? concat([nextCell, cell], 0) : cell); } @@ -103,19 +103,22 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize, // Expand [numDirections, batchSize, hiddenSize] to // [steps, numDirections, batchSize, hiddenSize] nextHidden = reshape(nextHidden, [1, numDirections, batchSize, hiddenSize]); - // Concat output sequence along 0 axis (for steps dimension) + // Concat output sequence along axis 0 (steps dimension) sequence = (sequence ? concat([sequence, nextHidden], 0) : nextHidden); } } if (direction === 'backward') { - // Reverse output sequence alog [0] axes (for steps dimension) + // Refer to https://www.w3.org/TR/webnn/#api-mlgraphbuilder-lstm, Spec says the + // sequence should contain every output from each time step in the temporal sequence, while + // the loop for steps concatenates sequence in a reversed order when direction is backward, + // so here need reverse output sequence along axis 0 (steps dimension). sequence = reverse(sequence, {axes: [0]}); } else if (direction === 'both') { - // Split output sequence into forward-sequence and backward-sequence two sequences along 1 axis - // (for numDirections dimension) + // Split output sequence into forward-sequence and backward-sequence two sequences along axis 1 + // (numDirections dimension) const [sequenceForward, sequenceBackward] = split(sequence, 2, {axis: 1}); - // Reverse backward-sequence alog [0] axes (for only steps dimension) + // Reverse backward-sequence along axis 0 (steps dimension) const reversedSequenceBackward = reverse(sequenceBackward, {axes: [0]}); sequence = concat([sequenceForward, reversedSequenceBackward], 1); } else { From 58cb672e6adc721e101b89ce650c1b8676b16d0c Mon Sep 17 00:00:00 2001 From: BruceDai Date: Wed, 8 Jan 2025 11:49:50 +0800 Subject: [PATCH 4/7] fix sequence issue with returnSequence option --- src/lstm.js | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/lstm.js b/src/lstm.js index 0a557db..15f2618 100644 --- a/src/lstm.js +++ b/src/lstm.js @@ -108,22 +108,25 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize, } } - if (direction === 'backward') { - // Refer to https://www.w3.org/TR/webnn/#api-mlgraphbuilder-lstm, Spec says the - // sequence should contain every output from each time step in the temporal sequence, while - // the loop for steps concatenates sequence in a reversed order when direction is backward, - // so here need reverse output sequence along axis 0 (steps dimension). - sequence = reverse(sequence, {axes: [0]}); - } else if (direction === 'both') { - // Split output sequence into forward-sequence and backward-sequence two sequences along axis 1 - // (numDirections dimension) - const [sequenceForward, sequenceBackward] = split(sequence, 2, {axis: 1}); - // Reverse backward-sequence along axis 0 (steps dimension) - const reversedSequenceBackward = reverse(sequenceBackward, {axes: [0]}); - sequence = concat([sequenceForward, reversedSequenceBackward], 1); + if (returnSequence) { + if (direction === 'backward') { + // Refer to https://www.w3.org/TR/webnn/#api-mlgraphbuilder-lstm, Spec says the + // sequence should contain every output from each time step in the temporal sequence, while + // the loop for steps concatenates sequence in a reversed order when direction is backward, + // so here need reverse output sequence along axis 0 (steps dimension). + sequence = reverse(sequence, {axes: [0]}); + } else if (direction === 'both') { + // Split output sequence into forward-sequence and backward-sequence two sequences along axis 1 + // (numDirections dimension) + const [sequenceForward, sequenceBackward] = split(sequence, 2, {axis: 1}); + // Reverse backward-sequence along axis 0 (steps dimension) + const reversedSequenceBackward = reverse(sequenceBackward, {axes: [0]}); + sequence = concat([sequenceForward, reversedSequenceBackward], 1); + } else { + // No need update sequence for 'forward' direction + } + return [hiddenState, cellState, sequence]; } else { - // No need update sequence for 'forward' direction + return [hiddenState, cellState]; } - - return (sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState]); } From ffceac78d13d7c8072b31ce4ed49ae01e2ca5cd9 Mon Sep 17 00:00:00 2001 From: BruceDai Date: Wed, 8 Jan 2025 15:07:03 +0800 Subject: [PATCH 5/7] fix lint error --- src/lstm.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lstm.js b/src/lstm.js index 15f2618..f1fe5f2 100644 --- a/src/lstm.js +++ b/src/lstm.js @@ -116,8 +116,8 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize, // so here need reverse output sequence along axis 0 (steps dimension). sequence = reverse(sequence, {axes: [0]}); } else if (direction === 'both') { - // Split output sequence into forward-sequence and backward-sequence two sequences along axis 1 - // (numDirections dimension) + // Split output sequence into forward-sequence and backward-sequence two sequences along axis + // 1 (numDirections dimension) const [sequenceForward, sequenceBackward] = split(sequence, 2, {axis: 1}); // Reverse backward-sequence along axis 0 (steps dimension) const reversedSequenceBackward = reverse(sequenceBackward, {axes: [0]}); From 0ae6903125b22bed9091a9b56132f418b4c052c3 Mon Sep 17 00:00:00 2001 From: BruceDai Date: Wed, 8 Jan 2025 18:41:29 +0800 Subject: [PATCH 6/7] refine implementation of lstm --- src/lstm.js | 93 +++++++++++++++++++++++++---------------------------- 1 file changed, 44 insertions(+), 49 deletions(-) diff --git a/src/lstm.js b/src/lstm.js index f1fe5f2..18c5841 100644 --- a/src/lstm.js +++ b/src/lstm.js @@ -3,10 +3,8 @@ import {concat} from './concat.js'; import {lstmCell} from './lstm_cell.js'; import {reshape, squeeze} from './reshape.js'; -import {reverse} from './reverse.js'; import {sizeOfShape, Tensor} from './lib/tensor.js'; import {sigmoid} from './sigmoid.js'; -import {split} from './split.js'; import {slice} from './slice.js'; import {tanh} from './tanh.js'; import {validateLstmParams} from './lib/validate-input.js'; @@ -49,12 +47,17 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize, initialCellState, new Array(sizeOfShape(initialCellState)).fill(0)); } - let sequence; const currentWeight = []; const currentRecurrentWeight = []; const currentBias = []; const currentRecurrentBias = []; const currentPeepholeWeight = []; + let forwardSequence = null; + let backwardSequence = null; + let currentHidden; + let currentCell; + let outputHidden; + let outputCell; for (let dir = 0; dir < numDirections; ++dir) { currentWeight.push(squeeze(slice(weight, [dir, 0, 0], [1, 4 * hiddenSize, inputSize]))); @@ -65,68 +68,60 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize, (squeeze(slice(recurrentBias, [dir, 0], [1, 4 * hiddenSize]))) : null); currentPeepholeWeight.push(peepholeWeight ? (squeeze(slice(peepholeWeight, [dir, 0], [1, 3 * hiddenSize]))) : null); - } - - for (let step = 0; step < steps; ++step) { - const currentHidden = []; - const currentCell = []; - let nextHidden = null; - let nextCell = null; - for (let dir = 0; dir < numDirections; ++dir) { - currentHidden.push(squeeze(slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize]))); - currentCell.push(squeeze(slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize]))); - } + currentHidden = squeeze(slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize])); + currentCell = squeeze(slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize])); - for (let dir = 0; dir < numDirections; ++dir) { - const slice0 = (dir == 1 || direction == 'backward' ? steps - step - 1 : step); + for (let step = 0; step < steps; ++step) { + const slice0 = dir === 1 || direction === 'backward' ? steps - step - 1 : step; const currentInput = squeeze(slice(input, [slice0, 0, 0], [1, batchSize, inputSize])); - const results = lstmCell( + [currentHidden, currentCell] = lstmCell( currentInput, currentWeight[dir], currentRecurrentWeight[dir], - currentHidden[dir], currentCell[dir], hiddenSize, {bias: currentBias[dir], + currentHidden, currentCell, hiddenSize, {bias: currentBias[dir], recurrentBias: currentRecurrentBias[dir], peepholeWeight: currentPeepholeWeight[dir], layout: layout, activations: activations}); - // Expand [batchSize, hiddenSize] to [numDirections, batchSize, hiddenSize] - const output = reshape(results[0], [1, batchSize, hiddenSize]); - const cell = reshape(results[1], [1, batchSize, hiddenSize]); - // Concat along axis 0 (numDirections dimension) - nextHidden = (nextHidden ? concat([nextHidden, output], 0) : output); - nextCell = (nextCell ? concat([nextCell, cell], 0) : cell); + if (returnSequence) { + // Expand hidden of 2D([batchSize, hiddenSize]) to + // 4D([steps, numDirections, batchSize, hiddenSize]) + const expandedHiddenAs4D = reshape(currentHidden, [1, 1, batchSize, hiddenSize]); + if (direction === 'forward' || (dir === 0 && direction === 'both')) { + forwardSequence = forwardSequence ? + concat([forwardSequence, expandedHiddenAs4D], 0) : + expandedHiddenAs4D; + } else if (direction === 'backward' || (dir === 1 && direction === 'both')) { + backwardSequence = backwardSequence ? + concat([expandedHiddenAs4D, backwardSequence], 0) : + expandedHiddenAs4D; + } + } } - hiddenState = nextHidden; - cellState = nextCell; + // Expand hidden of 2D([batchSize, hiddenSize]) to 3D([numDirections, batchSize, hiddenSize]) + const expandHiddenAs3D = reshape(currentHidden, [1, batchSize, hiddenSize]); + // Concat along axis 0 (numDirections dimension) + outputHidden = outputHidden ? concat([outputHidden, expandHiddenAs3D], 0) : expandHiddenAs3D; - if (returnSequence) { - // Expand [numDirections, batchSize, hiddenSize] to - // [steps, numDirections, batchSize, hiddenSize] - nextHidden = reshape(nextHidden, [1, numDirections, batchSize, hiddenSize]); - // Concat output sequence along axis 0 (steps dimension) - sequence = (sequence ? concat([sequence, nextHidden], 0) : nextHidden); - } + // Expand cell of 2D([batchSize, hiddenSize]) to 3D([numDirections, batchSize, hiddenSize]) + const expandCellAs3D = reshape(currentCell, [1, batchSize, hiddenSize]); + // Concat along axis 0 (numDirections dimension) + outputCell = outputCell ? concat([outputCell, expandCellAs3D], 0) : expandCellAs3D; } if (returnSequence) { - if (direction === 'backward') { - // Refer to https://www.w3.org/TR/webnn/#api-mlgraphbuilder-lstm, Spec says the - // sequence should contain every output from each time step in the temporal sequence, while - // the loop for steps concatenates sequence in a reversed order when direction is backward, - // so here need reverse output sequence along axis 0 (steps dimension). - sequence = reverse(sequence, {axes: [0]}); + // outputSequence: [steps, numDirections, batchSize, hiddenSize] + let outputSequence; + if (direction === 'forward') { + outputSequence = forwardSequence; + } else if (direction === 'backward') { + outputSequence = backwardSequence; } else if (direction === 'both') { - // Split output sequence into forward-sequence and backward-sequence two sequences along axis - // 1 (numDirections dimension) - const [sequenceForward, sequenceBackward] = split(sequence, 2, {axis: 1}); - // Reverse backward-sequence along axis 0 (steps dimension) - const reversedSequenceBackward = reverse(sequenceBackward, {axes: [0]}); - sequence = concat([sequenceForward, reversedSequenceBackward], 1); - } else { - // No need update sequence for 'forward' direction + // Concat along axis 1 (numDirections dimension) + outputSequence = concat([forwardSequence, backwardSequence], 1); } - return [hiddenState, cellState, sequence]; + return [outputHidden, outputCell, outputSequence]; } else { - return [hiddenState, cellState]; + return [outputHidden, outputCell]; } } From d16218b9c2906313f3888eef2d7de871dd1bb5dd Mon Sep 17 00:00:00 2001 From: BruceDai Date: Thu, 9 Jan 2025 08:08:39 +0800 Subject: [PATCH 7/7] address comments --- src/lstm.js | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/lstm.js b/src/lstm.js index 18c5841..a1d0fd4 100644 --- a/src/lstm.js +++ b/src/lstm.js @@ -54,10 +54,8 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize, const currentPeepholeWeight = []; let forwardSequence = null; let backwardSequence = null; - let currentHidden; - let currentCell; - let outputHidden; - let outputCell; + let outputHidden = null; + let outputCell = null; for (let dir = 0; dir < numDirections; ++dir) { currentWeight.push(squeeze(slice(weight, [dir, 0, 0], [1, 4 * hiddenSize, inputSize]))); @@ -69,8 +67,8 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize, currentPeepholeWeight.push(peepholeWeight ? (squeeze(slice(peepholeWeight, [dir, 0], [1, 3 * hiddenSize]))) : null); - currentHidden = squeeze(slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize])); - currentCell = squeeze(slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize])); + let currentHidden = squeeze(slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize])); + let currentCell = squeeze(slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize])); for (let step = 0; step < steps; ++step) { const slice0 = dir === 1 || direction === 'backward' ? steps - step - 1 : step;