Skip to content

matrix multiply fails when multiplying slice of transposed matrix, but succeeds on copy #850

@philwalk

Description

@philwalk

This seems like a bug to me, hopefully it's not a misunderstanding on my part.

The expression on line 35 executes as expected, but the seemingly identical expression on line 38 causes a runtime crash. The main difference between the expression X.t and (X.t).copy AFAICT is the tranposed matrix has an offset of 1 column into a larger underlying array, due to the way it was created.

The behavior is the same with both scala 2.13.10 and scala 3.2.1.
I have tested in the following environments:

  • Windows 10 with jdk 11, scala 2.13.10 and scala 3.2.1
  • WSL Ubuntu 22.04.1 with jdk 17 and scala 3.2.2-RC1
  • Linux Mint 19.3 / Ubuntu 18.04, jdk 11, scala 3.2.2-RC1

The code:

import breeze.linalg._

object ExpressionCrash {

  def main(args: Array[String]): Unit = {
    val yAndX = testmatrix
    printf("yAndX: %d x %d\n%s\n", yAndX.rows, yAndX.cols, yAndX)
    val y = yAndX(::,0) // column zero is y
    printf("y: %d\n%s\n", y.size, y)
    val X = yAndX(::,1 until yAndX.cols) // remainder of columns are predictors
    printf("X:  %d x %d\n%s\n", X.rows, X.cols, X)

    val N = X.cols.toDouble
    val Iᴛ: DenseMatrix[Double] = DenseMatrix.eye[Double](X.rows)
    val : DenseMatrix[Double] = DenseMatrix.eye[Double](X.cols)

    val ιᴛ: DenseVector[Double] = DenseVector.ones[Double](X.rows)
    val ιɴ: DenseVector[Double] = DenseVector.ones[Double](X.cols)

    val T = X.rows.toDouble
    val Jᴛ: DenseMatrix[Double] = I- 1.0/T * ιᴛ * ιᴛ.t
    val : DenseMatrix[Double] = Iɴ - 1.0/T * ιɴ * ιɴ.t

    val ȳ = ιᴛ.t * y/T
    printf("ȳ: %s\n", ȳ)

    var Wxz: DenseVector[Double] = DenseVector.zeros(X.cols)
    val Z = DenseVector.ones[Double](X.rows)

    printf(" Jɴ dims: %s x %s\n", Jɴ.rows,  Jɴ.cols)
    printf("X.t dims: %s x %s\n", X.t.rows, X.t.cols)
    printf(" Jᴛ dims: %s x %s\n", Jᴛ.rows,  Jᴛ.cols)
    printf("  Z dims: %s x %s\n", X.rows,   X.cols)

    Wxz = Jɴ * (X.t).copy * J* Z                                   // <<<<<< this succeeds
    printf("Wxz: %s\n", Wxz.data.take(4).toSeq)

    Wxz = Jɴ * X.t        * J* Z                                   // <<<<<< this crashes
    printf("Wxz: %s\n", Wxz.data.take(4).toSeq)
  }

  def J(T: Double): DenseMatrix[Double] = {
    val Iᴛ: DenseMatrix[Double] = DenseMatrix.eye[Double](T.toInt)
    val ιᴛ = DenseVector.ones[Double](T.toInt)
    val Jᴛ = I- 1.0/T * ιᴛ * ιᴛ.t
    Jᴛ
  }

  lazy val (testrows, testcols) = (4, 3)
  lazy val testmatrix = new DenseMatrix(testcols, testrows, testdata.toArray).t
  lazy val testdata = Seq(
  67,  33,  64,
  62,  69,  78,
  76,  58,  93,
  57,  60,  60,
  ).map( _.toDouble ).toArray
}

Output with stack dump:

Details
yAndX: 4 x 3
67.0  33.0  64.0
62.0  69.0  78.0
76.0  58.0  93.0
57.0  60.0  60.0
y: 4
DenseVector(67.0, 62.0, 76.0, 57.0)
X:  4 x 2
33.0  64.0
69.0  78.0
58.0  93.0
60.0  60.0
Nov 26, 2022 10:38:52 AM dev.ludovic.netlib.blas.InstanceBuilder initializeNative
WARNING: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
return java 11 instance
ȳ: 65.5
 Jɴ dims: 2 x 2
X.t dims: 2 x 4
 Jᴛ dims: 4 x 4
  Z dims: 4 x 2
Wxz: ArraySeq(0.0, 0.0)
Exception in thread "main" java.lang.IndexOutOfBoundsException: Index 12 out of bounds for length 12
        at dev.ludovic.netlib.blas.AbstractBLAS.checkIndex(AbstractBLAS.java:51)
        at dev.ludovic.netlib.blas.AbstractBLAS.dgemm(AbstractBLAS.java:295)
        at breeze.linalg.operators.DenseMatrixMultiplyOps$impl_OpMulMatrix_DMD_DMD_eq_DMD$.apply(DenseMatrixOps.expanded.scala:192)
        at breeze.linalg.operators.DenseMatrixMultiplyOps$impl_OpMulMatrix_DMD_DMD_eq_DMD$.apply(DenseMatrixOps.expanded.scala:186)
        at breeze.linalg.ImmutableNumericOps.$times(NumericOps.scala:119)
        at breeze.linalg.ImmutableNumericOps.$times$(NumericOps.scala:27)
        at breeze.linalg.DenseMatrix.$times(DenseMatrix.scala:52)
        at ExpressionCrash$.main(ExpressionCrash.scala:38)
        at ExpressionCrash.main(ExrpressionCrash.scala)
        at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
        at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.base/java.lang.reflect.Method.invoke(Method.java:566)
        at dotty.tools.scripting.ScriptingDriver.compileAndRun(ScriptingDriver.scala:36)
        at dotty.tools.scripting.Main$.main(Main.scala:45)
        at dotty.tools.MainGenericRunner$.run$1(MainGenericRunner.scala:249)
        at dotty.tools.MainGenericRunner$.main(MainGenericRunner.scala:268)
        at dotty.tools.MainGenericRunner.main(MainGenericRunner.scala)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions