diff --git a/src/core/LinearLayout.test.ts b/src/core/LinearLayout.test.ts index 77a3341..b2788c0 100644 --- a/src/core/LinearLayout.test.ts +++ b/src/core/LinearLayout.test.ts @@ -249,10 +249,54 @@ describe('LinearLayout', () => { expect(() => nonSquare.invert()).toThrowError() }) - it('should reject layouts where output bits are zero but inputs are not', () => { + it('should reject rank-deficient rectangular layouts', () => { + const matrix = [ + [1, 0, 1], + [1, 0, 1], + ] + const layout = LinearLayout.fromBitMatrix( + matrix, + [{ name: 'in', size: 8 }], + [ + { name: 'out0', size: 2 }, + { name: 'out1', size: 2 }, + ] + ) + + expect(layout.isInvertible()).toBe(false) + expect(() => layout.invert()).toThrowError( + 'Layout matrix is not invertible (matrix is not full row rank)' + ) + }) + + it('should handle layouts where output bits are zero but inputs are not', () => { const broadcast = LinearLayout.zeros1D(8, 'reg', 'zero') - expect(broadcast.isInvertible()).toBe(false) - expect(() => broadcast.invert()).toThrowError(/square and surjective/i) + expect(broadcast.isInvertible()).toBe(true) + const inverted = broadcast.invert() + const recovered = inverted.apply({ zero: 0 }) + expect(recovered.reg).toBe(0) + const roundTrip = broadcast.apply(recovered) + expect(roundTrip.zero).toBe(0) + }) + + it('should compute a right inverse for rectangular surjective layouts', () => { + const layout = new LinearLayout( + [ + ['in', [[1, 0], [0, 1], [1, 1]]], + ], + ['out0', 'out1'] + ) + expect(layout.isInvertible()).toBe(true) + const inverse = layout.invert() + const out0Size = layout.getOutDimSize('out0') + const out1Size = layout.getOutDimSize('out1') + for (let out0 = 0; out0 < out0Size; out0++) { + for (let out1 = 0; out1 < out1Size; out1++) { + const recoveredInput = inverse.apply({ out0, out1 }) + const roundTrip = layout.apply(recoveredInput) + expect(roundTrip).toEqual({ out0, out1 }) + } + } }) it('should invert layouts whose input/output spaces collapse to a single value', () => { diff --git a/src/core/LinearLayout.ts b/src/core/LinearLayout.ts index 8151946..35a793f 100644 --- a/src/core/LinearLayout.ts +++ b/src/core/LinearLayout.ts @@ -272,48 +272,71 @@ function basesToMatrix( } function invertBinaryMatrix(matrix: number[][]): number[][] { - const n = matrix.length - if (n === 0) { + const rowCount = matrix.length + if (rowCount === 0) { return [] } - const width = matrix[0]?.length ?? 0 - if (width !== n) { - throw new Error('Layout matrix must be square to invert') + const colCount = matrix[0]?.length ?? 0 + if (colCount < rowCount) { + throw new Error('Layout matrix must have at least as many columns as rows to be invertible') } const augmented: number[][] = matrix.map((row, i) => { - if (row.length !== width) { + if (row.length !== colCount) { throw new Error('Layout matrix rows must have consistent width') } const left = row.map((value) => (value & 1 ? 1 : 0)) - const right = new Array(n).fill(0) + const right = new Array(rowCount).fill(0) right[i] = 1 return [...left, ...right] }) - for (let col = 0; col < n; col++) { - let pivot = col - while (pivot < n && augmented[pivot]?.[col] !== 1) { + const totalCols = colCount + rowCount + let pivotRow = 0 + const pivotColumns: number[] = [] + + for (let col = 0; col < colCount && pivotRow < rowCount; col++) { + let pivot = pivotRow + while (pivot < rowCount && augmented[pivot]?.[col] !== 1) { pivot++ } - if (pivot === n) { - throw new Error('Layout matrix is not invertible') + if (pivot === rowCount) { + continue } - if (pivot !== col) { - const temp = augmented[col]! - augmented[col] = augmented[pivot]! + if (pivot !== pivotRow) { + const temp = augmented[pivotRow]! + augmented[pivotRow] = augmented[pivot]! augmented[pivot] = temp } - for (let row = 0; row < n; row++) { - if (row !== col && augmented[row]?.[col] === 1) { - for (let k = col; k < 2 * n; k++) { - augmented[row]![k]! ^= augmented[col]![k]! + for (let row = 0; row < rowCount; row++) { + if (row !== pivotRow && augmented[row]?.[col] === 1) { + for (let k = col; k < totalCols; k++) { + augmented[row]![k]! ^= augmented[pivotRow]![k]! } } } + pivotColumns[pivotRow] = col + pivotRow++ + } + + if (pivotRow !== rowCount) { + throw new Error('Layout matrix is not invertible (matrix is not full row rank)') + } + + const rightInverse = Array.from({ length: colCount }, () => new Array(rowCount).fill(0)) + for (let row = 0; row < rowCount; row++) { + const pivotCol = pivotColumns[row] + if (pivotCol === undefined) continue + const augRow = augmented[row] + if (!augRow) continue + const targetRow = rightInverse[pivotCol] + if (!targetRow) continue + for (let j = 0; j < rowCount; j++) { + targetRow[j] = augRow[colCount + j] ?? 0 + } } - return augmented.map((row) => row.slice(n)) + return rightInverse } /** @@ -584,13 +607,13 @@ export class LinearLayout { const inputBitCount = totalBitCount(inputSpecs) const outputBitCount = totalBitCount(outputSpecs) - if (inputBitCount !== outputBitCount) { - throw new Error('Cannot invert layout: layout must be square and surjective (input/output bit counts differ)') + if (inputBitCount < outputBitCount) { + throw new Error('Cannot invert layout: layout must be surjective (input bit count is smaller than output bit count)') } let inverseMatrix: number[][] if (outputBitCount === 0) { - inverseMatrix = [] + inverseMatrix = Array.from({ length: inputBitCount }, () => []) } else { const matrix = basesToMatrix(this.bases, inputSpecs, outputSpecs) inverseMatrix = invertBinaryMatrix(matrix) @@ -610,10 +633,10 @@ export class LinearLayout { } const inputBitCount = totalBitCount(inputSpecs) const outputBitCount = totalBitCount(outputSpecs) - if (inputBitCount !== outputBitCount) { + if (inputBitCount < outputBitCount) { return false } - if (inputBitCount === 0) { + if (outputBitCount === 0) { return true } const matrix = basesToMatrix(this.bases, inputSpecs, outputSpecs)