diff --git a/src/lstm.js b/src/lstm.js index 99474d7..a1d0fd4 100644 --- a/src/lstm.js +++ b/src/lstm.js @@ -47,12 +47,15 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize, initialCellState, new Array(sizeOfShape(initialCellState)).fill(0)); } - let sequence; const currentWeight = []; const currentRecurrentWeight = []; const currentBias = []; const currentRecurrentBias = []; const currentPeepholeWeight = []; + let forwardSequence = null; + let backwardSequence = null; + let outputHidden = null; + let outputCell = null; for (let dir = 0; dir < numDirections; ++dir) { currentWeight.push(squeeze(slice(weight, [dir, 0, 0], [1, 4 * hiddenSize, inputSize]))); @@ -63,44 +66,60 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize, (squeeze(slice(recurrentBias, [dir, 0], [1, 4 * hiddenSize]))) : null); currentPeepholeWeight.push(peepholeWeight ? (squeeze(slice(peepholeWeight, [dir, 0], [1, 3 * hiddenSize]))) : null); - } - - for (let step = 0; step < steps; ++step) { - const currentHidden = []; - const currentCell = []; - let nextHidden = null; - let nextCell = null; - for (let dir = 0; dir < numDirections; ++dir) { - currentHidden.push(squeeze(slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize]))); - currentCell.push(squeeze(slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize]))); - } + let currentHidden = squeeze(slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize])); + let currentCell = squeeze(slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize])); - for (let dir = 0; dir < numDirections; ++dir) { - const slice0 = (dir == 1 || direction == 'backward' ? steps - step - 1 : step); + for (let step = 0; step < steps; ++step) { + const slice0 = dir === 1 || direction === 'backward' ? steps - step - 1 : step; const currentInput = squeeze(slice(input, [slice0, 0, 0], [1, batchSize, inputSize])); - const results = lstmCell( + [currentHidden, currentCell] = lstmCell( currentInput, currentWeight[dir], currentRecurrentWeight[dir], - currentHidden[dir], currentCell[dir], hiddenSize, {bias: currentBias[dir], + currentHidden, currentCell, hiddenSize, {bias: currentBias[dir], recurrentBias: currentRecurrentBias[dir], peepholeWeight: currentPeepholeWeight[dir], layout: layout, activations: activations}); - const output = reshape(results[0], [1, null, hiddenSize]); - const cell = reshape(results[1], [1, null, hiddenSize]); - - nextHidden = (nextHidden ? concat([nextHidden, output], 0) : output); - nextCell = (nextCell ? concat([nextCell, cell], 0) : cell); + if (returnSequence) { + // Expand hidden of 2D([batchSize, hiddenSize]) to + // 4D([steps, numDirections, batchSize, hiddenSize]) + const expandedHiddenAs4D = reshape(currentHidden, [1, 1, batchSize, hiddenSize]); + if (direction === 'forward' || (dir === 0 && direction === 'both')) { + forwardSequence = forwardSequence ? + concat([forwardSequence, expandedHiddenAs4D], 0) : + expandedHiddenAs4D; + } else if (direction === 'backward' || (dir === 1 && direction === 'both')) { + backwardSequence = backwardSequence ? + concat([expandedHiddenAs4D, backwardSequence], 0) : + expandedHiddenAs4D; + } + } } - hiddenState = nextHidden; - cellState = nextCell; + // Expand hidden of 2D([batchSize, hiddenSize]) to 3D([numDirections, batchSize, hiddenSize]) + const expandHiddenAs3D = reshape(currentHidden, [1, batchSize, hiddenSize]); + // Concat along axis 0 (numDirections dimension) + outputHidden = outputHidden ? concat([outputHidden, expandHiddenAs3D], 0) : expandHiddenAs3D; - if (returnSequence) { - nextHidden = reshape(nextHidden, [1, numDirections, null, hiddenSize]); - sequence = (sequence ? concat([sequence, nextHidden], 0) : nextHidden); - } + // Expand cell of 2D([batchSize, hiddenSize]) to 3D([numDirections, batchSize, hiddenSize]) + const expandCellAs3D = reshape(currentCell, [1, batchSize, hiddenSize]); + // Concat along axis 0 (numDirections dimension) + outputCell = outputCell ? concat([outputCell, expandCellAs3D], 0) : expandCellAs3D; } - return (sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState]); + if (returnSequence) { + // outputSequence: [steps, numDirections, batchSize, hiddenSize] + let outputSequence; + if (direction === 'forward') { + outputSequence = forwardSequence; + } else if (direction === 'backward') { + outputSequence = backwardSequence; + } else if (direction === 'both') { + // Concat along axis 1 (numDirections dimension) + outputSequence = concat([forwardSequence, backwardSequence], 1); + } + return [outputHidden, outputCell, outputSequence]; + } else { + return [outputHidden, outputCell]; + } } diff --git a/test/lstm_test.js b/test/lstm_test.js index c03d885..b887480 100644 --- a/test/lstm_test.js +++ b/test/lstm_test.js @@ -45,7 +45,6 @@ describe('test lstm', function() { input, weight, recurrentWeight, steps, hiddenSize, {bias, recurrentBias, peepholeWeight, initialHiddenState, initialCellState, returnSequence, activations}); - console.log('outputs: ', outputs); utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); utils.checkShape(outputs[1], [numDirections, batchSize, hiddenSize]); utils.checkShape(outputs[2], [steps, numDirections, batchSize, hiddenSize]); @@ -65,7 +64,7 @@ describe('test lstm', function() { } }); - it('lstm steps=2 direction="backward" returnSequence=true' + + it('lstm steps=2 direction="backward" returnSequence=true ' + 'activations=[relu, relu, relu]', function() { const steps = 2; const numDirections = 1; @@ -106,7 +105,6 @@ describe('test lstm', function() { input, weight, recurrentWeight, steps, hiddenSize, {bias, recurrentBias, peepholeWeight, initialHiddenState, initialCellState, direction, returnSequence, activations}); - console.log('outputs: ', outputs); utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); utils.checkShape(outputs[1], [numDirections, batchSize, hiddenSize]); utils.checkShape(outputs[2], [steps, numDirections, batchSize, hiddenSize]); @@ -114,14 +112,76 @@ describe('test lstm', function() { [10.469, 58.02899999999999, 74.529, 518.9490000000001], [5.51, 20.009999999999998, 19.11, 75.21000000000001], [ - 1, - 8, - 1, - 8, 10.469, 58.02899999999999, 74.529, 518.9490000000001, + 1, + 8, + 1, + 8, + ], + ]; + for (let i = 0; i < expected.length; ++i) { + utils.checkValue(outputs[i], expected[i]); + } + }); + + it('lstm steps=2 direction="both" returnSequence=true', function() { + const steps = 2; + const numDirections = 2; + const batchSize = 2; + const inputSize = 2; + const hiddenSize = 2; + const input = new Tensor([steps, batchSize, inputSize], + new Float32Array([1, 2, 2, 1, 3, 4, 1, 2])); + const weight = new Tensor([numDirections, 4 * hiddenSize, inputSize], + new Float32Array([ + 1, -1, 2, -2, 1, -1, 2, -2, + 1, -1, 2, -2, 1, -1, 2, -2, + 1, -1, 2, -2, 1, -1, 2, -2, + 1, -1, 2, -2, 1, -1, 2, -2, + ])); + const recurrentWeight = new Tensor([numDirections, 4 * hiddenSize, hiddenSize], + new Array(2 * 4 * hiddenSize * hiddenSize).fill(0.1)); + const bias = new Tensor([numDirections, 4 * hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const recurrentBias = new Tensor([numDirections, 4 * hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const returnSequence = true; + const direction = 'both'; + const outputs = lstm( + input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, direction, returnSequence}); + utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); + utils.checkShape(outputs[1], [numDirections, batchSize, hiddenSize]); + utils.checkShape(outputs[2], [steps, numDirections, batchSize, hiddenSize]); + const expected = [ + [ + 0.5764073262004139, 0.8236227651782412, + 0.6612355785279247, 0.8442635760318142, + 0.5764073262004139, 0.8236227651782412, + 0.8635294727880538, 0.9491350760903781, + ], + [ + 1.0171455721466105, 1.6205496282195793, + 1.338846378789257, 1.7642604746965693, + 1.0171455721466105, 1.6205496282195793, + 1.485626937219704, 1.8449554199024933, + ], + [ + 0.36960635293570576, 0.6082834181835157, + 0.7037753329989016, 0.7586680430344475, + 0.5764073262004139, 0.8236227651782412, + 0.8635294727880538, 0.9491350760903781, + 0.5764073262004139, 0.8236227651782412, + 0.6612355785279247, 0.8442635760318142, + 0.36960635293570576, 0.6082834181835157, + 0.36960635293570576, 0.6082834181835157, ], ]; for (let i = 0; i < expected.length; ++i) {