From 9526ee00820253995936841252b7a20c7395c57c Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 23 Dec 2025 07:17:47 +0100 Subject: [PATCH 01/10] Add Value Type to Tensor - Introduce factory to create tensors - introduce Value Typeclass. - Establish base pattern for defining own typeclass --- core/src/main/scala/shapeful/package.scala | 58 ++++- .../main/scala/shapeful/tensor/Tensor.scala | 206 +++++++++--------- .../main/scala/shapeful/tensor/Value.scala | 47 ++++ 3 files changed, 201 insertions(+), 110 deletions(-) create mode 100644 core/src/main/scala/shapeful/tensor/Value.scala diff --git a/core/src/main/scala/shapeful/package.scala b/core/src/main/scala/shapeful/package.scala index 653da3f..afbbfa8 100644 --- a/core/src/main/scala/shapeful/package.scala +++ b/core/src/main/scala/shapeful/package.scala @@ -17,13 +17,13 @@ package object shapeful: case Prime[l] *: tail => l *: RemovePrimes[tail] case h *: tail => h *: RemovePrimes[tail] - extension[T <: Tuple : Labels](tensor: Tensor[T]) - def dropPrimes: Tensor[RemovePrimes[T]] = + extension[T <: Tuple : Labels, V : Value](tensor: Tensor[T, V]) + def dropPrimes: Tensor[RemovePrimes[T], V] = given newLabels: Labels[RemovePrimes[T]] with val names: List[String] = val oldLabels = summon[Labels[T]] oldLabels.names.toList.map(_.replace("'", "")) - Tensor.fromPy(tensor.jaxValue) + Tensor.fromPy[RemovePrimes[T], V](tensor.jaxValue) @targetName("On") infix trait ~[A, B] @@ -42,23 +42,61 @@ package object shapeful: export shapeful.tensor.{Shape, Shape0, Shape1, Shape2, Shape3} export shapeful.tensor.{DType, Device} export shapeful.tensor.{Label, Labels, Axis, AxisIndex, AxisIndices, Dim} + export shapeful.tensor.Value + + // Opaque types for DTypes - clean imports and display without .type suffix + opaque type Float32 = DType.Float32.type + object Float32: + given Value[Float32] = summon[Value[DType.Float32.type]] + + opaque type Float64 = DType.Float64.type + object Float64: + given Value[Float64] = summon[Value[DType.Float64.type]] + + opaque type Int32 = DType.Int32.type + object Int32: + given Value[Int32] = summon[Value[DType.Int32.type]] + + opaque type Int64 = DType.Int64.type + object Int64: + given Value[Int64] = summon[Value[DType.Int64.type]] + + opaque type Int16 = DType.Int16.type + object Int16: + given Value[Int16] = summon[Value[DType.Int16.type]] + + opaque type Int8 = DType.Int8.type + object Int8: + given Value[Int8] = summon[Value[DType.Int8.type]] + + opaque type UInt32 = DType.UInt32.type + object UInt32: + given Value[UInt32] = summon[Value[DType.UInt32.type]] + + opaque type UInt16 = DType.UInt16.type + object UInt16: + given Value[UInt16] = summon[Value[DType.UInt16.type]] + + opaque type UInt8 = DType.UInt8.type + object UInt8: + given Value[UInt8] = summon[Value[DType.UInt8.type]] + + opaque type Bool = DType.Bool.type + object Bool: + given Value[Bool] = summon[Value[DType.Bool.type]] // Export type helpers export shapeful.tensor.Axis.UnwrapAxes export shapeful.tensor.TupleHelpers.* - export shapeful.tensor.Broadcast + //export shapeful.tensor.Broadcast export Prime.* // Export operations - export shapeful.tensor.TensorOps.* + //export shapeful.tensor.TensorOps.* // Export automatic differentiation export shapeful.autodiff.{Autodiff, TensorTree, ToPyTree} // Export Just-in-Time compilation - export shapeful.jax.Jit.{jit, jit2} - - // Export implicit conversions - object Conversions: - export shapeful.tensor.Tensor0.{given_Conversion_Int_Tensor0, given_Conversion_Float_Tensor0} +// export shapeful.jax.Jit.{jit, jit2} diff --git a/core/src/main/scala/shapeful/tensor/Tensor.scala b/core/src/main/scala/shapeful/tensor/Tensor.scala index 52582c2..aacd706 100644 --- a/core/src/main/scala/shapeful/tensor/Tensor.scala +++ b/core/src/main/scala/shapeful/tensor/Tensor.scala @@ -6,7 +6,7 @@ import shapeful.jax.Jax import shapeful.jax.JaxDType import shapeful.jax.Jax.PyDynamic import shapeful.tensor.{Label, Labels} -import shapeful.random.Random +//import shapeful.random.Random import me.shadaj.scalapy.py.SeqConverters enum Device(val jaxDevice: PyDynamic): @@ -20,7 +20,7 @@ object Device: Device.CPU ) -class Tensor[T <: Tuple : Labels] private[tensor]( +class Tensor[T <: Tuple : Labels, V : Value] private[tensor]( val jaxValue: Jax.PyDynamic, ): @@ -32,13 +32,13 @@ class Tensor[T <: Tuple : Labels] private[tensor]( d => Jax.device_get(jaxValue).equals(d.jaxDevice) ).getOrElse(Device.Other(Jax.device_get(jaxValue).name.as[String])) - def asType(newDType: DType): Tensor[T] = - Tensor(jaxValue = Jax.jnp.astype(jaxValue, JaxDType.jaxDtype(newDType))) + def asType[V2 : Value](newDType: DType): Tensor[T, V2] = + Tensor[T, V2](jaxValue = Jax.jnp.astype(jaxValue, JaxDType.jaxDtype(newDType))) - def toDevice(newDevice: Device): Tensor[T] = - Tensor(jaxValue = Jax.device_put(jaxValue, newDevice.jaxDevice)) + def toDevice(newDevice: Device): Tensor[T, V] = + Tensor[T, V](jaxValue = Jax.device_put(jaxValue, newDevice.jaxDevice)) - def equals(other: Tensor[T]): Boolean = + def equals(other: Tensor[T, V]): Boolean = Jax.jnp.array_equal(this.jaxValue, other.jaxValue).item().as[Boolean] override def hashCode(): Int = jaxArray.tobytes().hashCode() @@ -54,118 +54,124 @@ object Tensor: type IndicesOf[T <: Tuple] = Tuple.Map[T, [ _ ] =>> Int] - private[tensor] def apply[T <: Tuple : Labels](jaxValue: Jax.PyDynamic): Tensor[T] = new Tensor[T](jaxValue) + private[tensor] def apply[T <: Tuple : Labels, V : Value](jaxValue: Jax.PyDynamic): Tensor[T, V] = new Tensor[T, V](jaxValue) - def fromPy[T <: Tuple : Labels](jaxValue: Jax.PyDynamic): Tensor[T] = Tensor(jaxValue) + def fromPy[T <: Tuple : Labels, V : Value](jaxValue: Jax.PyDynamic): Tensor[T, V] = Tensor[T, V](jaxValue) - def apply[T <: Tuple : Labels](shape: Shape[T])(using initTensor: Shape[T] => Tensor[T]): Tensor[T] = - initTensor(shape) + // Builder pattern for type-safe tensor creation + class TensorBuilder[V : Value]: + private val dtype = summon[Value[V]].dtype - def apply[T <: Tuple : Labels](shape: Shape[T], values: Array[Float], dtype: DType = DType.Float32, device: Device = Device.default): Tensor[T] = - require(values.length == shape.size, s"Values length ${values.length} does not match shape size ${shape.size}") - val jaxValues = Jax.jnp - .array( - values.toPythonProxy, - dtype = dtype.jaxType, - device = device.jaxDevice, - ) - .reshape(shape.dimensions.toPythonProxy) - Tensor(jaxValues) + def apply[T <: Tuple : Labels](shape: Shape[T], values: Array[Float], device: Device = Device.default): Tensor[T, V] = + require(values.length == shape.size, s"Values length ${values.length} does not match shape size ${shape.size}") + val jaxValues = Jax.jnp + .array( + values.toPythonProxy, + dtype = dtype.jaxType, + device = device.jaxDevice, + ) + .reshape(shape.dimensions.toPythonProxy) + Tensor[T, V](jaxValues) - def zeros[T <: Tuple : Labels](shape: Shape[T], dtype: DType = DType.Float32): Tensor[T] = - Tensor(Jax.jnp.zeros(shape.dimensions.toPythonProxy, dtype = dtype.jaxType)) + def zeros[T <: Tuple : Labels](shape: Shape[T]): Tensor[T, V] = + Tensor[T, V](Jax.jnp.zeros(shape.dimensions.toPythonProxy, dtype = dtype.jaxType)) - def ones[T <: Tuple : Labels](shape: Shape[T], dtype: DType = DType.Float32): Tensor[T] = - Tensor(Jax.jnp.ones(shape.dimensions.toPythonProxy, dtype = dtype.jaxType)) + def ones[T <: Tuple : Labels](shape: Shape[T]): Tensor[T, V] = + Tensor[T, V](Jax.jnp.ones(shape.dimensions.toPythonProxy, dtype = dtype.jaxType)) - def randn[T <: Tuple : Labels](shape: Shape[T], key: Random.Key, dtype: DType = DType.Float32): Tensor[T] = - Random.normal(key, shape, dtype = dtype) + def of[V : Value]: TensorBuilder[V] = new TensorBuilder[V] -type Tensor0 = Tensor[EmptyTuple] -type Tensor1[L] = Tensor[Tuple1[L]] -type Tensor2[L1, L2] = Tensor[(L1, L2)] -type Tensor3[L1, L2, L3] = Tensor[(L1, L2, L3)] -type Tensor4[L1, L2, L3, L4] = Tensor[(L1, L2, L3, L4)] +type Tensor0[V] = Tensor[EmptyTuple, V] +type Tensor1[L, V] = Tensor[Tuple1[L], V] +type Tensor2[L1, L2, V] = Tensor[(L1, L2), V] +type Tensor3[L1, L2, L3, V] = Tensor[(L1, L2, L3), V] +type Tensor4[L1, L2, L3, L4, V] = Tensor[(L1, L2, L3, L4), V] object Tensor0: - // implicit conversions for easy creation - given Conversion[Float, Tensor0] = (x: Float) => Tensor0(x) - given Conversion[Int, Tensor0] = (x: Int) => Tensor0(x) + class Tensor0Builder[V : Value]: + private val dtype = summon[Value[V]].dtype + + def apply(value: Float): Tensor0[V] = + Tensor0[V](Jax.jnp.array(value, dtype=dtype.jaxType)) + + def apply(value: Int): Tensor0[V] = + Tensor0[V](Jax.jnp.array(value, dtype=dtype.jaxType)) + + def apply(value: Boolean): Tensor0[V] = + Tensor0[V](Jax.jnp.array(value, dtype=dtype.jaxType)) - def apply(jaxValue: Jax.PyDynamic): Tensor0 = Tensor(jaxValue) + def of[V : Value]: Tensor0Builder[V] = new Tensor0Builder[V] - def apply(value: Float | Int | Boolean): Tensor0 = - value match - case v: Float => Tensor0(Jax.jnp.array(v, dtype=DType.Float32.jaxType)) - case v: Int => Tensor0(Jax.jnp.array(v, dtype=DType.Int32.jaxType)) - case v: Boolean => Tensor0(Jax.jnp.array(v, dtype=DType.Bool.jaxType)) + private[tensor] def apply[V : Value](jaxValue: Jax.PyDynamic): Tensor0[V] = Tensor[EmptyTuple, V](jaxValue) object Tensor1: - def apply[L : Label](axis: Axis[L], values: Array[Float], dtype: DType = DType.Float32): Tensor1[L] = - Tensor(Jax.jnp.array(values.toPythonProxy, dtype = dtype.jaxType)) + class Tensor1Builder[V : Value]: + private val dtype = summon[Value[V]].dtype - def fromInts[L : Label](axis: Axis[L], values: Array[Int], dtype: DType = DType.Int32): Tensor1[L] = - Tensor(Jax.jnp.array(values.toPythonProxy, dtype = dtype.jaxType)) + def apply[L : Label](axis: Axis[L], values: Array[Float]): Tensor1[L, V] = + Tensor[Tuple1[L], V](Jax.jnp.array(values.toPythonProxy, dtype = dtype.jaxType)) + + def fromInts[L : Label](axis: Axis[L], values: Array[Int]): Tensor1[L, V] = + Tensor[Tuple1[L], V](Jax.jnp.array(values.toPythonProxy, dtype = dtype.jaxType)) + + def of[V : Value]: Tensor1Builder[V] = new Tensor1Builder[V] object Tensor2: - def apply[L1 : Label, L2 : Label]( - shape: Shape2[L1, L2], - values: Array[Float], - dtype: DType, - ): Tensor2[L1, L2] = Tensor(shape, values, dtype) - - def apply[L1 : Label, L2 : Label]( - shape: Shape2[L1, L2], - values: Array[Float], - ): Tensor[(L1, L2)] = Tensor2(shape, values, DType.Float32) - - def apply[L1 : Label, L2 : Label]( - axis1: Axis[L1], - axis2: Axis[L2], - values: Array[Array[Float]], - dtype: DType = DType.Float32, - ): Tensor[(L1, L2)] = - val rows = values.length - val cols = values.headOption.map(_.length).getOrElse(0) - require(values.forall(_.length == cols), "All rows must have the same length") - Tensor2(Shape(axis1 -> rows, axis2 -> cols), values.flatten, dtype) - - def eye[L : Label](axis: Axis[L])(dim: Int, dtype: DType = DType.Float32): Tensor2[L, L] = - Tensor(Jax.jnp.eye(dim, dtype = dtype.jaxType)) - - def diag[L : Label](diag: Tensor1[L]): Tensor2[L, L] = - Tensor(Jax.jnp.diag(diag.jaxValue)) + class Tensor2Builder[V : Value]: + private val dtype = summon[Value[V]].dtype + + def apply[L1 : Label, L2 : Label]( + shape: Shape2[L1, L2], + values: Array[Float], + ): Tensor2[L1, L2, V] = + Tensor.of[V].apply(shape, values) + + def apply[L1 : Label, L2 : Label]( + axis1: Axis[L1], + axis2: Axis[L2], + values: Array[Array[Float]], + ): Tensor2[L1, L2, V] = + val rows = values.length + val cols = values.headOption.map(_.length).getOrElse(0) + require(values.forall(_.length == cols), "All rows must have the same length") + Tensor2.of[V].apply(Shape(axis1 -> rows, axis2 -> cols), values.flatten) + + def eye[L : Label](axis: Axis[L])(dim: Int): Tensor2[L, L, V] = + Tensor[Tuple2[L, L], V](Jax.jnp.eye(dim, dtype = dtype.jaxType)) + + def diag[L : Label](diag: Tensor1[L, V]): Tensor2[L, L, V] = + Tensor[Tuple2[L, L], V](Jax.jnp.diag(diag.jaxValue)) + + def of[V : Value]: Tensor2Builder[V] = new Tensor2Builder[V] object Tensor3: - def apply[L1 : Label, L2 : Label, L3 : Label]( - shape: Shape3[L1, L2, L3], - values: Array[Float], - dtype: DType, - ): Tensor3[L1, L2, L3] = Tensor(shape, values, dtype) - - def apply[L1 : Label, L2 : Label, L3 : Label]( - shape: Shape3[L1, L2, L3], - values: Array[Float], - ): Tensor3[L1, L2, L3] = Tensor3(shape, values, DType.Float32) - - def apply[L1 : Label, L2 : Label, L3 : Label]( - axis1: Axis[L1], - axis2: Axis[L2], - axis3: Axis[L3], - values: Array[Array[Array[Float]]], - dtype: DType = DType.Float32, - ): Tensor3[L1, L2, L3] = - val dim1 = values.length - val dim2 = values.headOption.map(_.length).getOrElse(0) - val dim3 = values.headOption.flatMap(_.headOption).map(_.length).getOrElse(0) - require(values.forall(_.length == dim2), "All second dimensions must match") - require(values.forall(_.forall(_.length == dim3)), "All third dimensions must match") - Tensor3( - Shape(axis1 -> dim1, axis2 -> dim2, axis3 -> dim3), - values.flatten.flatten, - dtype, - ) + class Tensor3Builder[V : Value]: + private val dtype = summon[Value[V]].dtype + + def apply[L1 : Label, L2 : Label, L3 : Label]( + shape: Shape3[L1, L2, L3], + values: Array[Float], + ): Tensor3[L1, L2, L3, V] = + Tensor.of[V].apply(shape, values) + + def apply[L1 : Label, L2 : Label, L3 : Label]( + axis1: Axis[L1], + axis2: Axis[L2], + axis3: Axis[L3], + values: Array[Array[Array[Float]]], + ): Tensor3[L1, L2, L3, V] = + val dim1 = values.length + val dim2 = values.headOption.map(_.length).getOrElse(0) + val dim3 = values.headOption.flatMap(_.headOption).map(_.length).getOrElse(0) + require(values.forall(_.length == dim2), "All second dimensions must match") + require(values.forall(_.forall(_.length == dim3)), "All third dimensions must match") + Tensor3.of[V].apply( + Shape(axis1 -> dim1, axis2 -> dim2, axis3 -> dim3), + values.flatten.flatten, + ) + + def of[V : Value]: Tensor3Builder[V] = new Tensor3Builder[V] diff --git a/core/src/main/scala/shapeful/tensor/Value.scala b/core/src/main/scala/shapeful/tensor/Value.scala new file mode 100644 index 0000000..00c676e --- /dev/null +++ b/core/src/main/scala/shapeful/tensor/Value.scala @@ -0,0 +1,47 @@ +package shapeful.tensor + +// Typeclass linking DType singleton types to runtime DType values +trait Value[V]: + def dtype: DType + +object Value: + // Helper to summon instances + inline def summon[V](using v: Value[V]): Value[V] = v + + // Simple two-line pattern for semantic types: + // trait Y + // object Y extends Value.As[Y, Float32] + abstract class As[V, BaseType](using base: Value[BaseType]) extends Value[V]: + def dtype: DType = base.dtype + given Value[V] = this + + // Given instances for all supported DType enum cases + given Value[DType.Float32.type] with + def dtype: DType = DType.Float32 + + given Value[DType.Float64.type] with + def dtype: DType = DType.Float64 + + given Value[DType.Int32.type] with + def dtype: DType = DType.Int32 + + given Value[DType.Int64.type] with + def dtype: DType = DType.Int64 + + given Value[DType.Int16.type] with + def dtype: DType = DType.Int16 + + given Value[DType.Int8.type] with + def dtype: DType = DType.Int8 + + given Value[DType.UInt32.type] with + def dtype: DType = DType.UInt32 + + given Value[DType.UInt16.type] with + def dtype: DType = DType.UInt16 + + given Value[DType.UInt8.type] with + def dtype: DType = DType.UInt8 + + given Value[DType.Bool.type] with + def dtype: DType = DType.Bool From 13bf82465222587bccddf688d5ac496ab09bfd95 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 23 Dec 2025 09:12:31 +0100 Subject: [PATCH 02/10] add tensorops back with new Value type argument --- core/src/main/scala/shapeful/package.scala | 4 +- .../main/scala/shapeful/random/Random.scala | 32 +- .../scala/shapeful/tensor/TensorOps.scala | 356 +++++++++--------- .../src/main/scala/basic/Playground.scala | 156 ++++---- 4 files changed, 285 insertions(+), 263 deletions(-) diff --git a/core/src/main/scala/shapeful/package.scala b/core/src/main/scala/shapeful/package.scala index afbbfa8..7d3ca3e 100644 --- a/core/src/main/scala/shapeful/package.scala +++ b/core/src/main/scala/shapeful/package.scala @@ -88,11 +88,11 @@ package object shapeful: // Export type helpers export shapeful.tensor.Axis.UnwrapAxes export shapeful.tensor.TupleHelpers.* - //export shapeful.tensor.Broadcast + export shapeful.tensor.Broadcast export Prime.* // Export operations - //export shapeful.tensor.TensorOps.* + export shapeful.tensor.TensorOps.* // Export automatic differentiation export shapeful.autodiff.{Autodiff, TensorTree, ToPyTree} diff --git a/core/src/main/scala/shapeful/random/Random.scala b/core/src/main/scala/shapeful/random/Random.scala index 8231cc2..1311312 100644 --- a/core/src/main/scala/shapeful/random/Random.scala +++ b/core/src/main/scala/shapeful/random/Random.scala @@ -49,36 +49,36 @@ object Random: */ /** Normal distribution with specified mean and standard deviation */ - def normal[T <: Tuple : Labels]( + def normal[T <: Tuple : Labels, V : Value]( key: Key, shape: Shape[T], - mean: Tensor0 = Tensor0(0f), - std: Tensor0 = Tensor0(1f), - dtype: DType = DType.Float32 - ): Tensor[T] = + mean: Tensor0[V] = Tensor0.of[DType.Float32.type].apply(0f), + std: Tensor0[V] = Tensor0.of[DType.Float32.type].apply(1f) + )(using ev: V =:= DType.Float32.type): Tensor[T, V] = + val dtype = summon[Value[V]].dtype val jaxValues = Jax.jrandom.normal( key.jaxKey, shape.dimensions.toPythonProxy, dtype = JaxDType.jaxDtype(dtype) ) - val standardNormal = Tensor.fromPy[T](jaxValues) + val standardNormal = Tensor.fromPy[T, V](jaxValues) standardNormal :* std :+ mean /** Uniform distribution in [0, 1) */ - def uniform[T <: Tuple : Labels]( + def uniform[T <: Tuple : Labels, V : Value]( key: Key, - shape: Shape[T], - dtype: DType = DType.Float32 - ): Tensor[T] = uniform(key, shape, Tensor0(0f), Tensor0(1f), dtype) + shape: Shape[T] + )(using ev: V =:= DType.Float32.type): Tensor[T, V] = + uniform(key, shape, Tensor0.of[V].apply(0f), Tensor0.of[V].apply(1f)) /** Uniform distribution in [minval, maxval) */ - def uniform[T <: Tuple : Labels]( + def uniform[T <: Tuple : Labels, V : Value]( key: Key, shape: Shape[T], - minval: Tensor0, - maxval: Tensor0, - dtype: DType - ): Tensor[T] = + minval: Tensor0[V], + maxval: Tensor0[V] + ): Tensor[T, V] = + val dtype = summon[Value[V]].dtype val jaxValues = Jax.jrandom.uniform( key.jaxKey, shape.dimensions.toPythonProxy, @@ -86,5 +86,5 @@ object Random: maxval = maxval.jaxValue, dtype = JaxDType.jaxDtype(dtype) ) - Tensor.fromPy[T](jaxValues) + Tensor.fromPy[T, V](jaxValues) diff --git a/core/src/main/scala/shapeful/tensor/TensorOps.scala b/core/src/main/scala/shapeful/tensor/TensorOps.scala index cc9cbb1..31a09ef 100644 --- a/core/src/main/scala/shapeful/tensor/TensorOps.scala +++ b/core/src/main/scala/shapeful/tensor/TensorOps.scala @@ -8,33 +8,34 @@ import shapeful.tensor.{Label, Labels} import shapeful.tensor.Axis.UnwrapAxes import scala.util.NotGiven import scala.collection.View.Empty +import shapeful.Bool import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters import shapeful.tensor.TupleHelpers.PrimeConcat import shapeful.{~, `|*|`} -trait Broadcast[T1 <: Tuple, T2 <: Tuple, Out <: Tuple]: - def broadcast(t1: Tensor[T1], t2: Tensor[T2]): (Tensor[Out], Tensor[Out]) +trait Broadcast[T1 <: Tuple, T2 <: Tuple, Out <: Tuple, V : Value]: + def broadcast(t1: Tensor[T1, V], t2: Tensor[T2, V]): (Tensor[Out, V], Tensor[Out, V]) object Broadcast: import shapeful.tensor.TensorOps.Structural.lift - given identity[T <: Tuple]: Broadcast[T, T, T] with - def broadcast(t1: Tensor[T], t2: Tensor[T]): (Tensor[T], Tensor[T]) = (t1, t2) + given identity[T <: Tuple, V : Value]: Broadcast[T, T, T, V] with + def broadcast(t1: Tensor[T, V], t2: Tensor[T, V]): (Tensor[T, V], Tensor[T, V]) = (t1, t2) - given broadcastToLeft[T1 <: Tuple : Labels, T2 <: Tuple : Labels](using + given broadcastToLeft[T1 <: Tuple : Labels, T2 <: Tuple : Labels, V : Value](using ev: Subset[T1, T2] - ): Broadcast[T1, T2, T1] with - def broadcast(t1: Tensor[T1], t2: Tensor[T2]): (Tensor[T1], Tensor[T1]) = + ): Broadcast[T1, T2, T1, V] with + def broadcast(t1: Tensor[T1, V], t2: Tensor[T2, V]): (Tensor[T1, V], Tensor[T1, V]) = val liftedT2 = t2.lift[T1](t1.shape) (t1, liftedT2) - given broadcastToRight[T1 <: Tuple : Labels, T2 <: Tuple : Labels](using + given broadcastToRight[T1 <: Tuple : Labels, T2 <: Tuple : Labels, V : Value](using ev: Subset[T2, T1], - ): Broadcast[T1, T2, T2] with - def broadcast(t1: Tensor[T1], t2: Tensor[T2]): (Tensor[T2], Tensor[T2]) = + ): Broadcast[T1, T2, T2, V] with + def broadcast(t1: Tensor[T1, V], t2: Tensor[T2, V]): (Tensor[T2, V], Tensor[T2, V]) = val liftedT1 = t1.lift[T2](t2.shape) (liftedT1, t2) @@ -46,63 +47,63 @@ object TensorOps: // ----------------------------------------------------------- object Elementwise: - def maximum[T <: Tuple : Labels](t1: Tensor[T], t2: Tensor[T]): Tensor[T] = Tensor(Jax.jnp.maximum(t1.jaxValue, t2.jaxValue)) - def minimum[T <: Tuple : Labels](t1: Tensor[T], t2: Tensor[T]): Tensor[T] = Tensor(Jax.jnp.minimum(t1.jaxValue, t2.jaxValue)) + def maximum[T <: Tuple : Labels, V : Value](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.maximum(t1.jaxValue, t2.jaxValue)) + def minimum[T <: Tuple : Labels, V : Value](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.minimum(t1.jaxValue, t2.jaxValue)) - extension [T <: Tuple : Labels](t: Tensor[T]) + extension [T <: Tuple : Labels, V : Value](t: Tensor[T, V]) - def +(other: Tensor[T]): Tensor[T] = t.add(other) - def :+[O <: Tuple](other: Tensor[O])(using broadcaster: Broadcast[T, O, T]): Tensor[T] = broadcaster.broadcast(t, other) match { case (l, r) => l.add(r) } - def +:[O <: Tuple : Labels](other: Tensor[O])(using broadcaster: Broadcast[O, T, O]): Tensor[O] = broadcaster.broadcast(other, t) match { case (r, l) => l.add(r) } - private def add(other: Tensor[T]): Tensor[T] = Tensor(Jax.jnp.add(t.jaxValue, other.jaxValue)) + def +(other: Tensor[T, V]): Tensor[T, V] = t.add(other) + def :+[O <: Tuple](other: Tensor[O, V])(using broadcaster: Broadcast[T, O, T, V]): Tensor[T, V] = broadcaster.broadcast(t, other) match { case (l, r) => l.add(r) } + def +:[O <: Tuple : Labels](other: Tensor[O, V])(using broadcaster: Broadcast[O, T, O, V]): Tensor[O, V] = broadcaster.broadcast(other, t) match { case (r, l) => l.add(r) } + private def add(other: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.add(t.jaxValue, other.jaxValue)) - def unary_- : Tensor[T] = Tensor(Jax.jnp.negative(t.jaxValue)) - def -(other: Tensor[T]): Tensor[T] = Tensor(Jax.jnp.subtract(t.jaxValue, other.jaxValue)) - def :-[O <: Tuple](other: Tensor[O])(using broadcaster: Broadcast[T, O, T]): Tensor[T] = broadcaster.broadcast(t, other) match { case (l, r) => l.subtract(r) } - def -:[O <: Tuple : Labels](other: Tensor[O])(using broadcaster: Broadcast[O, T, O]): Tensor[O] = broadcaster.broadcast(other, t) match { case (r, l) => l.subtract(r) } - private def subtract(other: Tensor[T]): Tensor[T] = Tensor(Jax.jnp.subtract(t.jaxValue, other.jaxValue)) - - def *(other: Tensor[T]): Tensor[T] = Tensor(Jax.jnp.multiply(t.jaxValue, other.jaxValue)) - def scale(other: Tensor0) = Tensor(Jax.jnp.multiply(t.jaxValue, other.jaxValue)) - def :*[O <: Tuple](other: Tensor[O])(using broadcaster: Broadcast[T, O, T]): Tensor[T] = broadcaster.broadcast(t, other) match { case (l, r) => l.multiply(r) } - def *:[O <: Tuple : Labels](other: Tensor[O])(using broadcaster: Broadcast[O, T, O]): Tensor[O] = broadcaster.broadcast(other, t) match { case (r, l) => l.multiply(r) } - private def multiply(other: Tensor[T]): Tensor[T] = Tensor(Jax.jnp.multiply(t.jaxValue, other.jaxValue)) + def unary_- : Tensor[T, V] = Tensor(Jax.jnp.negative(t.jaxValue)) + def -(other: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t.jaxValue, other.jaxValue)) + def :-[O <: Tuple](other: Tensor[O, V])(using broadcaster: Broadcast[T, O, T, V]): Tensor[T, V] = broadcaster.broadcast(t, other) match { case (l, r) => l.subtract(r) } + def -:[O <: Tuple : Labels](other: Tensor[O, V])(using broadcaster: Broadcast[O, T, O, V]): Tensor[O, V] = broadcaster.broadcast(other, t) match { case (r, l) => l.subtract(r) } + private def subtract(other: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t.jaxValue, other.jaxValue)) + + def *(other: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t.jaxValue, other.jaxValue)) + def scale(other: Tensor0[V]) = Tensor(Jax.jnp.multiply(t.jaxValue, other.jaxValue)) + def :*[O <: Tuple](other: Tensor[O, V])(using broadcaster: Broadcast[T, O, T, V]): Tensor[T, V] = broadcaster.broadcast(t, other) match { case (l, r) => l.multiply(r) } + def *:[O <: Tuple : Labels](other: Tensor[O, V])(using broadcaster: Broadcast[O, T, O, V]): Tensor[O, V] = broadcaster.broadcast(other, t) match { case (r, l) => l.multiply(r) } + private def multiply(other: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t.jaxValue, other.jaxValue)) - def /(other: Tensor[T]): Tensor[T] = Tensor(Jax.jnp.divide(t.jaxValue, other.jaxValue)) - def :/[O <: Tuple](other: Tensor[O])(using broadcaster: Broadcast[T, O, T]): Tensor[T] = broadcaster.broadcast(t, other) match { case (l, r) => l.divide(r) } - def /:[O <: Tuple : Labels](other: Tensor[O])(using broadcaster: Broadcast[O, T, O]): Tensor[O] = broadcaster.broadcast(other, t) match { case (r, l) => l.divide(r) } - private def divide(other: Tensor[T]): Tensor[T] = Tensor(Jax.jnp.divide(t.jaxValue, other.jaxValue)) + def /(other: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.divide(t.jaxValue, other.jaxValue)) + def :/[O <: Tuple](other: Tensor[O, V])(using broadcaster: Broadcast[T, O, T, V]): Tensor[T, V] = broadcaster.broadcast(t, other) match { case (l, r) => l.divide(r) } + def /:[O <: Tuple : Labels](other: Tensor[O, V])(using broadcaster: Broadcast[O, T, O, V]): Tensor[O, V] = broadcaster.broadcast(other, t) match { case (r, l) => l.divide(r) } + private def divide(other: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.divide(t.jaxValue, other.jaxValue)) // --- Unary Math --- - def abs: Tensor[T] = Tensor(Jax.jnp.abs(t.jaxValue)) - def sign: Tensor[T] = Tensor(Jax.jnp.sign(t.jaxValue)) - def pow(n: Tensor0): Tensor[T] = Tensor(Jax.jnp.power(t.jaxValue, n.jaxValue)) - def sqrt: Tensor[T] = Tensor(Jax.jnp.sqrt(t.jaxValue)) - def exp: Tensor[T] = Tensor(Jax.jnp.exp(t.jaxValue)) - def log: Tensor[T] = Tensor(Jax.jnp.log(t.jaxValue)) - def sin: Tensor[T] = Tensor(Jax.jnp.sin(t.jaxValue)) - def cos: Tensor[T] = Tensor(Jax.jnp.cos(t.jaxValue)) - def tanh: Tensor[T] = Tensor(Jax.jnp.tanh(t.jaxValue)) + def abs: Tensor[T, V] = Tensor(Jax.jnp.abs(t.jaxValue)) + def sign: Tensor[T, V] = Tensor(Jax.jnp.sign(t.jaxValue)) + def pow(n: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.power(t.jaxValue, n.jaxValue)) + def sqrt: Tensor[T, V] = Tensor(Jax.jnp.sqrt(t.jaxValue)) + def exp: Tensor[T, V] = Tensor(Jax.jnp.exp(t.jaxValue)) + def log: Tensor[T, V] = Tensor(Jax.jnp.log(t.jaxValue)) + def sin: Tensor[T, V] = Tensor(Jax.jnp.sin(t.jaxValue)) + def cos: Tensor[T, V] = Tensor(Jax.jnp.cos(t.jaxValue)) + def tanh: Tensor[T, V] = Tensor(Jax.jnp.tanh(t.jaxValue)) // --- Clipping --- - def clip(min: Float, max: Float): Tensor[T] = Tensor(Jax.jnp.clip(t.jaxValue, min, max)) - def clip(min: Tensor0, max: Tensor0): Tensor[T] = Tensor(Jax.jnp.clip(t.jaxValue, min.jaxValue, max.jaxValue)) + def clip(min: Float, max: Float): Tensor[T, V] = Tensor(Jax.jnp.clip(t.jaxValue, min, max)) + def clip(min: Tensor0[V], max: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.clip(t.jaxValue, min.jaxValue, max.jaxValue)) // --- Comparison --- - def <(other: Tensor[T]): Tensor[T] = Tensor(Jax.jnp.less(t.jaxValue, other.jaxValue)) - def <=(other: Tensor[T]): Tensor[T] = Tensor(Jax.jnp.less_equal(t.jaxValue, other.jaxValue)) - def >(other: Tensor[T]): Tensor[T] = Tensor(Jax.jnp.greater(t.jaxValue, other.jaxValue)) - def >=(other: Tensor[T]): Tensor[T] = Tensor(Jax.jnp.greater_equal(t.jaxValue, other.jaxValue)) + def <(other: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.less(t.jaxValue, other.jaxValue)) + def <=(other: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.less_equal(t.jaxValue, other.jaxValue)) + def >(other: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.greater(t.jaxValue, other.jaxValue)) + def >=(other: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.greater_equal(t.jaxValue, other.jaxValue)) - def elementEquals(other: Tensor[T]): Tensor[T] = + def elementEquals(other: Tensor[T, V]): Tensor[T, V] = require(t.shape.dimensions == other.shape.dimensions, s"Shape mismatch: ${t.shape.dimensions} vs ${other.shape.dimensions}") Tensor(jaxValue = Jax.jnp.equal(t.jaxValue, other.jaxValue)) def all: Boolean = Tensor0(Jax.jnp.all(t.jaxValue)).toBool def any: Boolean = Tensor0(Jax.jnp.any(t.jaxValue)).toBool - def approxEquals(other: Tensor[T], tolerance: Float = 1e-6f): Boolean = approxElementEquals(other, tolerance).all - def approxElementEquals(other: Tensor[T], tolerance: Float = 1e-6f): Tensor[T] = + def approxEquals(other: Tensor[T, V], tolerance: Float = 1e-6f): Boolean = approxElementEquals(other, tolerance).all + def approxElementEquals(other: Tensor[T, V], tolerance: Float = 1e-6f): Tensor[T, V] = Tensor(Jax.jnp.allclose( t.jaxValue, other.jaxValue, @@ -118,48 +119,48 @@ object TensorOps: // ----------------------------------------------------------- object Reduction: - extension [T <: Tuple : Labels](t: Tensor[T]) + extension [T <: Tuple : Labels, V : Value](t: Tensor[T, V]) // --- Sum --- - def sum: Tensor0 = Tensor0(Jax.jnp.sum(t.jaxValue)) - def sum[L : Label, R <: Tuple](axis: Axis[L])(using axisIndex: AxisIndex[T, L], remover: Remover.Aux[T, L, R], labels: Labels[R]): Tensor[R] = Tensor(Jax.jnp.sum(t.jaxValue, axis = axisIndex.value)) - def sum[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using remover: RemoverAll.Aux[T, UnwrapAxes[Inputs], R], axesIndices: AxisIndices[T, UnwrapAxes[Inputs]], labels: Labels[R]): Tensor[R] = Tensor(Jax.jnp.sum(t.jaxValue, axis = axesIndices.values.toPythonProxy)) + def sum: Tensor0[V] = Tensor0(Jax.jnp.sum(t.jaxValue)) + def sum[L : Label, R <: Tuple](axis: Axis[L])(using axisIndex: AxisIndex[T, L], remover: Remover.Aux[T, L, R], labels: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.sum(t.jaxValue, axis = axisIndex.value)) + def sum[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using remover: RemoverAll.Aux[T, UnwrapAxes[Inputs], R], axesIndices: AxisIndices[T, UnwrapAxes[Inputs]], labels: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.sum(t.jaxValue, axis = axesIndices.values.toPythonProxy)) // --- Mean --- - def mean: Tensor0 = Tensor0(Jax.jnp.mean(t.jaxValue)) - def mean[L : Label, R <: Tuple](axis: Axis[L])(using axisIndex: AxisIndex[T, L], remover: Remover.Aux[T, L, R], labels: Labels[R]): Tensor[R] = Tensor(Jax.jnp.mean(t.jaxValue, axis = axisIndex.value)) - def mean[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using remover: RemoverAll.Aux[T, UnwrapAxes[Inputs], R], axesIndices: AxisIndices[T, UnwrapAxes[Inputs]], labels: Labels[R]): Tensor[R] = Tensor(Jax.jnp.mean(t.jaxValue, axis = axesIndices.values.toPythonProxy)) + def mean: Tensor0[V] = Tensor0(Jax.jnp.mean(t.jaxValue)) + def mean[L : Label, R <: Tuple](axis: Axis[L])(using axisIndex: AxisIndex[T, L], remover: Remover.Aux[T, L, R], labels: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.mean(t.jaxValue, axis = axisIndex.value)) + def mean[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using remover: RemoverAll.Aux[T, UnwrapAxes[Inputs], R], axesIndices: AxisIndices[T, UnwrapAxes[Inputs]], labels: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.mean(t.jaxValue, axis = axesIndices.values.toPythonProxy)) // --- Std --- - def std: Tensor0 = Tensor0(Jax.jnp.std(t.jaxValue)) - def std[L : Label, R <: Tuple](axis: Axis[L])(using axisIndex: AxisIndex[T, L], remover: Remover.Aux[T, L, R], labels: Labels[R]): Tensor[R] = Tensor(Jax.jnp.std(t.jaxValue, axis = axisIndex.value)) - def std[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using remover: RemoverAll.Aux[T, UnwrapAxes[Inputs], R], axesIndices: AxisIndices[T, UnwrapAxes[Inputs]], labels: Labels[R]): Tensor[R] = Tensor(Jax.jnp.std(t.jaxValue, axis = axesIndices.values.toPythonProxy)) + def std: Tensor0[V] = Tensor0(Jax.jnp.std(t.jaxValue)) + def std[L : Label, R <: Tuple](axis: Axis[L])(using axisIndex: AxisIndex[T, L], remover: Remover.Aux[T, L, R], labels: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.std(t.jaxValue, axis = axisIndex.value)) + def std[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using remover: RemoverAll.Aux[T, UnwrapAxes[Inputs], R], axesIndices: AxisIndices[T, UnwrapAxes[Inputs]], labels: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.std(t.jaxValue, axis = axesIndices.values.toPythonProxy)) // --- Max --- - def max: Tensor0 = Tensor0(Jax.jnp.max(t.jaxValue)) - def max[L : Label, R <: Tuple](axis: Axis[L])(using axisIndex: AxisIndex[T, L], remover: Remover.Aux[T, L, R], labels: Labels[R]): Tensor[R] = Tensor(Jax.jnp.max(t.jaxValue, axis = axisIndex.value)) - def max[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using remover: RemoverAll.Aux[T, UnwrapAxes[Inputs], R], axesIndices: AxisIndices[T, UnwrapAxes[Inputs]], labels: Labels[R]): Tensor[R] = Tensor(Jax.jnp.max(t.jaxValue, axis = axesIndices.values.toPythonProxy)) + def max: Tensor0[V] = Tensor0(Jax.jnp.max(t.jaxValue)) + def max[L : Label, R <: Tuple](axis: Axis[L])(using axisIndex: AxisIndex[T, L], remover: Remover.Aux[T, L, R], labels: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.max(t.jaxValue, axis = axisIndex.value)) + def max[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using remover: RemoverAll.Aux[T, UnwrapAxes[Inputs], R], axesIndices: AxisIndices[T, UnwrapAxes[Inputs]], labels: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.max(t.jaxValue, axis = axesIndices.values.toPythonProxy)) // --- Min --- - def min: Tensor0 = Tensor0(Jax.jnp.min(t.jaxValue)) - def min[L : Label, R <: Tuple](axis: Axis[L])(using axisIndex: AxisIndex[T, L], remover: Remover.Aux[T, L, R], labels: Labels[R]): Tensor[R] = Tensor(Jax.jnp.min(t.jaxValue, axis = axisIndex.value)) - def min[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using remover: RemoverAll.Aux[T, UnwrapAxes[Inputs], R], axesIndices: AxisIndices[T, UnwrapAxes[Inputs]], labels: Labels[R]): Tensor[R] = Tensor(Jax.jnp.min(t.jaxValue, axis = axesIndices.values.toPythonProxy)) + def min: Tensor0[V] = Tensor0(Jax.jnp.min(t.jaxValue)) + def min[L : Label, R <: Tuple](axis: Axis[L])(using axisIndex: AxisIndex[T, L], remover: Remover.Aux[T, L, R], labels: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.min(t.jaxValue, axis = axisIndex.value)) + def min[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using remover: RemoverAll.Aux[T, UnwrapAxes[Inputs], R], axesIndices: AxisIndices[T, UnwrapAxes[Inputs]], labels: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.min(t.jaxValue, axis = axesIndices.values.toPythonProxy)) // --- Argmax --- - def argmax: Tensor0 = Tensor0(Jax.jnp.argmax(t.jaxValue)) - def argmax[L : Label, R <: Tuple](axis: Axis[L])(using axisIndex: AxisIndex[T, L], remover: Remover.Aux[T, L, R], labels: Labels[R]): Tensor[R] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = axisIndex.value)) + def argmax: Tensor0[V] = Tensor0(Jax.jnp.argmax(t.jaxValue)) + def argmax[L : Label, R <: Tuple](axis: Axis[L])(using axisIndex: AxisIndex[T, L], remover: Remover.Aux[T, L, R], labels: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = axisIndex.value)) // --- Argmin --- - def argmin: Tensor0 = Tensor0(Jax.jnp.argmin(t.jaxValue)) - def argmin[L : Label, R <: Tuple](axis: Axis[L])(using axisIndex: AxisIndex[T, L], remover: Remover.Aux[T, L, R], labels: Labels[R]): Tensor[R] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = axisIndex.value)) + def argmin: Tensor0[V] = Tensor0(Jax.jnp.argmin(t.jaxValue)) + def argmin[L : Label, R <: Tuple](axis: Axis[L])(using axisIndex: AxisIndex[T, L], remover: Remover.Aux[T, L, R], labels: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = axisIndex.value)) end Reduction object Contraction: - extension [T <: Tuple : Labels](tensor: Tensor[T]) + extension [T <: Tuple : Labels, V : Value](tensor: Tensor[T, V]) - def outerProduct[OtherShape <: Tuple : Labels, Out <: Tuple](other: Tensor[OtherShape])( + def outerProduct[OtherShape <: Tuple : Labels, Out <: Tuple](other: Tensor[OtherShape, V])( using primeConcat: PrimeConcat.Aux[T, OtherShape, Out], - ): Tensor[Out] = + ): Tensor[Out, V] = import Labels.ForPrimeConcat.given Tensor( // Jax outer product flattens, reshape required @@ -177,7 +178,7 @@ object TensorOps: Out <: Tuple ] (axis: Axis[ContractAxis]) - (other: Tensor[OtherShape])(using + (other: Tensor[OtherShape, V])(using remover: RemoverAll.Aux[T, ContractAxis *: EmptyTuple, R1], otherRemover: RemoverAll.Aux[OtherShape, ContractAxis *: EmptyTuple, R2], axisIndex: AxisIndex[T, ContractAxis], @@ -185,7 +186,7 @@ object TensorOps: primeConcat: PrimeConcat.Aux[R1, R2, Out], r1Labels: Labels[R1], r2Labels: Labels[R2], - ): Tensor[Out] = + ): Tensor[Out, V] = import Labels.ForPrimeConcat.given val axesTuple1 = Jax.Dynamic.global.tuple(Seq(axisIndex.value).toPythonProxy) val axesTuple2 = Jax.Dynamic.global.tuple(Seq(otherAxisIndex.value).toPythonProxy) @@ -203,14 +204,14 @@ object TensorOps: Out <: Tuple ] (axis: Axis[ContractAxisA ~ ContractAxisB]) - (other: Tensor[OtherShape])(using + (other: Tensor[OtherShape, V])(using remover: RemoverAll.Aux[T, ContractAxisA *: EmptyTuple, R1], otherRemover: RemoverAll.Aux[OtherShape, ContractAxisB *: EmptyTuple, R2], axisIndex: AxisIndex[T, ContractAxisA], otherAxisIndex: AxisIndex[OtherShape, ContractAxisB], primeConcat: PrimeConcat.Aux[R1, R2, Out], outLabels: Labels[Out], - ): Tensor[Out] = + ): Tensor[Out, V] = import Labels.ForPrimeConcat.given val axesTuple1 = Jax.Dynamic.global.tuple(Seq(axisIndex.value).toPythonProxy) val axesTuple2 = Jax.Dynamic.global.tuple(Seq(otherAxisIndex.value).toPythonProxy) @@ -222,16 +223,16 @@ object TensorOps: object LinearAlgebra: - extension [T <: Tuple : Labels](t: Tensor[T]) - def norm: Tensor0 = Tensor0(Jax.jnp.linalg.norm(t.jaxValue)) - def inv: Tensor[T] = Tensor(Jax.jnp.linalg.inv(t.jaxValue)) + extension [T <: Tuple : Labels, V : Value](t: Tensor[T, V]) + def norm: Tensor0[V] = Tensor0(Jax.jnp.linalg.norm(t.jaxValue)) + def inv: Tensor[T, V] = Tensor(Jax.jnp.linalg.inv(t.jaxValue)) def det[L1 : Label, L2 : Label](axis1: Axis[L1], axis2: Axis[L2])(using idx1: AxisIndex[T, L1], idx2: AxisIndex[T, L2], remover: RemoverAll[T, (L1, L2)], namesOf: Labels[remover.Out] - ): Tensor[remover.Out] = + ): Tensor[remover.Out, V] = // JAX det only works on the last two axes (-2, -1). We must move the user's selected axes to the end. val moved = Jax.jnp.moveaxis( t.jaxValue, @@ -245,21 +246,21 @@ object TensorOps: idx2: AxisIndex[T, L2], remover: RemoverAll[T, (L1, L2)], namesOf: Labels[remover.Out] - ): Tensor[remover.Out] = Tensor(Jax.jnp.trace(t.jaxValue, offset = offset, axis1 = idx1.value, axis2 = idx2.value)) + ): Tensor[remover.Out, V] = Tensor(Jax.jnp.trace(t.jaxValue, offset = offset, axis1 = idx1.value, axis2 = idx2.value)) def diagonal[L1 : Label, L2 : Label](axis1: Axis[L1], axis2: Axis[L2], offset: Int=0)(using idx1: AxisIndex[T, L1], idx2: AxisIndex[T, L2], remover: RemoverAll[T, (L1, L2)], namesOf: Labels[remover.Out] - ): Tensor[remover.Out *: L1 *: EmptyTuple] = Tensor(Jax.jnp.diagonal(t.jaxValue, offset = offset, axis1 = idx1.value, axis2 = idx2.value)) + ): Tensor[remover.Out *: L1 *: EmptyTuple, V] = Tensor(Jax.jnp.diagonal(t.jaxValue, offset = offset, axis1 = idx1.value, axis2 = idx2.value)) - extension [L1 : Label, L2 : Label](t: Tensor2[L1, L2]) - def det: Tensor0 = Tensor0(Jax.jnp.linalg.det(t.jaxValue)) - def trace: Tensor0 = t.trace(0) - def trace(offset: Int): Tensor0 = Tensor0(Jax.jnp.trace(t.jaxValue, offset = offset)) - def diagonal: Tensor1[L1] = t.diagonal(0) - def diagonal(offset: Int): Tensor1[L1] = Tensor(Jax.jnp.diagonal(t.jaxValue, offset = offset)) + extension [L1 : Label, L2 : Label, V : Value](t: Tensor2[L1, L2, V]) + def det: Tensor0[V] = Tensor0(Jax.jnp.linalg.det(t.jaxValue)) + def trace: Tensor0[V] = t.trace(0) + def trace(offset: Int): Tensor0[V] = Tensor0(Jax.jnp.trace(t.jaxValue, offset = offset)) + def diagonal: Tensor1[L1, V] = t.diagonal(0) + def diagonal(offset: Int): Tensor1[L1, V] = Tensor(Jax.jnp.diagonal(t.jaxValue, offset = offset)) end LinearAlgebra @@ -350,32 +351,32 @@ object TensorOps: import Util.* object TensorWhere: - def where[T <: Tuple : Labels]( - condition: Tensor[T], - x: Tensor[T], - y: Tensor[T] - ): Tensor[T] = + def where[T <: Tuple : Labels, V : Value]( + condition: Tensor[T, Bool], + x: Tensor[T, V], + y: Tensor[T, V] + ): Tensor[T, V] = Tensor(Jax.jnp.where(condition.jaxValue, x.jaxValue, y.jaxValue)) export TensorWhere.where - def stack[L : Label, T <: Tuple : Labels]( - tensors: Seq[Tensor[T]], + def stack[L : Label, T <: Tuple : Labels, V : Value]( + tensors: Seq[Tensor[T, V]], newAxis: Axis[L], - ): Tensor[L *: T] = + ): Tensor[L *: T, V] = require(tensors.nonEmpty, "Cannot stack an empty sequence of tensors") val jaxValuesSeq = tensors.map(_.jaxValue).toPythonProxy val stackedJaxValue = Jax.jnp.stack(jaxValuesSeq, axis = 0) Tensor(stackedJaxValue) - def stack[NewL, L, T <: Tuple : Labels]( - tensors: Seq[Tensor[T]], + def stack[NewL, L, T <: Tuple : Labels, V : Value]( + tensors: Seq[Tensor[T, V]], newAxis: Axis[NewL], afterAxis: Axis[L], )(using newLabel: Label[NewL], axisIndex: AxisIndex[T, L], - ): Tensor[InsertAfter[T, L, NewL]] = + ): Tensor[InsertAfter[T, L, NewL], V] = require(tensors.nonEmpty, "Cannot stack an empty sequence of tensors") val axisIdx = axisIndex.value + 1 // we are inserting after the given axis, so shift by 1 val jaxValuesSeq = tensors.map(_.jaxValue).toPythonProxy @@ -386,19 +387,19 @@ object TensorOps: val names = newNames.toSeq Tensor(stackedJaxValue) - def concatenate[L : Label, T <: Tuple : Labels]( - tensors: Seq[Tensor[T]], + def concatenate[L : Label, T <: Tuple : Labels, V : Value]( + tensors: Seq[Tensor[T, V]], concatAxis: Axis[L], )( using axisIndex: AxisIndex[T, L], - ): Tensor[T] = + ): Tensor[T, V] = require(tensors.nonEmpty, "Cannot concatenate an empty sequence of tensors") val axisIdx = axisIndex.value val jaxValuesSeq = tensors.map(_.jaxValue).toPythonProxy val concatenatedJaxValue = Jax.jnp.concatenate(jaxValuesSeq, axis = axisIdx) Tensor(concatenatedJaxValue) - extension [T <: Tuple : Labels](tensor: Tensor[T]) + extension [T <: Tuple : Labels, V : Value](tensor: Tensor[T, V]) private def calcPyIndices[Inputs <: Tuple]( inputs: Inputs, @@ -431,7 +432,7 @@ object TensorOps: def split[newL, splitL](newAxis: Axis[newL], splitAxis: Axis[splitL], interval: Int)(using newLabel: Label[newL], axisIndex: AxisIndex[T, splitL], - ): Tensor[InsertBefore[T, splitL, newL]] = + ): Tensor[InsertBefore[T, splitL, newL], V] = val splitIdx = axisIndex.value val names = summon[Labels[T]].names val newNames = names.take(splitIdx) ++ Seq(newLabel.name) ++ names.drop(splitIdx) @@ -451,9 +452,9 @@ object TensorOps: def chunk[splitL : Label](splitAxis: Axis[splitL], interval: Int)(using axisIndex: AxisIndex[T, splitL], - ): Seq[Tensor[T]] = + ): Seq[Tensor[T, V]] = val res = Jax.jnp.split(tensor.jaxValue, interval, axis = axisIndex.value).as[Seq[Jax.PyDynamic]] - res.map(x => Tensor[T](x)) + res.map(x => Tensor[T, V](x)) def tile = ??? def repeat = ??? @@ -465,7 +466,7 @@ object TensorOps: remover: RemoverAll.Aux[T, LabelsToRemove, R], axesIndices: AxisIndices[T, ExtractLabels[Inputs]], labels: Labels[R], - ): Tensor[R] = + ): Tensor[R, V] = val pyIndices = tensor.calcPyIndices(inputs, axesIndices) Tensor(tensor.jaxValue.bracketAccess(pyIndices)) @@ -476,7 +477,7 @@ object TensorOps: remover: RemoverAll.Aux[T, LabelsToRemove, R], axesIndices: AxisIndices[T, ExtractLabels[Tuple1[(Axis[L], I)]]], labels: Labels[R], - ): Tensor[R] = slice(Tuple1(axisWithSliceIndex)) + ): Tensor[R, V] = slice(Tuple1(axisWithSliceIndex)) def set[Inputs <: Tuple, LabelsToRemove <: Tuple, R <: Tuple]( inputs: Inputs @@ -485,7 +486,7 @@ object TensorOps: remover: RemoverAll.Aux[T, LabelsToRemove, R], axesIndices: AxisIndices[T, ExtractLabels[Inputs]], labels: Labels[R], - )(value: Tensor[R]): Tensor[T] = + )(value: Tensor[R, V]): Tensor[T, V] = val pyIndices = tensor.calcPyIndices(inputs, axesIndices) val result = tensor.jaxValue.at.bracketAccess(pyIndices).set(value.jaxValue) Tensor(result) @@ -497,9 +498,9 @@ object TensorOps: remover: RemoverAll.Aux[T, LabelsToRemove, R], axesIndices: AxisIndices[T, ExtractLabels[Tuple1[(Axis[L], I)]]], labels: Labels[R] - )(value: Tensor[R]): Tensor[T] = set(Tuple1(axisWithSliceIndex))(value) + )(value: Tensor[R, V]): Tensor[T, V] = set(Tuple1(axisWithSliceIndex))(value) - def rearrange[Axes <: Tuple](newOrder: Axes)(using Labels[UnwrapAxes[Axes]]): Tensor[UnwrapAxes[Axes]] = + def rearrange[Axes <: Tuple](newOrder: Axes)(using Labels[UnwrapAxes[Axes]]): Tensor[UnwrapAxes[Axes], V] = rearrange[Axes, EmptyTuple](newOrder, EmptyTuple) def rearrange[Axes <: Tuple, Dims <: Tuple]( @@ -508,7 +509,7 @@ object TensorOps: )(using newLabels: Labels[UnwrapAxes[Axes]], extractor: DimExtractor[Dims], - ): Tensor[UnwrapAxes[Axes]] = + ): Tensor[UnwrapAxes[Axes], V] = def createEinopsPattern(fromPattern: String, toPattern: String): String = def cleanPattern(pattern: String): String = // to replace all a*b*c in pattern with (a b c), example: @@ -534,7 +535,7 @@ object TensorOps: def lift[O <: Tuple : Labels](newShape: Shape[O])( using ev: Subset[O, T] // Ensures T's axes are all present in O - ): Tensor[O] = + ): Tensor[O, V] = val t = tensor val currentNames = summon[Labels[T]].names @@ -562,9 +563,9 @@ object TensorOps: rename: (Axis[OldLabel], Axis[NewLabel]), )(using replacer: Replacer[T, OldLabel, NewLabel] - ): Tensor[replacer.Out] = Tensor(tensor.jaxValue) + ): Tensor[replacer.Out, V] = Tensor(tensor.jaxValue) - def retag[newT <: Tuple](using newLabels: Labels[newT]): Tensor[newT] = + def retag[newT <: Tuple](using newLabels: Labels[newT]): Tensor[newT, V] = Tensor(tensor.jaxValue)(using newLabels) def relabelAll[newT <: Tuple]( @@ -573,7 +574,7 @@ object TensorOps: newLabels: Labels[UnwrapAxes[newT]], @implicitNotFound("Cannot convert tensor of shape ${T} to shape ${newT} due to size mismatch.") evSameSize: Tuple.Size[newT] =:= Tuple.Size[T], - ): Tensor[UnwrapAxes[newT]] = Tensor[UnwrapAxes[newT]](tensor.jaxValue) + ): Tensor[UnwrapAxes[newT], V] = Tensor[UnwrapAxes[newT], V](tensor.jaxValue) def swap[L1 : Label, L2 : Label]( axis1: Axis[L1], @@ -581,7 +582,7 @@ object TensorOps: )(using axisIndex1: AxisIndex[T, L1], axisIndex2: AxisIndex[T, L2], - ): Tensor[Swap[T, L1, L2]] = + ): Tensor[Swap[T, L1, L2], V] = given Labels[Swap[T, L1, L2]] with def names = val originalNames = summon[Labels[T]].names @@ -594,17 +595,17 @@ object TensorOps: } Tensor(Jax.jnp.swapaxes(tensor.jaxValue, axisIndex1.value, axisIndex2.value)) - def ravel: Tensor1[JoinNames[T]] = + def ravel: Tensor1[JoinNames[T], V] = given Labels[Tuple1[JoinNames[T]]] with def names = List(summon[Labels[T]].names.mkString("*")) Tensor(Jax.jnp.ravel(tensor.jaxValue)) - def appendAxis[L : Label](axis: Axis[L])(using AxisAbsent[T, L]): Tensor[Tuple.Concat[T, Tuple1[L]]] = + def appendAxis[L : Label](axis: Axis[L])(using AxisAbsent[T, L]): Tensor[Tuple.Concat[T, Tuple1[L]], V] = import Labels.ForConcat.given val newShape = tensor.shape.dimensions :+ 1 Tensor(Jax.jnp.reshape(tensor.jaxValue, newShape.toPythonProxy)) - def prependAxis[L : Label](axis: Axis[L])(using AxisAbsent[T, L]): Tensor[Tuple.Concat[Tuple1[L], T]] = + def prependAxis[L : Label](axis: Axis[L])(using AxisAbsent[T, L]): Tensor[Tuple.Concat[Tuple1[L], T], V] = import Labels.ForConcat.given val newShape = 1 +: tensor.shape.dimensions Tensor(Jax.jnp.reshape(tensor.jaxValue, newShape.toPythonProxy)) @@ -613,7 +614,7 @@ object TensorOps: remover: RemoverAll.Aux[T, Tuple1[L], R], axisIndex: AxisIndex[T, L], labels: Labels[R], - ): Tensor[R] = + ): Tensor[R, V] = require( tensor.shape.dimensions(axisIndex.value) == 1, s"Cannot squeeze axis ${axis} of size ${tensor.shape.dimensions(axisIndex.value)}" @@ -630,23 +631,26 @@ object TensorOps: object ZipVmap: - type TensorsOf[Shapes <: Tuple] <: Tuple = Shapes match + type TensorsOf[Shapes <: Tuple, V] <: Tuple = Shapes match case EmptyTuple => EmptyTuple case head *: tail => head match - case Tuple => Tensor[head] *: TensorsOf[tail] + case Tuple => Tensor[head, V] *: TensorsOf[tail, V] type ExtractShape[T] = T match - case Tensor[s] => s + case Tensor[s, v] => s + + type ExtractValue[T] = T match + case Tensor[s, v] => v type ShapesOf[Tensors <: Tuple] = Tuple.Map[Tensors, ExtractShape] - trait Zipper[Shapes <: Tuple, L, SlicedShapes <: Tuple]: - def dimSize(tensors: TensorsOf[Shapes], axis: Axis[L]): Int - def sliceAll(tensors: TensorsOf[Shapes], axis: Axis[L], idx: Int): TensorsOf[SlicedShapes] + trait Zipper[Shapes <: Tuple, L, SlicedShapes <: Tuple, V]: + def dimSize(tensors: TensorsOf[Shapes, V], axis: Axis[L]): Int + def sliceAll(tensors: TensorsOf[Shapes, V], axis: Axis[L], idx: Int): TensorsOf[SlicedShapes, V] object Zipper: - given empty[L]: Zipper[EmptyTuple, L, EmptyTuple] = new Zipper[EmptyTuple, L, EmptyTuple]: + given empty[L, V]: Zipper[EmptyTuple, L, EmptyTuple, V] = new Zipper[EmptyTuple, L, EmptyTuple, V]: def dimSize(t: EmptyTuple, axis: Axis[L]) = 0 def sliceAll(t: EmptyTuple, axis: Axis[L], idx: Int) = EmptyTuple @@ -655,34 +659,36 @@ object TensorOps: TailShapes <: Tuple, L, HeadSliced <: Tuple, - TailSliced <: Tuple + TailSliced <: Tuple, + V ](using remover: RemoverAll.Aux[HeadShape, L *: EmptyTuple, HeadSliced], axisIndex: AxisIndex[HeadShape, L], - tailZipper: Zipper[TailShapes, L, TailSliced], + tailZipper: Zipper[TailShapes, L, TailSliced, V], labels1: Labels[HeadShape], labels2: Labels[HeadSliced], - ): Zipper[HeadShape *: TailShapes, L, HeadSliced *: TailSliced] = - new Zipper[HeadShape *: TailShapes, L, HeadSliced *: TailSliced]: + value: Value[V], + ): Zipper[HeadShape *: TailShapes, L, HeadSliced *: TailSliced, V] = + new Zipper[HeadShape *: TailShapes, L, HeadSliced *: TailSliced, V]: - def dimSize(tensors: TensorsOf[HeadShape *: TailShapes], axis: Axis[L]): Int = - val head = tensors.asInstanceOf[Tensor[HeadShape] *: Tuple].head + def dimSize(tensors: TensorsOf[HeadShape *: TailShapes, V], axis: Axis[L]): Int = + val head = tensors.asInstanceOf[Tensor[HeadShape, V] *: Tuple].head head.shape.dimensions(axisIndex.value) - def sliceAll(tensors: TensorsOf[HeadShape *: TailShapes], axis: Axis[L], idx: Int): TensorsOf[HeadSliced *: TailSliced] = - val tuple = tensors.asInstanceOf[Tensor[HeadShape] *: TensorsOf[TailShapes]] - val slicedHead = tuple.head.slice(axis -> idx) + def sliceAll(tensors: TensorsOf[HeadShape *: TailShapes, V], axis: Axis[L], idx: Int): TensorsOf[HeadSliced *: TailSliced, V] = + val tuple = tensors.asInstanceOf[Tensor[HeadShape, V] *: TensorsOf[TailShapes, V]] + val slicedHead = tuple.head.slice(axis -> idx)(using labels1, value) val slicedTail = tailZipper.sliceAll(tuple.tail, axis, idx) - (slicedHead *: slicedTail).asInstanceOf[TensorsOf[HeadSliced *: TailSliced]] + (slicedHead *: slicedTail).asInstanceOf[TensorsOf[HeadSliced *: TailSliced, V]] - case class ZipResult[L : Label, Shapes <: Tuple, SlicedShapes <: Tuple]( + case class ZipResult[L : Label, Shapes <: Tuple, SlicedShapes <: Tuple, V : Value]( axis: Axis[L], - tensors: TensorsOf[Shapes] - )(using zipper: Zipper[Shapes, L, SlicedShapes]): + tensors: TensorsOf[Shapes, V] + )(using zipper: Zipper[Shapes, L, SlicedShapes, V]): def vmap[OutShape <: Tuple : Labels]( - f: TensorsOf[SlicedShapes] => Tensor[OutShape] - ): Tensor[L *: OutShape] = + f: TensorsOf[SlicedShapes, V] => Tensor[OutShape, V] + ): Tensor[L *: OutShape, V] = val size = zipper.dimSize(tensors, axis) @@ -696,18 +702,18 @@ object TensorOps: Structural.stack(results, axis) - def zip[L : Label, Inputs <: Tuple, Sliced <: Tuple]( + def zip[L : Label, Inputs <: Tuple, Sliced <: Tuple, V : Value]( axis: Axis[L] )( tensors: Inputs )(using - zipper: Zipper[ShapesOf[Inputs], L, Sliced] - ): ZipResult[L, ShapesOf[Inputs], Sliced] = - ZipResult(axis, tensors.asInstanceOf[TensorsOf[ShapesOf[Inputs]]]) + zipper: Zipper[ShapesOf[Inputs], L, Sliced, V] + ): ZipResult[L, ShapesOf[Inputs], Sliced, V] = + ZipResult(axis, tensors.asInstanceOf[TensorsOf[ShapesOf[Inputs], V]]) def zipvmap[ L : Label, - Inputs <: Tuple, + Inputs <: NonEmptyTuple, OutShape <: Tuple : Labels, Sliced <: Tuple, ]( @@ -715,15 +721,15 @@ object TensorOps: )( tensors: Inputs )(using - zipper: Zipper[ShapesOf[Inputs], L, Sliced] + zipper: Zipper[ShapesOf[Inputs], L, Sliced, ExtractValue[Tuple.Head[Inputs]]] )( - f: TensorsOf[Sliced] => Tensor[OutShape] - ): Tensor[L *: OutShape] = + f: TensorsOf[Sliced, ExtractValue[Tuple.Head[Inputs]]] => Tensor[OutShape, ExtractValue[Tuple.Head[Inputs]]] + )(using Value[ExtractValue[Tuple.Head[Inputs]]]): Tensor[L *: OutShape, ExtractValue[Tuple.Head[Inputs]]] = zip(axis)(tensors).vmap(f) export ZipVmap.zipvmap - extension [T <: Tuple : Labels](t: Tensor[T]) + extension [T <: Tuple : Labels, V : Value](t: Tensor[T, V]) def vmap[VmapAxis : Label, OuterShape <: Tuple : Labels, R <: Tuple]( axis: Axis[VmapAxis] @@ -731,12 +737,12 @@ object TensorOps: remover: Remover.Aux[T, VmapAxis, R], vmapAxisIndex: AxisIndex[T, VmapAxis], )( - f: Tensor[R] => Tensor[OuterShape] + f: Tensor[R, V] => Tensor[OuterShape, V] )(using labels: Labels[R] - ): Tensor[VmapAxis *: OuterShape] = + ): Tensor[VmapAxis *: OuterShape, V] = val fpy = (jxpr: Jax.PyDynamic) => - val innerTensor = Tensor[R](jxpr) + val innerTensor = Tensor[R, V](jxpr) println(("A", innerTensor.shape)) val result = f(innerTensor) result.jaxValue @@ -749,10 +755,10 @@ object TensorOps: axisIndex: AxisIndex[T, L], replacer: Replacer[T, L, OutAxis], )( - f: Tensor[Tuple1[L]] => Tensor[Tuple1[OutAxis]] - ): Tensor[replacer.Out] = + f: Tensor[Tuple1[L], V] => Tensor[Tuple1[OutAxis], V] + ): Tensor[replacer.Out, V] = val fpy = (jxpr: Jax.PyDynamic) => - val inputTensor = Tensor[Tuple1[L]](jxpr) + val inputTensor = Tensor[Tuple1[L], V](jxpr) val result = f(inputTensor) result.jaxValue @@ -765,14 +771,14 @@ object TensorOps: def vreduce[L : Label, R <: Tuple]( axis: Axis[L] )( - f: Tensor[Tuple1[L]] => Tensor0 + f: Tensor[Tuple1[L], V] => Tensor0[V] )(using axisIndex: AxisIndex[T, L], remover: Remover.Aux[T, L, R], labels: Labels[R] - ): Tensor[remover.Out] = + ): Tensor[remover.Out, V] = val fpy = (jxpr: Jax.PyDynamic) => - val inputTensor = Tensor[Tuple1[L]](jxpr) + val inputTensor = Tensor[Tuple1[L], V](jxpr) val result = f(inputTensor) result.jaxValue @@ -795,45 +801,45 @@ object TensorOps: // Common specialized operation names // ----------------------------------------------------------- object ScalarOps: - extension (t: Tensor0) + extension [V : Value](t: Tensor0[V]) def toInt: Int = t.jaxValue.item().as[Int] def toFloat: Float = t.jaxValue.item().as[Float] def toBool: Boolean = t.jaxValue.item().as[Boolean] @targetName("tensor0Pow") - def pow(exponent: Tensor0): Tensor0 = Tensor0(Jax.jnp.pow(t.jaxValue, exponent.jaxValue)) + def pow(exponent: Tensor0[V]): Tensor0[V] = Tensor0(Jax.jnp.pow(t.jaxValue, exponent.jaxValue)) object VectorOps: - extension [L : Label](t: Tensor1[L]) - def dot(other: Tensor1[L]): Tensor0 = t.innerDot(other) - def innerDot(other: Tensor1[L]): Tensor0 = t.contract(Axis[L])(other) - def outerDot[OtherLabel : Label](other: Tensor1[OtherLabel]): Tensor2[L, OtherLabel] = + extension [L : Label, V : Value](t: Tensor1[L, V]) + def dot(other: Tensor1[L, V]): Tensor0[V] = t.innerDot(other) + def innerDot(other: Tensor1[L, V]): Tensor0[V] = t.contract(Axis[L])(other) + def outerDot[OtherLabel : Label](other: Tensor1[OtherLabel, V]): Tensor2[L, OtherLabel, V] = val result = t.outerProduct(other) result - def relabelTo[NewL : Label](newAxis: Axis[NewL]): Tensor1[NewL] = Tensor[Tuple1[NewL]](t.jaxValue) + def relabelTo[NewL : Label](newAxis: Axis[NewL]): Tensor1[NewL, V] = Tensor[Tuple1[NewL], V](t.jaxValue) object MatrixOps: - extension [L1 : Label, L2 : Label](t: Tensor2[L1, L2]) - def transpose: Tensor2[L2, L1] = t.rearrange((Axis[L2], Axis[L1])) + extension [L1 : Label, L2 : Label, V : Value](t: Tensor2[L1, L2, V]) + def transpose: Tensor2[L2, L1, V] = t.rearrange((Axis[L2], Axis[L1])) @targetName("tensor2MatmulTensor2") - def matmul[L3 : Label](other: Tensor2[L2, L3])(using + def matmul[L3 : Label](other: Tensor2[L2, L3, V])(using remover: Remover.Aux[(L1, L2), L2, Tuple1[L1]], otherRemover: Remover.Aux[(L2, L3), L2, Tuple1[L3]], idx1: AxisIndex[(L1, L2), L2], idx2: AxisIndex[(L2, L3), L2], - ): Tensor2[L1, L3] = + ): Tensor2[L1, L3, V] = val result = t.contract(Axis[L2])(other) result @targetName("tensor2MatmulTensor1") - def matmul(other: Tensor1[L2])(using + def matmul(other: Tensor1[L2, V])(using remover: Remover.Aux[(L1, L2), L2, Tuple1[L1]], otherRemover: Remover.Aux[Tuple1[L2], L2, EmptyTuple], idx1: AxisIndex[(L1, L2), L2], idx2: AxisIndex[Tuple1[L2], L2], - ): Tensor[Tuple1[L1]] = + ): Tensor[Tuple1[L1], V] = val result = t.contract(Axis[L2])(other) result diff --git a/examples/src/main/scala/basic/Playground.scala b/examples/src/main/scala/basic/Playground.scala index 227c76f..0b0bd40 100644 --- a/examples/src/main/scala/basic/Playground.scala +++ b/examples/src/main/scala/basic/Playground.scala @@ -3,8 +3,24 @@ package examples.basic import shapeful.* import scala.util.NotGiven +trait Y +object Y extends Value.As[Y, Float32] @main def playground(): Unit = + + + val t = Tensor.of[Y].zeros(Shape( + Axis["Batch"] -> 4, + Axis["Features"] -> 8, + )) + + val t2 = Tensor.of[Float32].apply(Shape( + Axis["Batch"] -> 4, + Axis["Features"] -> 8, + ), Array.fill(32)(1.0f)) + + val t3 = t2 + t2 + println("TensorV2 Playground") { println("MatMul tests") @@ -12,7 +28,7 @@ import scala.util.NotGiven 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ).map(_.toFloat) - val X = Tensor2( + val X = Tensor.of[Float32]( values = values, shape = Shape( Axis["Samples"] -> 10, @@ -31,7 +47,7 @@ import scala.util.NotGiven 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ).map(_.toFloat) - val X = Tensor2( + val X = Tensor.of[Float32]( values = values, shape = Shape( Axis["Samples"] -> 10, @@ -50,21 +66,21 @@ import scala.util.NotGiven } { println("DType and Device tests") - val t = Tensor.zeros(Shape( + val t = Tensor.of[Float32].zeros(Shape( Axis["Batch"] -> 1024, Axis["Features"] -> 512, )) println(t.shape) println(t.dtype) - println(t.asType(DType.Int32).dtype) + println(t.asType[Int32].dtype) println(t.device) println(t.toDevice(Device.CPU).device) } { - val x = Tensor.zeros(Shape( + val x = Tensor.of[Float32].zeros(Shape( Axis["Features"] -> 2, )) - val A = Tensor.zeros(Shape( + val A = Tensor.of[Float32].zeros(Shape( Axis["Samples"] -> 50, Axis["Features"] -> 2, )) @@ -86,7 +102,7 @@ import scala.util.NotGiven type Width = "width" type Height = "height" type Channel = "channel" - val X = Tensor.zeros(Shape( + val X = Tensor.of[Float32].zeros(Shape( Axis[Batch] -> 32, Axis[Frame] -> 64, Axis[Width] -> 256, @@ -111,7 +127,7 @@ import scala.util.NotGiven trait Width derives Label trait Height derives Label trait Channel derives Label - val X = Tensor.zeros(Shape( + val X = Tensor.of[Float32].zeros(Shape( Axis[Batch] -> 32, Axis[Frame] -> 64, Axis[Width] -> 256, @@ -130,60 +146,60 @@ import scala.util.NotGiven { println("Contraction with overlapping axes") import scala.util.NotGiven - def f[L1: Label, L2: Label, L3: Label]( - x: Tensor[(L1, L2)], - y: Tensor[(L2, L3)] - ): Tensor[(L1, L3, L2)] = + def f[L1: Label, L2: Label, L3: Label, V:Value]( + x: Tensor[(L1, L2), V], + y: Tensor[(L2, L3), V] + ): Tensor[(L1, L3, L2), V] = x.vmap(Axis[L1]){ xi => y.vmap(Axis[L3]){ yi => xi + yi } } - val z = f(Tensor.zeros(Shape( + val z = f(Tensor.of[Float32].zeros(Shape( Axis["A"] -> 2, Axis["B"] -> 3, - )), Tensor.zeros(Shape( + )), Tensor.of[Float32].zeros(Shape( Axis["B"] -> 3, Axis["C"] -> 4, ))) println(z.shape) } { - def f(t1: Tensor[("A", "C")], t2: Tensor[Tuple1["C"]]): Tensor[Tuple1["A"]] = + def f(t1: Tensor[("A", "C"), Float32], t2: Tensor[Tuple1["C"], Float32]): Tensor[Tuple1["A"], Float32] = t1.matmul(t2) - val t1 = Tensor.ones(Shape( + val t1 = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["C"] -> 2, )) - val t2 = Tensor.ones(Shape( + val t2 = Tensor.of[Float32].ones(Shape( Axis["C"] -> 2, )) println(f(t1, t2)) println("vmap 2") import scala.util.NotGiven - val x1 = Tensor.ones(Shape( + val x1 = Tensor.of[Float32].ones(Shape( Axis["B"] -> 1, Axis["A"] -> 2, Axis["C"] -> 2, )) - val x2 = Tensor.ones(Shape( + val x2 = Tensor.of[Float32].ones(Shape( Axis["B"] -> 1, Axis["C"] -> 2, )) } { - def f[L1: Label, L2: Label, L3: Label](x: Tensor[(L1, L2)], y: Tensor[(L2, L3)]) = + def f[L1: Label, L2: Label, L3: Label, V: Value](x: Tensor[(L1, L2), V], y: Tensor[(L2, L3), V]) = x.vmap(Axis[L1]){ xi => y.vmap(Axis[L3]){ yi => xi + yi } } println(f( - Tensor.zeros(Shape( + Tensor.of[Float32].zeros(Shape( Axis["A"] -> 2, Axis["B"] -> 3, )), - Tensor.zeros(Shape( + Tensor.of[Float32].zeros(Shape( Axis["B"] -> 3, Axis["C"] -> 4, )) @@ -192,7 +208,7 @@ import scala.util.NotGiven { println("Ravel") - val res = Tensor.ones(Shape( + val res = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["B"] -> 3, Axis["C"] -> 4, @@ -201,7 +217,7 @@ import scala.util.NotGiven } { println("swapaxes") - val res = Tensor.ones(Shape( + val res = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["B"] -> 3, Axis["C"] -> 4, @@ -210,13 +226,13 @@ import scala.util.NotGiven } { println("appendAxis / prependAxis") - val res = Tensor.ones(Shape( + val res = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["B"] -> 3, Axis["C"] -> 4, )).appendAxis(Axis["D"]) println(res.shape) - val res2 = Tensor.ones(Shape( + val res2 = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["B"] -> 3, Axis["C"] -> 4, @@ -225,7 +241,7 @@ import scala.util.NotGiven } { println("squeeze") - val res = Tensor.ones(Shape( + val res = Tensor.of[Float32].ones(Shape( Axis["A"] -> 1, Axis["B"] -> 3, Axis["C"] -> 1, @@ -236,21 +252,21 @@ import scala.util.NotGiven } { println("Slice") - val res = Tensor.ones(Shape( + val res = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["B"] -> 3, )).slice( Axis["B"] -> 2 ) println(res.shape) - val res2 = Tensor.ones(Shape( + val res2 = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["B"] -> 3, )).slice( Axis["B"] -> (0 to 1) ) println(res2.shape) - val res3 = Tensor.ones(Shape( + val res3 = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["B"] -> 3, Axis["C"] -> 4, @@ -269,19 +285,19 @@ import scala.util.NotGiven type Sector = "Sector" type Risk = "Risk" - val x = Tensor.ones(Shape( + val x = Tensor.of[Float32].ones(Shape( Axis[Batch] -> 6, Axis[Asset] -> 3, Axis[Region] -> 5, )) - val y = Tensor.ones(Shape( + val y = Tensor.of[Float32].ones(Shape( Axis[Region] -> 5, Axis[Batch] -> 6, Axis[Sector] -> 4, )) - val z = Tensor.ones(Shape( + val z = Tensor.of[Float32].ones(Shape( Axis[Sector] -> 4, Axis[Risk] -> 5, Axis[Batch] -> 6, @@ -306,9 +322,9 @@ import scala.util.NotGiven type Sector = "Sector" type Risk = "Risk" - val x = Tensor.ones(Shape(Axis[Batch] -> 6, Axis[Asset] -> 3, Axis[Region] -> 5)) - val y = Tensor.ones(Shape(Axis[Region] -> 5, Axis[Batch] -> 6, Axis[Sector] -> 4)) - val z = Tensor.ones(Shape(Axis[Sector] -> 4, Axis[Risk] -> 5, Axis[Batch] -> 6)) + val x = Tensor.of[Float32].ones(Shape(Axis[Batch] -> 6, Axis[Asset] -> 3, Axis[Region] -> 5)) + val y = Tensor.of[Float32].ones(Shape(Axis[Region] -> 5, Axis[Batch] -> 6, Axis[Sector] -> 4)) + val z = Tensor.of[Float32].ones(Shape(Axis[Sector] -> 4, Axis[Risk] -> 5, Axis[Batch] -> 6)) val res = zipvmap(Axis[Batch])((x, y, z)) { case (xi, yi, zi) => xi.sum + yi.sum + zi.sum @@ -317,24 +333,24 @@ import scala.util.NotGiven } { println("TensorWhere tests") - val x = Tensor.ones(Shape( + val x = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["B"] -> 3, )) - val y = Tensor.zeros(Shape( + val y = Tensor.of[Float32].zeros(Shape( Axis["A"] -> 2, Axis["B"] -> 3, )) - val condition = Tensor.zeros(Shape( + val condition = Tensor.of[Float32].zeros(Shape( Axis["A"] -> 2, Axis["B"] -> 3, - )).asType(DType.Bool) + )).asType[Bool](DType.Bool) val res = where(condition, x, y) println(res.shape) } { println("Diag") - val x = Tensor.ones(Shape( + val x = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["B"] -> 3, )) @@ -343,19 +359,19 @@ import scala.util.NotGiven } { println("Set") - val x = Tensor.ones(Shape( + val x = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["B"] -> 3, )).set(( Axis["A"] -> 1, Axis["B"] -> 2, - ))(Tensor0(42)) + ))(Tensor0.of[Float32](42)) println(x) - val v = Tensor1( + val v = Tensor1.of[Float32]( Axis["B"], - Array(100, 101, 102), + Array(100, 101, 102).map(_.toFloat), ) - val x2 = Tensor.ones(Shape( + val x2 = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["B"] -> 3, )).set( @@ -365,7 +381,7 @@ import scala.util.NotGiven } { // attention mechanism example - def softmax[L: Label](tensor: Tensor1[L]): Tensor1[L] = + def softmax[L: Label](tensor: Tensor1[L, Float32]): Tensor1[L, Float32] = val expTensor = tensor.exp val sumExp = expTensor.sum expTensor.vmap(Axis[L]) { _ / sumExp } @@ -376,29 +392,29 @@ import scala.util.NotGiven trait Context derives Label case class Attention( - wk: Tensor2[Value, Key], - wq: Tensor2[Value, Query], - wv: Tensor2[Value, Prime[Value]], + wk: Tensor2[Value, Key, Float32], + wq: Tensor2[Value, Query, Float32], + wv: Tensor2[Value, Prime[Value], Float32], ): private trait AttnWeights derives Label - def apply(x: Tensor2[Context, Value]): Tensor2[Context, Value] = + def apply(x: Tensor2[Context, Value, Float32]): Tensor2[Context, Value, Float32] = val k = x.contract(Axis[Value])(wk) val q = x.contract(Axis[Value])(wq) val v = x.contract(Axis[Value])(wv) - val dk = Tensor0(Math.sqrt(k.shape(Axis[Key])).toFloat) + val dk = Tensor0.of[Float32](Math.sqrt(k.shape(Axis[Key])).toFloat) val attnWeightsPrime = q.contract(Axis[Query ~ Key])(k) - .vmap(Axis[Context])(x => softmax(x).relabelTo(Axis[AttnWeights])) + .vmap(Axis[Context])(attnRow => softmax(attnRow).relabelTo(Axis[AttnWeights])) val resPrime = attnWeightsPrime.contract(Axis[AttnWeights ~ Context])(v) resPrime.relabel(Axis[Prime[Value]] -> Axis[Value]) trait Batch derives Label - val x = Tensor.ones(Shape(Axis[Batch] -> 32, Axis[Context] -> 128, Axis[Value] -> 64)) + val x = Tensor.of[Float32].ones(Shape(Axis[Batch] -> 32, Axis[Context] -> 128, Axis[Value] -> 64)) val attention = Attention( - Tensor.ones(Shape(Axis[Value] -> 64, Axis[Key] -> 64)), - Tensor.ones(Shape(Axis[Value] -> 64, Axis[Query] -> 64)), - Tensor.ones(Shape(Axis[Value] -> 64, Axis[Prime[Value]] -> 64)), + Tensor.of[Float32].ones(Shape(Axis[Value] -> 64, Axis[Key] -> 64)), + Tensor.of[Float32].ones(Shape(Axis[Value] -> 64, Axis[Query] -> 64)), + Tensor.of[Float32].ones(Shape(Axis[Value] -> 64, Axis[Prime[Value]] -> 64)), ) val newX = x.vmap(Axis[Batch])(attention(_)) println(newX.shape) @@ -406,21 +422,21 @@ import scala.util.NotGiven { println("Attention") // multi-head attention mechanism example - def softmax[L: Label](tensor: Tensor1[L]): Tensor1[L] = + def softmax[L: Label](tensor: Tensor1[L, Float32]): Tensor1[L, Float32] = val expTensor = tensor.exp val sumExp = expTensor.sum expTensor.vmap(Axis[L]) { _ / sumExp } - val X = Tensor.ones( + val X = Tensor.of[Float32].ones( Shape(Axis["Batch"] -> 32, Axis["Sequence"] -> 128, Axis["Value"] -> 64) ) - val WK = Tensor.ones( + val WK = Tensor.of[Float32].ones( Shape(Axis["Value"] -> 64, Axis["Key"] -> 8, Axis["Heads"] -> 8) ) - val WQ = Tensor.ones( + val WQ = Tensor.of[Float32].ones( Shape(Axis["Value"] -> 64, Axis["Query"] -> 8, Axis["Heads"] -> 8) ) - val WV = Tensor.ones( + val WV = Tensor.of[Float32].ones( Shape(Axis["Value"] -> 64, Axis["NewValue"] -> 8, Axis["Heads"] -> 8) ) val Xnew = X.vmap(Axis["Batch"]) { Xi => @@ -428,7 +444,7 @@ import scala.util.NotGiven val Q = Xi.contract(Axis["Value"])(WQ) val V = Xi.contract(Axis["Value"])(WV) val res = zipvmap(Axis["Heads"])(Q, K, V) { (Qi, Ki, Vi) => - val dk = Tensor0(Math.sqrt(Ki.shape(Axis["Key"])).toFloat) + val dk = Tensor0.of[Float32](Math.sqrt(Ki.shape(Axis["Key"])).toFloat) val AttnWeights = (Qi.contract(Axis["Query" ~ "Key"])(Ki) :/ dk) .relabelAll((Axis["NewSequence"], Axis["Weights"])) .vmap(Axis["NewSequence"])(softmax) @@ -452,9 +468,9 @@ import scala.util.NotGiven type AxisAB2 = Axis[A | B] type exists = Axis[A & B] - val ab = Tensor.ones(Shape(Axis[A] -> 2, Axis[B] -> 2)) - val ba = Tensor.ones(Shape(Axis[B] -> 2, Axis[A] -> 2)) - val cd = Tensor.ones(Shape(Axis[C] -> 2, Axis[D] -> 2)) + val ab = Tensor.of[Float32].ones(Shape(Axis[A] -> 2, Axis[B] -> 2)) + val ba = Tensor.of[Float32].ones(Shape(Axis[B] -> 2, Axis[A] -> 2)) + val cd = Tensor.of[Float32].ones(Shape(Axis[C] -> 2, Axis[D] -> 2)) val res2 = ab.slice(Axis[A | C] -> 1) @@ -482,23 +498,23 @@ import scala.util.NotGiven xxx, Labels.namesOfEmpty ) given Labels[(A | B) *: EmptyTuple] = yyy - val aorb = Tensor.ones(Shape(Axis[A | B] -> 2)(using xxx)) + val aorb = Tensor.of[Float32].ones(Shape(Axis[A | B] -> 2)(using xxx)) val lala = summon[Label[A]] // val r3 = aorb.slice(Axis[A] -> 1) val r3 = aorb.slice(Axis[A | B] -> 1) println(r3.shape) } { - val t1 = Tensor.ones(Shape( + val t1 = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["B"] -> 3, Axis["C"] -> 4, )) val t2 = t1.appendAxis(Axis["D"]) // val t3 = t1.appendAxis(Axis["A"]) // should not compile - def f[T <: Tuple : Labels](t: Tensor[T]) = + def f[T <: Tuple : Labels, V: Value](t: Tensor[T, V]) = t.appendAxis(Axis["D"]) - def f2[T <: Tuple : Labels](t: Tensor[T]) = + def f2[T <: Tuple : Labels, V: Value](t: Tensor[T, V]) = t.appendAxis(Axis["A"]) val t3 = f(t1) println(t3.shape) From 123ffe7a6f463303dd40ad3f0a16390f1437a1d8 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 23 Dec 2025 09:31:27 +0100 Subject: [PATCH 03/10] better asType syntax --- core/src/main/scala/shapeful/tensor/Tensor.scala | 4 ++-- examples/src/main/scala/basic/Playground.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/shapeful/tensor/Tensor.scala b/core/src/main/scala/shapeful/tensor/Tensor.scala index aacd706..9b2c84f 100644 --- a/core/src/main/scala/shapeful/tensor/Tensor.scala +++ b/core/src/main/scala/shapeful/tensor/Tensor.scala @@ -32,8 +32,8 @@ class Tensor[T <: Tuple : Labels, V : Value] private[tensor]( d => Jax.device_get(jaxValue).equals(d.jaxDevice) ).getOrElse(Device.Other(Jax.device_get(jaxValue).name.as[String])) - def asType[V2 : Value](newDType: DType): Tensor[T, V2] = - Tensor[T, V2](jaxValue = Jax.jnp.astype(jaxValue, JaxDType.jaxDtype(newDType))) + def asType[V2 : Value]: Tensor[T, V2] = + Tensor[T, V2](jaxValue = Jax.jnp.astype(jaxValue, JaxDType.jaxDtype(summon[Value[V2]].dtype))) def toDevice(newDevice: Device): Tensor[T, V] = Tensor[T, V](jaxValue = Jax.device_put(jaxValue, newDevice.jaxDevice)) diff --git a/examples/src/main/scala/basic/Playground.scala b/examples/src/main/scala/basic/Playground.scala index 0b0bd40..d6b7a0e 100644 --- a/examples/src/main/scala/basic/Playground.scala +++ b/examples/src/main/scala/basic/Playground.scala @@ -344,7 +344,7 @@ object Y extends Value.As[Y, Float32] val condition = Tensor.of[Float32].zeros(Shape( Axis["A"] -> 2, Axis["B"] -> 3, - )).asType[Bool](DType.Bool) + )).asType[Bool] val res = where(condition, x, y) println(res.shape) } From 6d56fa87f1a294908ec236088b10c8036f132c72 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 23 Dec 2025 09:50:45 +0100 Subject: [PATCH 04/10] putting autodiff back in, supporting Value type parameter --- .../scala/shapeful/autodiff/Autodiff.scala | 14 ++--- .../main/scala/shapeful/autodiff/PyTree.scala | 16 ++--- .../scala/shapeful/autodiff/TensorTree.scala | 22 +++---- core/src/main/scala/shapeful/jax/Jit.scala | 10 ++-- core/src/main/scala/shapeful/package.scala | 2 +- examples/src/main/scala/api/AutodiffAPI.scala | 60 +++++++++---------- 6 files changed, 63 insertions(+), 61 deletions(-) diff --git a/core/src/main/scala/shapeful/autodiff/Autodiff.scala b/core/src/main/scala/shapeful/autodiff/Autodiff.scala index c4fd9dd..6d36af6 100644 --- a/core/src/main/scala/shapeful/autodiff/Autodiff.scala +++ b/core/src/main/scala/shapeful/autodiff/Autodiff.scala @@ -1,6 +1,6 @@ package shapeful.autodiff -import shapeful.tensor.{Tensor, Tensor0, Tensor1, Tensor2, Shape, AxisIndices} +import shapeful.tensor.{Tensor, Tensor0, Tensor1, Tensor2, Shape, AxisIndices, Value} import shapeful.jax.Jax import me.shadaj.scalapy.py @@ -9,17 +9,17 @@ object Autodiff: type Gradient[In, Out] = Out match case EmptyTuple => EmptyTuple case h *: t => Gradient[In, h] *: Gradient[In, t] - case Tensor[outS] => GradientTensorVsInput[In, outS] + case Tensor[outS, v] => GradientTensorVsInput[In, outS, v] case _ => EmptyTuple - type GradientTensorVsInput[In, OutShape <: Tuple] = In match + type GradientTensorVsInput[In, OutShape <: Tuple, V] = In match case EmptyTuple => EmptyTuple - case h *: t => GradientTensorVsInput[h, OutShape] *: GradientTensorVsInput[t, OutShape] - case Tensor[inS] => Tensor[Tuple.Concat[OutShape, inS]] + case h *: t => GradientTensorVsInput[h, OutShape, V] *: GradientTensorVsInput[t, OutShape, V] + case Tensor[inS, v2] => Tensor[Tuple.Concat[OutShape, inS], V] - def grad[Input](f: Input => Tensor0)(using + def grad[Input, V : Value](f: Input => Tensor0[V])(using inTree: ToPyTree[Input], - outTree: ToPyTree[Tensor0], + outTree: ToPyTree[Tensor0[V]], ): Input => Input = val fpy = (jxpr: py.Dynamic) => diff --git a/core/src/main/scala/shapeful/autodiff/PyTree.scala b/core/src/main/scala/shapeful/autodiff/PyTree.scala index 9c6d52c..2f30047 100644 --- a/core/src/main/scala/shapeful/autodiff/PyTree.scala +++ b/core/src/main/scala/shapeful/autodiff/PyTree.scala @@ -1,6 +1,6 @@ package shapeful.autodiff -import shapeful.tensor.{Tensor, Shape} +import shapeful.tensor.{Tensor, Shape, Value} import shapeful.jax.Jax import me.shadaj.scalapy.py @@ -18,9 +18,9 @@ object ToPyTree: def apply[P](using pt: ToPyTree[P]): ToPyTree[P] = pt // Keep the tensor instance - given [T <: Tuple : Labels]: ToPyTree[Tensor[T]] with - def toPyTree(t: Tensor[T]): Jax.PyAny = t.jaxValue - def fromPyTree(p: Jax.PyAny): Tensor[T] = Tensor.fromPy(p.as[Jax.PyDynamic]) + given [T <: Tuple : Labels, V : Value]: ToPyTree[Tensor[T, V]] with + def toPyTree(t: Tensor[T, V]): Jax.PyAny = t.jaxValue + def fromPyTree(p: Jax.PyAny): Tensor[T, V] = Tensor.fromPy(p.as[Jax.PyDynamic]) // Tuple instances - these should have lower priority than specific case classes given tupleInstance[A, B](using ta: ToPyTree[A], tb: ToPyTree[B]): ToPyTree[(A, B)] with @@ -57,7 +57,9 @@ object ToPyTree: inline def reconstructField[T](pyElem: py.Dynamic): T = inline erasedValue[T] match - case _: Tensor[?] => Tensor.fromPy(pyElem.as[Jax.PyDynamic]).asInstanceOf[T] + case _: Tensor[?, ?] => + // For tensors, delegate to the ToPyTree instance which has the proper type info + compiletime.summonInline[ToPyTree[T]].fromPyTree(pyElem) case _: String => pyElem.as[String].asInstanceOf[T] case _: Int => @@ -84,8 +86,8 @@ object ToPyTree: inline def convertSingleField[T](elem: T): Jax.PyAny = inline erasedValue[T] match - case _: Tensor[?] => - elem.asInstanceOf[Tensor[?]].jaxValue + case _: Tensor[?, ?] => + elem.asInstanceOf[Tensor[?, ?]].jaxValue case _: String => py.Dynamic.global.str(elem.asInstanceOf[String]) case _: Int => diff --git a/core/src/main/scala/shapeful/autodiff/TensorTree.scala b/core/src/main/scala/shapeful/autodiff/TensorTree.scala index 7459f50..997b0ad 100644 --- a/core/src/main/scala/shapeful/autodiff/TensorTree.scala +++ b/core/src/main/scala/shapeful/autodiff/TensorTree.scala @@ -7,20 +7,20 @@ import scala.compiletime.* // TODO hot fix with retag and context parameter... maybe this can be improved? trait TensorTree[P]: - def map(p: P, f: [T <: Tuple] => Labels[T] ?=> Tensor[T] => Tensor[T]): P - def zipMap(p1: P, p2: P, f: [T <: Tuple] => Labels[T] ?=> (Tensor[T], Tensor[T]) => Tensor[T]): P + def map(p: P, f: [T <: Tuple, V] => (Labels[T], Value[V]) ?=> (Tensor[T, V] => Tensor[T, V])): P + def zipMap(p1: P, p2: P, f: [T <: Tuple, V] => (Labels[T], Value[V]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): P object TensorTree extends TensorTreeLowPriority: def apply[P](using pt: TensorTree[P]): TensorTree[P] = pt - given [Q <: Tuple](using n: Labels[Q]): TensorTree[Tensor[Q]] with - def map(t: Tensor[Q], f: [T <: Tuple] => Labels[T] ?=> Tensor[T] => Tensor[T]): Tensor[Q] = + given [Q <: Tuple, V](using n: Labels[Q], v: Value[V]): TensorTree[Tensor[Q, V]] with + def map(t: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T], Value[V2]) ?=> (Tensor[T, V2] => Tensor[T, V2])): Tensor[Q, V] = import TensorOps.retag - f[Q](using n)(t.retag[Q](using n)) + f[Q, V](using n, v)(t.retag[Q](using n)) - def zipMap(p1: Tensor[Q], p2: Tensor[Q], f: [T <: Tuple] => Labels[T] ?=> (Tensor[T], Tensor[T]) => Tensor[T]): Tensor[Q] = + def zipMap(p1: Tensor[Q, V], p2: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T], Value[V2]) ?=> ((Tensor[T, V2], Tensor[T, V2]) => Tensor[T, V2])): Tensor[Q, V] = import TensorOps.retag - f[Q](using n)(p1.retag[Q](using n), p2.retag[Q](using n)) + f[Q, V](using n, v)(p1.retag[Q](using n), p2.retag[Q](using n)) inline given derived[P <: Product](using m: Mirror.ProductOf[P]): TensorTree[P] = val elemInstances = summonAll[Tuple.Map[m.MirroredElemTypes, TensorTree]] @@ -31,13 +31,13 @@ object TensorTree extends TensorTreeLowPriority: instances: List[TensorTree[Any]], m: Mirror.ProductOf[P] ): TensorTree[P] = new TensorTree[P]: - def map(p: P, f: [T <: Tuple] => Labels[T] ?=> Tensor[T] => Tensor[T]): P = + def map(p: P, f: [T <: Tuple, V] => (Labels[T], Value[V]) ?=> (Tensor[T, V] => Tensor[T, V])): P = val inputs = p.productIterator.toList val mappedElems = inputs.zip(instances).map: case (elem, inst) => inst.map(elem, f) m.fromProduct(Tuple.fromArray(mappedElems.map(_.asInstanceOf[Object]).toArray)) - def zipMap(p1: P, p2: P, f: [T <: Tuple] => Labels[T] ?=> (Tensor[T], Tensor[T]) => Tensor[T]): P = + def zipMap(p1: P, p2: P, f: [T <: Tuple, V] => (Labels[T], Value[V]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): P = val inputs1 = p1.productIterator.toList val inputs2 = p2.productIterator.toList val mappedElems = inputs1.zip(inputs2).zip(instances).map: @@ -46,5 +46,5 @@ object TensorTree extends TensorTreeLowPriority: trait TensorTreeLowPriority: given identity[A]: TensorTree[A] = new TensorTree[A]: - def map(p: A, f: [T <: Tuple] => Labels[T] ?=> Tensor[T] => Tensor[T]): A = p - def zipMap(p1: A, p2: A, f: [T <: Tuple] => Labels[T] ?=> (Tensor[T], Tensor[T]) => Tensor[T]): A = p1 + def map(p: A, f: [T <: Tuple, V] => (Labels[T], Value[V]) ?=> (Tensor[T, V] => Tensor[T, V])): A = p + def zipMap(p1: A, p2: A, f: [T <: Tuple, V] => (Labels[T], Value[V]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): A = p1 diff --git a/core/src/main/scala/shapeful/jax/Jit.scala b/core/src/main/scala/shapeful/jax/Jit.scala index 064c15a..cb2672b 100644 --- a/core/src/main/scala/shapeful/jax/Jit.scala +++ b/core/src/main/scala/shapeful/jax/Jit.scala @@ -4,13 +4,13 @@ import shapeful.tensor.{Tensor, Shape, Labels} import shapeful.jax.{Jax, JaxDType} import shapeful.autodiff.ToPyTree import me.shadaj.scalapy.py - +import shapeful.tensor.Value object Jit: - def jit[PyTree: ToPyTree, OutT <: Tuple : Labels]( - f: PyTree => Tensor[OutT] - ): PyTree => Tensor[OutT] = + def jit[PyTree: ToPyTree, OutT <: Tuple : Labels, V : Value]( + f: PyTree => Tensor[OutT, V] + ): PyTree => Tensor[OutT, V] = // Python function that accepts a pytree val fpy = (pyTreePy: Jax.PyDynamic) => @@ -25,7 +25,7 @@ object Jit: (pyTree: PyTree) => val pyTreePy = ToPyTree[PyTree].toPyTree(pyTree) val resultJax = jitted(pyTreePy) - Tensor.fromPy[OutT](resultJax) + Tensor.fromPy[OutT, V](resultJax) def jit2[PyTree: ToPyTree, OutT <: Tuple : Labels]( f: PyTree => PyTree diff --git a/core/src/main/scala/shapeful/package.scala b/core/src/main/scala/shapeful/package.scala index 7d3ca3e..1c0e8b2 100644 --- a/core/src/main/scala/shapeful/package.scala +++ b/core/src/main/scala/shapeful/package.scala @@ -98,5 +98,5 @@ package object shapeful: export shapeful.autodiff.{Autodiff, TensorTree, ToPyTree} // Export Just-in-Time compilation -// export shapeful.jax.Jit.{jit, jit2} + export shapeful.jax.Jit.{jit, jit2} diff --git a/examples/src/main/scala/api/AutodiffAPI.scala b/examples/src/main/scala/api/AutodiffAPI.scala index 4a20901..6cf801f 100644 --- a/examples/src/main/scala/api/AutodiffAPI.scala +++ b/examples/src/main/scala/api/AutodiffAPI.scala @@ -4,112 +4,112 @@ import shapeful.* @main def autoDiffAPI(): Unit = - val AB = Tensor.ones(Shape( + val AB = Tensor.of[Float32].ones(Shape( Axis["A"] -> 10, Axis["B"] -> 5, )) - val AC = Tensor.ones(Shape( + val AC = Tensor.of[Float32].ones(Shape( Axis["A"]-> 10, Axis["C"] -> 5 )) - val ABCD = Tensor.ones(Shape( + val ABCD = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["B"] -> 3, Axis["C"] -> 4, Axis["D"] -> 5, )) { - def f(x: Tensor1["A"]): Tensor0 = x.sum + def f(x: Tensor1["A", Float32]): Tensor0[Float32] = x.sum val df = Autodiff.grad(f) - val delta = df(Tensor1(Axis["A"], Array.fill(10)(1.0f))) + val delta = df(Tensor1.of[Float32](Axis["A"], Array.fill(10)(1.0f))) println(delta.shape) } { - type ParamsTuple = (Tensor2["A", "B"], Tensor1["C"]) - def f(params: ParamsTuple): Tensor0 = + type ParamsTuple = (Tensor2["A", "B", Float32], Tensor1["C", Float32]) + def f(params: ParamsTuple): Tensor0[Float32] = params._1.sum + params._2.sum val df = Autodiff.grad(f) val delta = df(( - Tensor2(Axis["A"], Axis["B"], Array( + Tensor2.of[Float32](Axis["A"], Axis["B"], Array( Array.fill(5)(1.0f), Array.fill(5)(1.0f), )), - Tensor1(Axis["C"], Array.fill(5)(1.0f)) + Tensor1.of[Float32](Axis["C"], Array.fill(5)(1.0f)) )) println((delta._1.shape, delta._2.shape)) } { case class Params( - a: Tensor2["A", "B"], - b: Tensor1["C"], + a: Tensor2["A", "B", Float32], + b: Tensor1["C", Float32], ) derives TensorTree - def f(params: Params): Tensor0 = + def f(params: Params): Tensor0[Float32] = params.a.sum + params.b.sum val df = Autodiff.grad(f) val delta = df(Params( - Tensor2(Axis["A"], Axis["B"], Array( + Tensor2.of[Float32](Axis["A"], Axis["B"], Array( Array.fill(5)(1.0f), Array.fill(5)(1.0f), )), - Tensor1(Axis["C"], Array.fill(5)(1.0f)) + Tensor1.of[Float32](Axis["C"], Array.fill(5)(1.0f)) )) println(delta) } { - def f(x: Tensor1["A"]): Tensor1["A"] = x + def f(x: Tensor1["A", Float32]): Tensor1["A", Float32] = x val df = Autodiff.jacobian(f) - val delta = df(Tensor1(Axis["A"], Array.fill(10)(1.0f))) + val delta = df(Tensor1.of[Float32](Axis["A"], Array.fill(10)(1.0f))) println(delta.shape) } { - def f(x: Tensor1["A"]) = x.outerProduct(x) + def f(x: Tensor1["A", Float32]) = x.outerProduct(x) val df = Autodiff.jacobian(f) - val delta = df(Tensor1(Axis["A"], Array.fill(10)(1.0f))) + val delta = df(Tensor1.of[Float32](Axis["A"], Array.fill(10)(1.0f))) println(delta.shape) } { import shapeful.tensor.TensorOps.* - type ParamsTuple = (Tensor2["A", "B"], Tensor1["C"]) - def f(x: ParamsTuple): Tensor1["A"] = x._1.slice(Axis["B"] -> 0) + type ParamsTuple = (Tensor2["A", "B", Float32], Tensor1["C", Float32]) + def f(x: ParamsTuple): Tensor1["A", Float32] = x._1.slice(Axis["B"] -> 0) val df = Autodiff.jacobian(f) val delta = df(( - Tensor2(Axis["A"], Axis["B"], Array( + Tensor2.of[Float32](Axis["A"], Axis["B"], Array( Array.fill(5)(1.0f), Array.fill(5)(1.0f), )), - Tensor1(Axis["C"], Array.fill(5)(1.0f)) + Tensor1.of[Float32](Axis["C"], Array.fill(5)(1.0f)) )) println((delta._1.shape, delta._2.shape)) } { println("Hessian") - def f(x: Tensor1["A"]): Tensor0 = x.sum + def f(x: Tensor1["A", Float32]): Tensor0[Float32] = x.sum val df = Autodiff.jacobian(f) val ddf = Autodiff.jacobian(df) - val delta = ddf(Tensor1(Axis["A"], Array.fill(10)(1.0f))) + val delta = ddf(Tensor1.of[Float32](Axis["A"], Array.fill(10)(1.0f))) println(delta.shape) } { - def f(x: Tensor1["A"]): Tensor0 = x.sum + def f(x: Tensor1["A", Float32]): Tensor0[Float32] = x.sum val df = Autodiff.jacobian(f) val ddf = Autodiff.jacobian(df) val dddf = Autodiff.jacobian(ddf) val ddddf = Autodiff.jacobian(dddf) - val delta = ddddf(Tensor1(Axis["A"], Array.fill(10)(1.0f))) + val delta = ddddf(Tensor1.of[Float32](Axis["A"], Array.fill(10)(1.0f))) println(delta.shape) } { import shapeful.tensor.TensorOps.* - type ParamsTuple = (Tensor2["A", "B"], Tensor1["C"]) - def f(x: ParamsTuple): Tensor0 = x._1.sum + type ParamsTuple = (Tensor2["A", "B", Float32], Tensor1["C", Float32]) + def f(x: ParamsTuple): Tensor0[Float32] = x._1.sum val df = Autodiff.jacobian(f) val ddf = Autodiff.jacobian(df) val delta = ddf(( - Tensor2(Axis["A"], Axis["B"], Array( + Tensor2.of[Float32](Axis["A"], Axis["B"], Array( Array.fill(5)(1.0f), Array.fill(5)(1.0f), )), - Tensor1(Axis["C"], Array.fill(5)(1.0f)) + Tensor1.of[Float32](Axis["C"], Array.fill(5)(1.0f)) )) // TODO Is this actually correct, check it! println(( From f477cd94e33e177e506bf898c406f52166d9454b Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 23 Dec 2025 10:02:31 +0100 Subject: [PATCH 05/10] putting nn modules back in with Value type parameter --- nn/src/main/scala/nn/Activation.scala | 15 ++++----- nn/src/main/scala/nn/GradientDescent.scala | 5 ++- nn/src/main/scala/nn/LinearLayer.scala | 38 +++++++++++++--------- nn/src/main/scala/nn/package.scala | 3 ++ 4 files changed, 34 insertions(+), 27 deletions(-) diff --git a/nn/src/main/scala/nn/Activation.scala b/nn/src/main/scala/nn/Activation.scala index 014afd7..cbda90e 100644 --- a/nn/src/main/scala/nn/Activation.scala +++ b/nn/src/main/scala/nn/Activation.scala @@ -2,23 +2,22 @@ package nn import shapeful.* import shapeful.jax.Jax -import shapeful.Conversions.given object ActivationFunctions: // TODO rewrite relu, sigmoid to JAX - def sigmoid[T <: Tuple : Labels](t: Tensor[T]): Tensor[T] = - val ones = Tensor.ones(t.shape) - val minust = t :* -1.0f + def sigmoid[T <: Tuple : Labels, V : Value](t: Tensor[T, V]): Tensor[T, V] = + val ones = Tensor.of[V].ones(t.shape) + val minust = t.scale(Tensor0.of[V].apply(-1.0f)) ones / (ones + (minust).exp) - def relu[T <: Tuple : Labels](t: Tensor[T]): Tensor[T] = - val zeros = Tensor.zeros(t.shape) + def relu[T <: Tuple : Labels, V : Value](t: Tensor[T, V]): Tensor[T, V] = + val zeros = Tensor.of[V].zeros(t.shape) maximum(t, zeros) - def gelu[T <: Tuple : Labels](t: Tensor[T]): Tensor[T] = + def gelu[T <: Tuple : Labels, V : Value](t: Tensor[T, V]): Tensor[T, V] = Tensor.fromPy(Jax.jnn.gelu(t.jaxValue)) - def softmax[L: Label](t: Tensor1[L]): Tensor1[L] = + def softmax[L: Label, V : Value](t: Tensor1[L, V]): Tensor1[L, V] = Tensor.fromPy(Jax.jnn.softmax(t.jaxValue, axis = 0)) diff --git a/nn/src/main/scala/nn/GradientDescent.scala b/nn/src/main/scala/nn/GradientDescent.scala index 61ec80d..cd2f770 100644 --- a/nn/src/main/scala/nn/GradientDescent.scala +++ b/nn/src/main/scala/nn/GradientDescent.scala @@ -1,11 +1,10 @@ package nn import shapeful.* -import shapeful.Conversions.given case class GradientDescent[Params](df: Params => Params, lr: Float): def step(params: Params)(using paramTree: TensorTree[Params]) = val gradients = df(params) - paramTree.zipMap(gradients, params, [T <: Tuple] => (n: Labels[T]) ?=> (g: Tensor[T], p: Tensor[T]) => - p - (g :* lr) + paramTree.zipMap(gradients, params, [T <: Tuple, V] => (n: Labels[T], v: Value[V]) ?=> (g: Tensor[T, V], p: Tensor[T, V]) => + p - g.scale(Tensor0.of[V].apply(lr)) ) \ No newline at end of file diff --git a/nn/src/main/scala/nn/LinearLayer.scala b/nn/src/main/scala/nn/LinearLayer.scala index fb6c7f6..9163fad 100644 --- a/nn/src/main/scala/nn/LinearLayer.scala +++ b/nn/src/main/scala/nn/LinearLayer.scala @@ -1,12 +1,12 @@ package nn import shapeful.* -import shapeful.Conversions.given +import shapeful.random.Random import shapeful.random.Random.Key object LinearLayer: - case class Params[In, Out](weight: Tensor2[In, Out], bias: Tensor1[Out]) + case class Params[In, Out](weight: Tensor2[In, Out, DType.Float32.type], bias: Tensor1[Out, DType.Float32.type]) object Params: given [I : Label, O : Label]: TensorTree[Params[I, O]] = TensorTree.derived @@ -15,30 +15,36 @@ object LinearLayer: def apply[In : Label, Out : Label](paramKey: Key)( inputDim: Dim[In], outputDim: Dim[Out], - ): Params[In, Out] = Params( - weight = Tensor.randn(Shape(inputDim, outputDim), paramKey), - bias = Tensor.zeros(Shape(outputDim)), - ) - -case class LinearLayer[In : Label,Out : Label](params: LinearLayer.Params[In, Out]) extends Function[Tensor1[In], Tensor1[Out]]: - override def apply(x: Tensor1[In]): Tensor1[Out] = + ): Params[In, Out] = + val mean = Tensor0.of[DType.Float32.type].apply(0f) + val std = Tensor0.of[DType.Float32.type].apply(1f) + Params( + weight = Random.normal(paramKey, Shape(inputDim, outputDim), mean, std), + bias = Tensor.of[DType.Float32.type].zeros(Shape(outputDim)), + ) + +case class LinearLayer[In : Label,Out : Label](params: LinearLayer.Params[In, Out]) extends Function[Tensor1[In, DType.Float32.type], Tensor1[Out, DType.Float32.type]]: + override def apply(x: Tensor1[In, DType.Float32.type]): Tensor1[Out, DType.Float32.type] = import params.{weight, bias} x.contract(Axis[In])(weight) + bias object LinearMap: - case class Params[In](weight: Tensor1[In], bias: Tensor0) + case class Params[In](weight: Tensor1[In, DType.Float32.type], bias: Tensor0[DType.Float32.type]) object Params: given [In : Label]: TensorTree[Params[In]] = TensorTree.derived given [In : Label]: ToPyTree[Params[In]] = ToPyTree.derived - def apply[In : Label](paramKey: Key)(inputDim: Dim[In]): Params[In] = Params( - weight = Tensor.randn(Shape(inputDim), paramKey), - bias = Tensor0(0.0f), - ) + def apply[In : Label](paramKey: Key)(inputDim: Dim[In]): Params[In] = + val mean = Tensor0.of[DType.Float32.type].apply(0f) + val std = Tensor0.of[DType.Float32.type].apply(1f) + Params( + weight = Random.normal(paramKey, Shape(inputDim), mean, std), + bias = Tensor0.of[DType.Float32.type].apply(0.0f), + ) -case class LinearMap[In : Label](params: LinearMap.Params[In]) extends Function[Tensor1[In], Tensor0]: - override def apply(x: Tensor1[In]): Tensor0 = +case class LinearMap[In : Label](params: LinearMap.Params[In]) extends Function[Tensor1[In, DType.Float32.type], Tensor0[DType.Float32.type]]: + override def apply(x: Tensor1[In, DType.Float32.type]): Tensor0[DType.Float32.type] = import params.{weight, bias} x.contract(Axis[In])(weight) + bias diff --git a/nn/src/main/scala/nn/package.scala b/nn/src/main/scala/nn/package.scala index e69de29..7ebdc2c 100644 --- a/nn/src/main/scala/nn/package.scala +++ b/nn/src/main/scala/nn/package.scala @@ -0,0 +1,3 @@ +package object nn { + +} From 36408a294a46163e8d15263c5cb02f0c5eeddc8a Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 23 Dec 2025 11:12:55 +0100 Subject: [PATCH 06/10] fix some type parameters and add given conversion for tenor0 back in --- core/src/main/scala/shapeful/package.scala | 2 ++ core/src/main/scala/shapeful/random/Random.scala | 7 ++++--- core/src/main/scala/shapeful/tensor/Tensor.scala | 6 ++++++ nn/src/main/scala/nn/LinearLayer.scala | 12 ++++++------ 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/shapeful/package.scala b/core/src/main/scala/shapeful/package.scala index 1c0e8b2..7dc391e 100644 --- a/core/src/main/scala/shapeful/package.scala +++ b/core/src/main/scala/shapeful/package.scala @@ -100,3 +100,5 @@ package object shapeful: // Export Just-in-Time compilation export shapeful.jax.Jit.{jit, jit2} + object Conversions: + export shapeful.tensor.Tensor0.{given_Conversion_Int_Tensor0, given_Conversion_Float_Tensor0} \ No newline at end of file diff --git a/core/src/main/scala/shapeful/random/Random.scala b/core/src/main/scala/shapeful/random/Random.scala index 1311312..acc98b8 100644 --- a/core/src/main/scala/shapeful/random/Random.scala +++ b/core/src/main/scala/shapeful/random/Random.scala @@ -4,6 +4,7 @@ import shapeful.tensor.* import shapeful.tensor.TensorOps.* import shapeful.jax.{Jax, JaxDType} import me.shadaj.scalapy.py.SeqConverters +import shapeful.Float32 /** JAX-based random number generation with proper key management. * @@ -52,9 +53,9 @@ object Random: def normal[T <: Tuple : Labels, V : Value]( key: Key, shape: Shape[T], - mean: Tensor0[V] = Tensor0.of[DType.Float32.type].apply(0f), - std: Tensor0[V] = Tensor0.of[DType.Float32.type].apply(1f) - )(using ev: V =:= DType.Float32.type): Tensor[T, V] = + mean: Tensor0[V] = Tensor0.of[Float32].apply(0f), + std: Tensor0[V] = Tensor0.of[Float32].apply(1f) + ): Tensor[T, V] = val dtype = summon[Value[V]].dtype val jaxValues = Jax.jrandom.normal( key.jaxKey, diff --git a/core/src/main/scala/shapeful/tensor/Tensor.scala b/core/src/main/scala/shapeful/tensor/Tensor.scala index 9b2c84f..c9547c4 100644 --- a/core/src/main/scala/shapeful/tensor/Tensor.scala +++ b/core/src/main/scala/shapeful/tensor/Tensor.scala @@ -8,6 +8,8 @@ import shapeful.jax.Jax.PyDynamic import shapeful.tensor.{Label, Labels} //import shapeful.random.Random import me.shadaj.scalapy.py.SeqConverters +import shapeful.Float32 +import shapeful.Int32 enum Device(val jaxDevice: PyDynamic): case CPU extends Device(Jax.devices("cpu").head.as[PyDynamic]) @@ -101,6 +103,10 @@ object Tensor0: def apply(value: Boolean): Tensor0[V] = Tensor0[V](Jax.jnp.array(value, dtype=dtype.jaxType)) + + given Conversion[Float, Tensor0[Float32]] = (x: Float) => Tensor0.of[Float32](x) + given Conversion[Int, Tensor0[Int32]] = (x: Int) => Tensor0.of[Int32](x) + def of[V : Value]: Tensor0Builder[V] = new Tensor0Builder[V] private[tensor] def apply[V : Value](jaxValue: Jax.PyDynamic): Tensor0[V] = Tensor[EmptyTuple, V](jaxValue) diff --git a/nn/src/main/scala/nn/LinearLayer.scala b/nn/src/main/scala/nn/LinearLayer.scala index 9163fad..a592805 100644 --- a/nn/src/main/scala/nn/LinearLayer.scala +++ b/nn/src/main/scala/nn/LinearLayer.scala @@ -30,21 +30,21 @@ case class LinearLayer[In : Label,Out : Label](params: LinearLayer.Params[In, Ou object LinearMap: - case class Params[In](weight: Tensor1[In, DType.Float32.type], bias: Tensor0[DType.Float32.type]) + case class Params[In](weight: Tensor1[In, Float32], bias: Tensor0[Float32]) object Params: given [In : Label]: TensorTree[Params[In]] = TensorTree.derived given [In : Label]: ToPyTree[Params[In]] = ToPyTree.derived def apply[In : Label](paramKey: Key)(inputDim: Dim[In]): Params[In] = - val mean = Tensor0.of[DType.Float32.type].apply(0f) - val std = Tensor0.of[DType.Float32.type].apply(1f) + val mean = Tensor0.of[Float32].apply(0f) + val std = Tensor0.of[Float32].apply(1f) Params( weight = Random.normal(paramKey, Shape(inputDim), mean, std), - bias = Tensor0.of[DType.Float32.type].apply(0.0f), + bias = Tensor0.of[Float32].apply(0.0f), ) -case class LinearMap[In : Label](params: LinearMap.Params[In]) extends Function[Tensor1[In, DType.Float32.type], Tensor0[DType.Float32.type]]: - override def apply(x: Tensor1[In, DType.Float32.type]): Tensor0[DType.Float32.type] = +case class LinearMap[In : Label](params: LinearMap.Params[In]) extends Function[Tensor1[In, Float32], Tensor0[Float32]]: + override def apply(x: Tensor1[In, Float32]): Tensor0[Float32] = import params.{weight, bias} x.contract(Axis[In])(weight) + bias From 1b9a7feefa1c72b34c2b00ef86083346b110e247 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 23 Dec 2025 11:55:00 +0100 Subject: [PATCH 07/10] make LogisitcRegression and MLClassifier work again. --- .../main/scala/basic/LogisticRegression.scala | 25 +++---- .../main/scala/basic/MLClassifierMNist.scala | 75 +++++++++++-------- nn/src/main/scala/nn/LinearLayer.scala | 12 +-- 3 files changed, 62 insertions(+), 50 deletions(-) diff --git a/examples/src/main/scala/basic/LogisticRegression.scala b/examples/src/main/scala/basic/LogisticRegression.scala index ab86005..9eaa2d8 100644 --- a/examples/src/main/scala/basic/LogisticRegression.scala +++ b/examples/src/main/scala/basic/LogisticRegression.scala @@ -18,12 +18,11 @@ object LogisticRegression: case class BinaryLogisticRegression( params: BinaryLogisticRegression.Params, - ) extends Function[Tensor1[Feature], Tensor0]: + ) extends Function[Tensor1[Feature, Float32], Tensor0[Float32]]: private val linear = LinearMap(params.linearMap) - def logits(input: Tensor1[Feature]): Tensor0 = linear(input) - def probits(input: Tensor1[Feature]): Tensor0 = sigmoid(logits(input)) - def apply(input: Tensor1[Feature]): Tensor0 = logits(input) >= Tensor0(0f) - + def logits(input: Tensor1[Feature, Float32]): Tensor0[Float32] = linear(input) + def probits(input: Tensor1[Feature, Float32]): Tensor0[Float32] = sigmoid(logits(input)) + def apply(input: Tensor1[Feature, Float32]): Tensor0[Float32] = logits(input) >= Tensor0.of[Float32](0f) def main(args: Array[String]): Unit = import io.github.quafadas.table.* @@ -44,23 +43,23 @@ object LogisticRegression: }.toArray val labelData = dfShuffled.column["species"].toArray.map(_.toFloat) - val dataUnnormalized = Tensor2(Axis[Sample], Axis[Feature], featureData) - val dataLabels = Tensor1(Axis[Sample], labelData) + val dataUnnormalized = Tensor2.of[Float32](Axis[Sample], Axis[Feature], featureData) + val dataLabels = Tensor1.of[Int32](Axis[Sample], labelData) // TODO implement split val (trainingDataUnnormalized, valDataUnnormalized) = (dataUnnormalized, dataUnnormalized) val (trainLabels, valLabels) = (dataLabels, dataLabels) - def calcMeanAndStd(t: Tensor2[Sample, Feature]): (Tensor1[Feature], Tensor1[Feature]) = + def calcMeanAndStd(t: Tensor2[Sample, Feature, Float32]): (Tensor1[Feature, Float32], Tensor1[Feature, Float32]) = val mean = t.vmap(Axis[Feature])(_.mean) val std = zipvmap(Axis[Feature])(t, mean): case (x, m) => val epsilon = 1e-6f - (x :- m).pow(2).mean.sqrt + epsilon + (x :- m).pow(2f).mean.sqrt + epsilon // x.vmap(Axis[Sample])(xi => (xi - m).pow(2)).mean.sqrt + epsilon (mean, std) - def standardizeData(mean: Tensor1[Feature], std: Tensor1[Feature])(data: Tensor2[Sample, Feature]): Tensor2[Sample, Feature] = + def standardizeData(mean: Tensor1[Feature, Float32], std: Tensor1[Feature, Float32])(data: Tensor2[Sample, Feature, Float32]): Tensor2[Sample, Feature, Float32] = data.vapply(Axis[Feature])(feature => (feature - mean) / std) // (data :- mean) :/ std @@ -72,7 +71,7 @@ object LogisticRegression: val (initKey, restKey) = trainKey.split2() val (lossKey, sampleKey) = restKey.split2() - def loss(data: Tensor2[Sample, Feature])(params: BinaryLogisticRegression.Params): Tensor0 = + def loss(data: Tensor2[Sample, Feature, Float32])(params: BinaryLogisticRegression.Params): Tensor0[Float32] = val model = BinaryLogisticRegression(params) val losses = zipvmap(Axis[Sample])(data, trainLabels): case (sample, label) => @@ -99,8 +98,8 @@ object LogisticRegression: val valPreds = valData.vmap(Axis[Sample])(model) println(List( "epoch: " + index, - "trainAcc: " + (1 - (trainPreds - trainLabels).abs.mean), - "valAcc: " + (1 - (valPreds - valLabels).abs.mean) + "trainAcc: " + (1f - (trainPreds - trainLabels.asType[Float32]).abs.mean), + "valAcc: " + (1f - (valPreds - valLabels.asType[Float32]).abs.mean) ).mkString(", ")) .map((params, _) => params) .drop(2500) diff --git a/examples/src/main/scala/basic/MLClassifierMNist.scala b/examples/src/main/scala/basic/MLClassifierMNist.scala index 43bb32a..a9f3618 100644 --- a/examples/src/main/scala/basic/MLClassifierMNist.scala +++ b/examples/src/main/scala/basic/MLClassifierMNist.scala @@ -1,7 +1,6 @@ package examples.basic import shapeful.* -import shapeful.Conversions.given import nn.* import nn.ActivationFunctions.{relu, sigmoid} import shapeful.random.Random @@ -11,8 +10,8 @@ import scala.util.Try import java.io.{FileInputStream, DataInputStream, BufferedInputStream} def binaryCrossEntropy[L : Label]( - logits: Tensor1[L], label: Tensor0 -): Tensor0 = + logits: Tensor1[L, Float32], label: Tensor0[Int32] +): Tensor0[Float32] = val maxLogit = logits.max val stableExp = (logits :- maxLogit).exp val logSumExp = stableExp.sum.log + maxLogit @@ -48,23 +47,23 @@ object MLPClassifierMNist: layer2 = LinearLayer.Params(key2)(layer2Dim, outputDim), ) - case class MLP(params: MLP.Params) extends Function[Tensor2[Height, Width], Tensor0]: + case class MLP(params: MLP.Params) extends Function[Tensor2[Height, Width, Float32], Tensor0[Float32]]: private val layer1 = LinearLayer(params.layer1) private val layer2 = LinearLayer(params.layer2) def logits( - image: Tensor2[Height, Width], - ): Tensor1[Output] = + image: Tensor2[Height, Width, Float32], + ): Tensor1[Output, Float32] = val hidden = relu(layer1(image.ravel)) layer2(hidden) - override def apply(image: Tensor2[Height, Width]): Tensor0 = logits(image).argmax(Axis[Output]) + override def apply(image: Tensor2[Height, Width, Float32]): Tensor0[Float32] = logits(image).argmax(Axis[Output]) object MNISTLoader: private def readInt(dis: DataInputStream): Int = dis.readInt() - private def loadImagePixels[S <: Sample : Label](filename: String, maxImages: Option[Int] = None): Try[Tensor3[S, Height, Width]] = + private def loadImagePixels[S <: Sample : Label](filename: String, maxImages: Option[Int] = None): Try[Tensor3[S, Height, Width, Float32]] = Try { val dis = new DataInputStream(new BufferedInputStream(new FileInputStream(filename))) try @@ -86,12 +85,12 @@ object MLPClassifierMNist: // Convert bytes to floats with vectorized operation val allPixels = pixelBytes.map(b => (b & 0xff) / 255.0f) val shape = Shape(Axis[S] -> numImages, Axis[Height] -> rows, Axis[Width] -> cols) - val tensor = Tensor3(shape, allPixels, DType.Float32) + val tensor = Tensor3.of[Float32](shape, allPixels) tensor.toDevice(Device.CPU) finally dis.close() } - private def loadLabelsArray[S <: Sample : Label](filename: String, maxLabels: Option[Int] = None): Try[Tensor1[S]] = Try { + private def loadLabelsArray[S <: Sample : Label](filename: String, maxLabels: Option[Int] = None): Try[Tensor1[S, Int32]] = Try { val dis = new DataInputStream(new BufferedInputStream(new FileInputStream(filename))) try val magic = readInt(dis) @@ -106,12 +105,12 @@ object MLPClassifierMNist: for i <- 0.until(numLabels) do labels(i) = dis.readUnsignedByte() // Create Tensor1 from labels - specify the label type correctly - val tensor = Tensor1.fromInts(Axis[S], labels, DType.Int32) + val tensor = Tensor1.of[Int32].fromInts(Axis[S], labels) tensor.toDevice(Device.CPU) finally dis.close() } - private def createDataset[S <: Sample : Label](imagesFile: String, labelsFile: String, maxSamples: Option[Int] = None): Try[Tuple2[Tensor[(S, Height, Width)], Tensor1[S]]] = + private def createDataset[S <: Sample : Label](imagesFile: String, labelsFile: String, maxSamples: Option[Int] = None): Try[Tuple2[Tensor[(S, Height, Width), Float32], Tensor1[S, Int32]]] = for imagePixels <- loadImagePixels[S](imagesFile, maxSamples) labels <- loadLabelsArray[S](labelsFile, maxSamples) @@ -123,12 +122,12 @@ object MLPClassifierMNist: println(s"Created in-memory MNIST dataset with $numImages images") (imagePixels, labels) - def createTrainingDataset(dataDir: String = "data", maxSamples: Option[Int] = None): Try[Tuple2[Tensor[(TrainSample, Height, Width)], Tensor1[TrainSample]]] = + def createTrainingDataset(dataDir: String = "data", maxSamples: Option[Int] = None): Try[Tuple2[Tensor[(TrainSample, Height, Width), Float32], Tensor1[TrainSample, Int32]]] = val imagesFile = s"$dataDir/train-images-idx3-ubyte" val labelsFile = s"$dataDir/train-labels-idx1-ubyte" createDataset[TrainSample](imagesFile, labelsFile, maxSamples) - def createTestDataset(dataDir: String = "data", maxSamples: Option[Int] = None): Try[Tuple2[Tensor[(TestSample, Height, Width)], Tensor1[TestSample]]] = + def createTestDataset(dataDir: String = "data", maxSamples: Option[Int] = None): Try[Tuple2[Tensor[(TestSample, Height, Width), Float32], Tensor1[TestSample, Int32]]] = val imagesFile = s"$dataDir/t10k-images-idx3-ubyte" val labelsFile = s"$dataDir/t10k-labels-idx1-ubyte" createDataset[TestSample](imagesFile, labelsFile, maxSamples) @@ -145,15 +144,18 @@ object MLPClassifierMNist: val (trainX, trainY) = MNISTLoader.createTrainingDataset(maxSamples = Some(numSamples)).get val (testX, testY) = MNISTLoader.createTestDataset(maxSamples = Some(1024)).get - def batchLoss(batchImages: Tensor[(TrainSample, Height, Width)], batchLabels: Tensor1[TrainSample])( + def batchLoss(batchImages: Tensor[(TrainSample, Height, Width), Float32], batchLabels: Tensor1[TrainSample, Int32])( params: MLP.Params - ): Tensor0 = + ): Tensor0[Float32] = val model = MLP(params) - val losses = zipvmap(Axis[TrainSample])(batchImages, batchLabels): - case (image, label) => - val logits = model.logits(image) - binaryCrossEntropy(logits, label) - losses.mean + val batchSize = batchImages.shape(Axis[TrainSample]) + val losses = (0 until batchSize).map: idx => + val image = batchImages.slice(Axis[TrainSample] -> idx) + val label = batchLabels.slice(Axis[TrainSample] -> idx) + val logits = model.logits(image) + binaryCrossEntropy(logits, label) + .reduce(_ + _) + losses / Tensor0.of[Float32].apply(batchSize.toFloat) val initParams = MLP.Params( Axis[Height |*| Width] -> 28 * 28, @@ -161,14 +163,19 @@ object MLPClassifierMNist: Axis[Output] -> 10 )(initKey) - def accuracy[S <: Sample : Label](predictions: Tensor1[S], targets: Tensor1[S]): Tensor0 = - val matches = zipvmap(Axis[S])(predictions, targets): - case (pred, target) => Tensor0(pred.toInt == target.toInt) - matches.mean + def accuracy[S <: Sample : Label](predictions: Tensor1[S, Float32], targets: Tensor1[S, Int32]): Tensor0[Float32] = + val numSamples = predictions.shape(Axis[S]) + val matches = (0 until numSamples).map: idx => + val pred = predictions.slice(Axis[S] -> idx) + val target = targets.slice(Axis[S] -> idx) + val isMatch = if pred.toFloat.toInt == target.toInt then 1.0f else 0.0f + Tensor0.of[Float32].apply(isMatch) + .reduce(_ + _) + matches / Tensor0.of[Float32].apply(numSamples.toFloat) def miniBatchGradientDescent( - imageBatches: Seq[Tensor[(TrainSample, Height, Width)]], - labelBatches: Seq[Tensor1[TrainSample]], + imageBatches: Seq[Tensor[(TrainSample, Height, Width), Float32]], + labelBatches: Seq[Tensor1[TrainSample, Int32]], )( params: MLP.Params ): MLP.Params = @@ -200,10 +207,16 @@ object MLPClassifierMNist: case (params, epoch) => timed("Evaluation"): val model = MLP(params) - val testPreds = testX.vmap(Axis[TestSample])(model) - val testAccuracy = accuracy(testPreds, testY) - val trainPreds = trainX.vmap(Axis[TrainSample])(model) - val trainAccuracy = accuracy(trainPreds, trainY) + val testPreds = testX.vmap(Axis[TestSample]): img => + val pred = model(img) + pred.prependAxis(Axis["Pred"]) + val testPredictions = testPreds.squeeze(Axis["Pred"]) + val testAccuracy = accuracy(testPredictions, testY) + val trainPreds = trainX.vmap(Axis[TrainSample]): img => + val pred = model(img) + pred.prependAxis(Axis["Pred"]) + val trainPredictions = trainPreds.squeeze(Axis["Pred"]) + val trainAccuracy = accuracy(trainPredictions, trainY) println(List( s"Epoch $epoch", f"Test accuracy: ${testAccuracy.toFloat * 100}%.2f%%", diff --git a/nn/src/main/scala/nn/LinearLayer.scala b/nn/src/main/scala/nn/LinearLayer.scala index a592805..9307051 100644 --- a/nn/src/main/scala/nn/LinearLayer.scala +++ b/nn/src/main/scala/nn/LinearLayer.scala @@ -6,7 +6,7 @@ import shapeful.random.Random.Key object LinearLayer: - case class Params[In, Out](weight: Tensor2[In, Out, DType.Float32.type], bias: Tensor1[Out, DType.Float32.type]) + case class Params[In, Out](weight: Tensor2[In, Out, Float32], bias: Tensor1[Out, Float32]) object Params: given [I : Label, O : Label]: TensorTree[Params[I, O]] = TensorTree.derived @@ -16,15 +16,15 @@ object LinearLayer: inputDim: Dim[In], outputDim: Dim[Out], ): Params[In, Out] = - val mean = Tensor0.of[DType.Float32.type].apply(0f) - val std = Tensor0.of[DType.Float32.type].apply(1f) + val mean = Tensor0.of[Float32].apply(0f) + val std = Tensor0.of[Float32].apply(1f) Params( weight = Random.normal(paramKey, Shape(inputDim, outputDim), mean, std), - bias = Tensor.of[DType.Float32.type].zeros(Shape(outputDim)), + bias = Tensor.of[Float32].zeros(Shape(outputDim)), ) -case class LinearLayer[In : Label,Out : Label](params: LinearLayer.Params[In, Out]) extends Function[Tensor1[In, DType.Float32.type], Tensor1[Out, DType.Float32.type]]: - override def apply(x: Tensor1[In, DType.Float32.type]): Tensor1[Out, DType.Float32.type] = +case class LinearLayer[In : Label,Out : Label](params: LinearLayer.Params[In, Out]) extends Function[Tensor1[In, Float32], Tensor1[Out, Float32]]: + override def apply(x: Tensor1[In, Float32]): Tensor1[Out, Float32] = import params.{weight, bias} x.contract(Axis[In])(weight) + bias From 8f2809a3a71ab3e0974ad1bdb966d7bd936ed559 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 23 Dec 2025 16:07:26 +0100 Subject: [PATCH 08/10] make tensorapi compile. --- examples/src/main/scala/api/TensorAPI.scala | 36 ++++++++++----------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/examples/src/main/scala/api/TensorAPI.scala b/examples/src/main/scala/api/TensorAPI.scala index a0b2020..70872d8 100644 --- a/examples/src/main/scala/api/TensorAPI.scala +++ b/examples/src/main/scala/api/TensorAPI.scala @@ -7,7 +7,7 @@ import me.shadaj.scalapy.py.PythonException def opBlock[T](operation: String)(block: => T): Unit = val res = block block match - case t: Tensor[?] => + case t: Tensor[?,Float32] => println(f"$operation%-30s: ${t.shape}%-30s == ${py.eval("res.shape if hasattr(res, 'shape') else res")}") case v => println(f"$operation%-30s: $v%-30s == ${py.eval("res")}") @@ -19,28 +19,28 @@ def tensorAPI(): Unit = py.exec("import jax.numpy as jnp") py.exec("import einops") // py.eval("import jax.numpy as jnp") - val AB = Tensor.ones(Shape( + val AB = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["B"] -> 3, )) py.exec("ab = jnp.ones((2, 3))") - val AC = Tensor.ones(Shape( + val AC = Tensor.of[Float32].ones(Shape( Axis["A"]-> 2, Axis["C"] -> 4 )) py.exec("ac = jnp.ones((2, 4))") - val BCD = Tensor.ones(Shape( + val BCD = Tensor.of[Float32].ones(Shape( Axis["B"]-> 3, Axis["C"] -> 4, Axis["D"] -> 5, )) py.exec("bcd = jnp.ones((3, 4, 5))") - val BC = Tensor.ones(Shape( + val BC = Tensor.of[Float32].ones(Shape( Axis["B"]-> 3, Axis["C"] -> 4, )) py.exec("bc = jnp.ones((3, 4))") - val ABCD = Tensor.ones(Shape( + val ABCD = Tensor.of[Float32].ones(Shape( Axis["A"] -> 2, Axis["B"] -> 3, Axis["C"] -> 4, @@ -63,12 +63,12 @@ def tensorAPI(): Unit = } opBlock("Scalar broadcast: ABCD + Scalar") { py.exec("res = abcd + 5") - ABCD :+ Tensor0(5) + ABCD :+ Tensor0.of[Float32](5) } opBlock("Axes broadcasting backward: ABCD + CD") { py.exec("cd = jnp.ones((4,5))") py.exec("res = abcd + cd") - val CD = Tensor.ones(Shape( + val CD = Tensor.of[Float32].ones(Shape( Axis["C"] -> 4, Axis["D"] -> 5, )) ABCD :+ CD @@ -139,7 +139,7 @@ def tensorAPI(): Unit = } opBlock("pow") { py.exec("res = ab ** 2") - AB.pow(Tensor0(2)) + AB.pow(Tensor0.of[Float32](2)) } opBlock("sqrt") { py.exec("res = jnp.sqrt(ab)") @@ -339,7 +339,7 @@ def tensorAPI(): Unit = py.exec("res = ab.at[0, :].set(jnp.array([0,1,2]))") AB.set( Axis["A"] -> 0 - )(Tensor1(Axis["B"], Array(0, 1, 2))) + )(Tensor1.of[Float32](Axis["B"], Array(0, 1, 2))) } // set sub-matrix, AB.at[0:1, 0:1].set([[1,2],[3,4]]) opBlock("set ab axis=0:1,0:1") { @@ -347,7 +347,7 @@ def tensorAPI(): Unit = AB.set(( // TODO make (()) optional Axis["A"] -> (0 until 2), Axis["B"] -> (0 until 2), - ))(Tensor2( + ))(Tensor2.of[Float32]( Axis["A"], Axis["B"], Array( @@ -453,7 +453,7 @@ def tensorAPI(): Unit = */ opBlock("squeeze A from AB") { py.exec("tmp = jnp.ones((1,3))") // Setup - val tmp = Tensor.ones(Shape( + val tmp = Tensor.of[Float32].ones(Shape( Axis["A"] -> 1, Axis["B"] -> 3, )) @@ -516,9 +516,9 @@ def tensorAPI(): Unit = py.exec("condition = jnp.zeros(shape)") py.exec("res = jnp.where(condition, x, y)") val shape = Shape(Axis["A"] -> 2, Axis["B"] -> 3) - val x = Tensor.ones(shape) - val y = Tensor.zeros(shape) - val condition = Tensor.zeros(shape) + val x = Tensor.of[Float32].ones(shape) + val y = Tensor.of[Float32].zeros(shape) + val condition = Tensor.of[Bool].zeros(shape) where(condition, x, y) } } @@ -543,7 +543,7 @@ def tensorAPI(): Unit = AB.trace } opBlock("det abc1c2 over axes C1 and C2") { - val ABC1C2 = Tensor.ones( + val ABC1C2 = Tensor.of[Float32].ones( Shape(Axis["A"] -> 2, Axis["B"] -> 3, Axis["C1"] -> 4, Axis["C2"] -> 4) ) py.exec("abc1c2 = jnp.ones((2, 3, 4, 4))") @@ -551,7 +551,7 @@ def tensorAPI(): Unit = ABC1C2.det(Axis["C1"], Axis["C2"]) } opBlock("det A1A2") { - val A1A2 = Tensor.ones( + val A1A2 = Tensor.of[Float32].ones( Shape(Axis["A1"] -> 2, Axis["A2"] -> 2) ) py.exec("a1a2 = jnp.ones((2, 2))") @@ -563,7 +563,7 @@ def tensorAPI(): Unit = AB.norm } opBlock("inv AB") { - val a1a2 = Tensor2( + val a1a2 = Tensor2.of[Float32]( Axis["A1"], Axis["A2"], Array( From e49d36f14c1f46f394c6bc58e60992878bc51b8e Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 23 Dec 2025 16:07:49 +0100 Subject: [PATCH 09/10] uncomment GPT2 it will be rewritten in the future anyways. --- examples/src/main/scala/basic/GPT2.scala | 728 +++++++++++------------ 1 file changed, 364 insertions(+), 364 deletions(-) diff --git a/examples/src/main/scala/basic/GPT2.scala b/examples/src/main/scala/basic/GPT2.scala index 807bb60..21e57ac 100644 --- a/examples/src/main/scala/basic/GPT2.scala +++ b/examples/src/main/scala/basic/GPT2.scala @@ -1,382 +1,382 @@ -package examples.basic - -import shapeful.* -import shapeful.Conversions.given - -import nn.ActivationFunctions.* - -// Dimensions -trait Vocab derives Label // 50257 -trait Embedding derives Label // 768 -trait Context derives Label // 1024 -trait Inner derives Label // 3072 - -trait Batch derives Label - -case class LNParams( - weight: Tensor1[Embedding], - bias: Tensor1[Embedding] -) - -case class LinearParams[In, Out]( - weight: Tensor2[In, Out], - bias: Tensor1[Out] -) - -// Values stored as Query, Key, Value triplets -// type QKV[L <: String] = L * L * L -// case class AttentionParams( - // Fused Q,K,V projection (768 -> 2304) -// cAttn: LinearParams[Embedding, QKV[Embedding]], - // Output projection (768 -> 768) -// cProj: LinearParams[Embedding, Embedding] -//) - - -trait Heads derives Label -trait Key derives Label -trait Query derives Label -trait Value derives Label - -case class MultiHeadAttentionParams( - WK : Tensor3[Heads, Embedding, Key], - WKBias: Tensor2[Heads, Key], - WQ : Tensor3[Heads, Embedding, Query], - WQBias: Tensor2[Heads, Query], - WV : Tensor3[Heads, Embedding, Value], - WVBias: Tensor2[Heads, Value], - proj: LinearParams[Heads |*| Value, Embedding], -) derives ToPyTree - -type QKV = Heads |*| Query |*| Heads |*| Key |*| Heads |*| Value - -object MultiHeadAttentionParams: - def apply( - cAttn: LinearParams[Embedding, QKV], - cProj: LinearParams[Heads |*| Value, Embedding], - numHeads: Int, - ): MultiHeadAttentionParams = - def splitWeightToHeads[L](t: Tensor2[Embedding, Heads |*| L], numHeads: Int)(using label: Label[L]): Tensor3[Heads, Embedding, L] = - val tLength = t.shape(Axis[Heads |*| L]) - require(tLength % numHeads == 0, s"T length $tLength not divisible by numHeads $numHeads") - t.rearrange( - (Axis[Heads], Axis[Embedding], Axis[L]), - (Axis[Heads] -> numHeads, Axis[L] ->(tLength / numHeads)), - ) - def splitBiasToHeads[L](t: Tensor1[Heads |*| L], numHeads: Int)(using label: Label[L]): Tensor2[Heads, L] = - val tLength = t.shape(Axis[Heads |*| L]) - require(tLength % numHeads == 0, s"T length $tLength not divisible by numHeads $numHeads") - t.rearrange( - (Axis[Heads], Axis[L]), - (Axis[Heads] -> numHeads, Axis[L] ->(tLength / numHeads)), - ) - val qkvLength = cAttn.weight.shape(Axis[QKV]) - require(qkvLength % 3 == 0, s"QKV length $qkvLength not divisible by 3") - val (qLength, kLength, vLength) = (qkvLength / 3, qkvLength / 3, qkvLength / 3) +// package examples.basic + +// import shapeful.* +// import shapeful.Conversions.given + +// import nn.ActivationFunctions.* + +// // Dimensions +// trait Vocab derives Label // 50257 +// trait Embedding derives Label // 768 +// trait Context derives Label // 1024 +// trait Inner derives Label // 3072 + +// trait Batch derives Label + +// case class LNParams( +// weight: Tensor1[Embedding], +// bias: Tensor1[Embedding] +// ) + +// case class LinearParams[In, Out]( +// weight: Tensor2[In, Out], +// bias: Tensor1[Out] +// ) + +// // Values stored as Query, Key, Value triplets +// // type QKV[L <: String] = L * L * L +// // case class AttentionParams( +// // Fused Q,K,V projection (768 -> 2304) +// // cAttn: LinearParams[Embedding, QKV[Embedding]], +// // Output projection (768 -> 768) +// // cProj: LinearParams[Embedding, Embedding] +// //) + + +// trait Heads derives Label +// trait Key derives Label +// trait Query derives Label +// trait Value derives Label + +// case class MultiHeadAttentionParams( +// WK : Tensor3[Heads, Embedding, Key], +// WKBias: Tensor2[Heads, Key], +// WQ : Tensor3[Heads, Embedding, Query], +// WQBias: Tensor2[Heads, Query], +// WV : Tensor3[Heads, Embedding, Value], +// WVBias: Tensor2[Heads, Value], +// proj: LinearParams[Heads |*| Value, Embedding], +// ) derives ToPyTree + +// type QKV = Heads |*| Query |*| Heads |*| Key |*| Heads |*| Value + +// object MultiHeadAttentionParams: +// def apply( +// cAttn: LinearParams[Embedding, QKV], +// cProj: LinearParams[Heads |*| Value, Embedding], +// numHeads: Int, +// ): MultiHeadAttentionParams = +// def splitWeightToHeads[L](t: Tensor2[Embedding, Heads |*| L], numHeads: Int)(using label: Label[L]): Tensor3[Heads, Embedding, L] = +// val tLength = t.shape(Axis[Heads |*| L]) +// require(tLength % numHeads == 0, s"T length $tLength not divisible by numHeads $numHeads") +// t.rearrange( +// (Axis[Heads], Axis[Embedding], Axis[L]), +// (Axis[Heads] -> numHeads, Axis[L] ->(tLength / numHeads)), +// ) +// def splitBiasToHeads[L](t: Tensor1[Heads |*| L], numHeads: Int)(using label: Label[L]): Tensor2[Heads, L] = +// val tLength = t.shape(Axis[Heads |*| L]) +// require(tLength % numHeads == 0, s"T length $tLength not divisible by numHeads $numHeads") +// t.rearrange( +// (Axis[Heads], Axis[L]), +// (Axis[Heads] -> numHeads, Axis[L] ->(tLength / numHeads)), +// ) +// val qkvLength = cAttn.weight.shape(Axis[QKV]) +// require(qkvLength % 3 == 0, s"QKV length $qkvLength not divisible by 3") +// val (qLength, kLength, vLength) = (qkvLength / 3, qkvLength / 3, qkvLength / 3) - val wq = cAttn.weight.slice(Axis[QKV] -> (0 until qLength)).relabel(Axis[QKV] -> Axis[Heads |*| Query]) - val wkb = cAttn.bias.slice(Axis[QKV] -> (qLength until qLength + kLength)).relabel(Axis[QKV] -> Axis[Heads |*| Key]) - val wk = cAttn.weight.slice(Axis[QKV] -> (qLength until qLength + kLength)).relabel(Axis[QKV] -> Axis[Heads |*| Key]) - val wqb = cAttn.bias.slice(Axis[QKV] -> (0 until qLength)).relabel(Axis[QKV] -> Axis[Heads |*| Query]) - val wv = cAttn.weight.slice(Axis[QKV] -> (qLength + kLength until qkvLength)).relabel(Axis[QKV] -> Axis[Heads |*| Value]) - val wvb = cAttn.bias.slice(Axis[QKV] -> (qLength + kLength until qkvLength)).relabel(Axis[QKV] -> Axis[Heads |*| Value]) +// val wq = cAttn.weight.slice(Axis[QKV] -> (0 until qLength)).relabel(Axis[QKV] -> Axis[Heads |*| Query]) +// val wkb = cAttn.bias.slice(Axis[QKV] -> (qLength until qLength + kLength)).relabel(Axis[QKV] -> Axis[Heads |*| Key]) +// val wk = cAttn.weight.slice(Axis[QKV] -> (qLength until qLength + kLength)).relabel(Axis[QKV] -> Axis[Heads |*| Key]) +// val wqb = cAttn.bias.slice(Axis[QKV] -> (0 until qLength)).relabel(Axis[QKV] -> Axis[Heads |*| Query]) +// val wv = cAttn.weight.slice(Axis[QKV] -> (qLength + kLength until qkvLength)).relabel(Axis[QKV] -> Axis[Heads |*| Value]) +// val wvb = cAttn.bias.slice(Axis[QKV] -> (qLength + kLength until qkvLength)).relabel(Axis[QKV] -> Axis[Heads |*| Value]) - MultiHeadAttentionParams( - WQ = splitWeightToHeads(wq, numHeads), - WQBias = splitBiasToHeads(wqb, numHeads), - WK = splitWeightToHeads(wk, numHeads), - WKBias = splitBiasToHeads(wkb, numHeads), - WV = splitWeightToHeads(wv, numHeads), - WVBias = splitBiasToHeads(wvb, numHeads), - proj = cProj, - ) - - -case class MLPParams( - c_fc: LinearParams[Embedding, Inner], // 768 -> 3072 - c_proj: LinearParams[Inner, Embedding] // 3072 -> 768 -) - -type WTEParams = Tensor2[Vocab, Embedding] -type WPEParams = Tensor2[Context, Embedding] - -case class HiddenParams( - ln1 : LNParams, - attn : MultiHeadAttentionParams, - ln2 : LNParams, - mlp: MLPParams -) - -case class GPT2Params( - wpe: WPEParams, - wte: WTEParams, - layers: List[HiddenParams], - ln_f : LNParams -) - -case class GPT2(params: GPT2Params): - - private case class LinearLayer[In : Label, Out : Label](params: LinearParams[In, Out]) extends Function[Tensor1[In], Tensor1[Out]]: - override def apply(x: Tensor1[In]): Tensor1[Out] = - x.contract(Axis[In])(params.weight) + params.bias - - private case class MLP(params: MLPParams) extends Function[Tensor2[Context, Embedding], Tensor2[Context, Embedding]]: - - private val hiddenLayer = LinearLayer(params.c_fc) - private val outputLayer = LinearLayer(params.c_proj) - // TODO add dropout - - def apply(in: Tensor2[Context, Embedding]): Tensor2[Context, Embedding] = - in.vmap(Axis[Context])(x => - val hidden = gelu(hiddenLayer(x)) - outputLayer(hidden) - ) - - private case class MultiHeadAttention(params: MultiHeadAttentionParams) extends Function[Tensor2[Context, Embedding], Tensor2[Context, Embedding]]: - - private val projection = LinearLayer(params.proj) +// MultiHeadAttentionParams( +// WQ = splitWeightToHeads(wq, numHeads), +// WQBias = splitBiasToHeads(wqb, numHeads), +// WK = splitWeightToHeads(wk, numHeads), +// WKBias = splitBiasToHeads(wkb, numHeads), +// WV = splitWeightToHeads(wv, numHeads), +// WVBias = splitBiasToHeads(wvb, numHeads), +// proj = cProj, +// ) + + +// case class MLPParams( +// c_fc: LinearParams[Embedding, Inner], // 768 -> 3072 +// c_proj: LinearParams[Inner, Embedding] // 3072 -> 768 +// ) + +// type WTEParams = Tensor2[Vocab, Embedding] +// type WPEParams = Tensor2[Context, Embedding] + +// case class HiddenParams( +// ln1 : LNParams, +// attn : MultiHeadAttentionParams, +// ln2 : LNParams, +// mlp: MLPParams +// ) + +// case class GPT2Params( +// wpe: WPEParams, +// wte: WTEParams, +// layers: List[HiddenParams], +// ln_f : LNParams +// ) + +// case class GPT2(params: GPT2Params): + +// private case class LinearLayer[In : Label, Out : Label](params: LinearParams[In, Out]) extends Function[Tensor1[In], Tensor1[Out]]: +// override def apply(x: Tensor1[In]): Tensor1[Out] = +// x.contract(Axis[In])(params.weight) + params.bias + +// private case class MLP(params: MLPParams) extends Function[Tensor2[Context, Embedding], Tensor2[Context, Embedding]]: + +// private val hiddenLayer = LinearLayer(params.c_fc) +// private val outputLayer = LinearLayer(params.c_proj) +// // TODO add dropout + +// def apply(in: Tensor2[Context, Embedding]): Tensor2[Context, Embedding] = +// in.vmap(Axis[Context])(x => +// val hidden = gelu(hiddenLayer(x)) +// outputLayer(hidden) +// ) + +// private case class MultiHeadAttention(params: MultiHeadAttentionParams) extends Function[Tensor2[Context, Embedding], Tensor2[Context, Embedding]]: + +// private val projection = LinearLayer(params.proj) - def apply(X : Tensor2[Context, Embedding]): Tensor2[Context, Embedding] = - val heads = zipvmap(Axis[Heads])(params.WQ, params.WK, params.WV): (wqi, wki, wvi) => - attention(wqi, wki, wvi)(X) - heads.vmap(Axis[Context])(heads => projection(heads.ravel)) - - private def attention( - wq : Tensor2[Embedding, Query], - wk : Tensor2[Embedding, Key], - wv : Tensor2[Embedding, Value]) - (x : Tensor2[Context, Embedding]): Tensor2[Context, Value] = - trait AttnWeights derives Label - val q = x.contract(Axis[Embedding])(wq) - val k = x.contract(Axis[Embedding])(wk) - val v = x.contract(Axis[Embedding])(wv) - val dk = Tensor0(Math.sqrt(k.shape(Axis[Key])).toFloat) - val attnWeights = (q.contract(Axis[Query ~ Key])(k) :/ dk) - .vmap(Axis[Context])(x => softmax(x).relabelTo(Axis[AttnWeights])) - val result = attnWeights.contract(Axis[AttnWeights ~ Context])(v) - result - - private case class LayerNorm(params: LNParams) extends Function[Tensor1[Embedding], Tensor1[Embedding]]: - - private def standardize(x: Tensor1[Embedding]): Tensor1[Embedding] = - val mean = x.mean - val x0 = x :- mean - val variance = x0.pow(2).mean - val epsilon = 1e-6f - x0 :/ (variance + epsilon).sqrt - - def apply(x: Tensor1[Embedding]): Tensor1[Embedding] = - val normalized = standardize(x) - normalized * params.weight + params.bias - - private case class TransformerLayer(params: HiddenParams) extends Function[Tensor2[Context, Embedding], Tensor2[Context, Embedding]]: - - private val mlp = MLP(params.mlp) - private val multiHeadAttention = MultiHeadAttention(params.attn) - private val preNormalization = LayerNorm(params.ln1) - private val postNormalization = LayerNorm(params.ln2) +// def apply(X : Tensor2[Context, Embedding]): Tensor2[Context, Embedding] = +// val heads = zipvmap(Axis[Heads])(params.WQ, params.WK, params.WV): (wqi, wki, wvi) => +// attention(wqi, wki, wvi)(X) +// heads.vmap(Axis[Context])(heads => projection(heads.ravel)) + +// private def attention( +// wq : Tensor2[Embedding, Query], +// wk : Tensor2[Embedding, Key], +// wv : Tensor2[Embedding, Value]) +// (x : Tensor2[Context, Embedding]): Tensor2[Context, Value] = +// trait AttnWeights derives Label +// val q = x.contract(Axis[Embedding])(wq) +// val k = x.contract(Axis[Embedding])(wk) +// val v = x.contract(Axis[Embedding])(wv) +// val dk = Tensor0(Math.sqrt(k.shape(Axis[Key])).toFloat) +// val attnWeights = (q.contract(Axis[Query ~ Key])(k) :/ dk) +// .vmap(Axis[Context])(x => softmax(x).relabelTo(Axis[AttnWeights])) +// val result = attnWeights.contract(Axis[AttnWeights ~ Context])(v) +// result + +// private case class LayerNorm(params: LNParams) extends Function[Tensor1[Embedding], Tensor1[Embedding]]: + +// private def standardize(x: Tensor1[Embedding]): Tensor1[Embedding] = +// val mean = x.mean +// val x0 = x :- mean +// val variance = x0.pow(2).mean +// val epsilon = 1e-6f +// x0 :/ (variance + epsilon).sqrt + +// def apply(x: Tensor1[Embedding]): Tensor1[Embedding] = +// val normalized = standardize(x) +// normalized * params.weight + params.bias + +// private case class TransformerLayer(params: HiddenParams) extends Function[Tensor2[Context, Embedding], Tensor2[Context, Embedding]]: + +// private val mlp = MLP(params.mlp) +// private val multiHeadAttention = MultiHeadAttention(params.attn) +// private val preNormalization = LayerNorm(params.ln1) +// private val postNormalization = LayerNorm(params.ln2) - def apply(t: Tensor2[Context, Embedding]): Tensor2[Context, Embedding] = - val attnDelta = multiHeadAttention(t.vmap(Axis[Context])(preNormalization)) - val t2 = t + attnDelta - val mlpDelta = mlp(t2.vmap(Axis[Context])(postNormalization)) - t2 + mlpDelta - - private case class Transformer(layers: List[TransformerLayer]) extends Function[Tensor2[Context, Embedding], Tensor2[Context, Embedding]]: - override def apply(t: Tensor2[Context, Embedding]): Tensor2[Context, Embedding] = - layers.foldLeft(t) { (acc, layer) => layer(acc) } - - private val transformer = Transformer(params.layers.map(layerParams => TransformerLayer(layerParams))) - private val finalNormalization = LayerNorm(params.ln_f) - private val outputLayer = LinearLayer(LinearParams( - weight = params.wte.transpose, - bias = Tensor.zeros(Shape(params.wte.shape.dim(Axis[Vocab]))), - )) - - // type Int32Tensor1[L <: String] = Tensor1[L] { type DType = DType.UInt32.type } - - private def embedder(tokens: Tensor1[Context]): Tensor2[Context, Embedding] = - tokens.vmap(Axis[Context])(token => - params.wte.slice(Axis[Vocab] -> token.toInt) - ) - - private def addPositionEncoding(embeddings: Tensor2[Context, Embedding]): Tensor2[Context, Embedding] = - embeddings + params.wpe - - def logits(inputTokens: Tensor[(Batch, Context)]): Tensor[(Batch, Context, Vocab)] = - inputTokens.vmap(Axis[Batch])(tokens => - val startEmbeddings = addPositionEncoding(embedder(tokens)) - val endEmbeddings = transformer(startEmbeddings) - endEmbeddings.vmap(Axis[Context])(x => - val xNorm = finalNormalization(x) - outputLayer(xNorm) - ) - ) - - def probits(inputTokens: Tensor[(Batch, Context)]): Tensor[(Batch, Context, Vocab)] = - logits(inputTokens).vapply(Axis[Vocab])(softmax) - - def apply(inputTokens: Tensor[(Batch, Context)]): Tensor[(Batch, Context)] = - logits(inputTokens).argmax(Axis[Vocab]) +// def apply(t: Tensor2[Context, Embedding]): Tensor2[Context, Embedding] = +// val attnDelta = multiHeadAttention(t.vmap(Axis[Context])(preNormalization)) +// val t2 = t + attnDelta +// val mlpDelta = mlp(t2.vmap(Axis[Context])(postNormalization)) +// t2 + mlpDelta + +// private case class Transformer(layers: List[TransformerLayer]) extends Function[Tensor2[Context, Embedding], Tensor2[Context, Embedding]]: +// override def apply(t: Tensor2[Context, Embedding]): Tensor2[Context, Embedding] = +// layers.foldLeft(t) { (acc, layer) => layer(acc) } + +// private val transformer = Transformer(params.layers.map(layerParams => TransformerLayer(layerParams))) +// private val finalNormalization = LayerNorm(params.ln_f) +// private val outputLayer = LinearLayer(LinearParams( +// weight = params.wte.transpose, +// bias = Tensor.zeros(Shape(params.wte.shape.dim(Axis[Vocab]))), +// )) + +// // type Int32Tensor1[L <: String] = Tensor1[L] { type DType = DType.UInt32.type } + +// private def embedder(tokens: Tensor1[Context]): Tensor2[Context, Embedding] = +// tokens.vmap(Axis[Context])(token => +// params.wte.slice(Axis[Vocab] -> token.toInt) +// ) + +// private def addPositionEncoding(embeddings: Tensor2[Context, Embedding]): Tensor2[Context, Embedding] = +// embeddings + params.wpe + +// def logits(inputTokens: Tensor[(Batch, Context)]): Tensor[(Batch, Context, Vocab)] = +// inputTokens.vmap(Axis[Batch])(tokens => +// val startEmbeddings = addPositionEncoding(embedder(tokens)) +// val endEmbeddings = transformer(startEmbeddings) +// endEmbeddings.vmap(Axis[Context])(x => +// val xNorm = finalNormalization(x) +// outputLayer(xNorm) +// ) +// ) + +// def probits(inputTokens: Tensor[(Batch, Context)]): Tensor[(Batch, Context, Vocab)] = +// logits(inputTokens).vapply(Axis[Vocab])(softmax) + +// def apply(inputTokens: Tensor[(Batch, Context)]): Tensor[(Batch, Context)] = +// logits(inputTokens).argmax(Axis[Vocab]) -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -lazy val tiktoken = py.module("tiktoken") - -case class Tokenizer(enc: py.Dynamic): - def encode(s: String): List[Int] = - val pythonSet = py.Dynamic.global.set(Seq("<|endoftext|>").toPythonProxy) - enc.encode(s, allowed_special=pythonSet).as[List[Int]] - - def decode(l: List[Int]): String = - enc.decode(l.toPythonProxy).as[String] - -case class Inference(gpt2: GPT2, tokenizer: Tokenizer): - - def apply(input: String): LazyList[String] = - val tokenIds = tokenizer.encode(input) - def loop(currentTokens: List[Int]): LazyList[String] = - println(s"Current tokens: $currentTokens") - val inputTensor = Tensor( - Shape((Axis[Batch] -> 1, Axis[Context] -> currentTokens.length)), - currentTokens.map(_.toFloat).toArray, - // DType.Int32, - ) - val nextTokenTensor = gpt2(inputTensor) - val nextTokenId = nextTokenTensor.slice(Axis[Batch] -> 0).slice(Axis[Context] -> (currentTokens.length - 1)).toInt - val newTokens = currentTokens :+ nextTokenId - val decoded = tokenizer.decode(newTokens) - LazyList.cons(decoded, loop(newTokens)) - loop(tokenIds) +// import me.shadaj.scalapy.py +// import me.shadaj.scalapy.py.SeqConverters +// lazy val tiktoken = py.module("tiktoken") + +// case class Tokenizer(enc: py.Dynamic): +// def encode(s: String): List[Int] = +// val pythonSet = py.Dynamic.global.set(Seq("<|endoftext|>").toPythonProxy) +// enc.encode(s, allowed_special=pythonSet).as[List[Int]] + +// def decode(l: List[Int]): String = +// enc.decode(l.toPythonProxy).as[String] + +// case class Inference(gpt2: GPT2, tokenizer: Tokenizer): + +// def apply(input: String): LazyList[String] = +// val tokenIds = tokenizer.encode(input) +// def loop(currentTokens: List[Int]): LazyList[String] = +// println(s"Current tokens: $currentTokens") +// val inputTensor = Tensor( +// Shape((Axis[Batch] -> 1, Axis[Context] -> currentTokens.length)), +// currentTokens.map(_.toFloat).toArray, +// // DType.Int32, +// ) +// val nextTokenTensor = gpt2(inputTensor) +// val nextTokenId = nextTokenTensor.slice(Axis[Batch] -> 0).slice(Axis[Context] -> (currentTokens.length - 1)).toInt +// val newTokens = currentTokens :+ nextTokenId +// val decoded = tokenizer.decode(newTokens) +// LazyList.cons(decoded, loop(newTokens)) +// loop(tokenIds) -object GPT2Inference: - - import java.io.RandomAccessFile - import java.nio.channels.FileChannel - import java.nio.{ByteBuffer, ByteOrder} - import java.nio.charset.StandardCharsets - import shapeful.jax.Jax - import shapeful.tensor.DType - import me.shadaj.scalapy.py - import me.shadaj.scalapy.py.SeqConverters - - case class TensorInfo(dtype: String, shape: List[Int], start: Long, end: Long) - - object SafeTensorsReader: - import me.shadaj.scalapy.py.SeqConverters - import java.util.Base64 - - // A compact Python loader that decodes Base64 back to a tensor - // Defined as a single line to completely avoid IndentationErrors - private val pythonLoader = py.eval("""lambda b64, dtype, shape: (__import__('numpy').frombuffer(__import__('base64').b64decode(b64), dtype={'F32':__import__('numpy').float32,'I32':__import__('numpy').int32,'I64':__import__('numpy').int64}[dtype]).reshape(shape))""") - - def readHeader(filePath: String): (Map[String, TensorInfo], Long) = - // ... (Keep your existing header parsing code exactly as it is) ... - // (I omitted it here for brevity, but copy-paste your previous working readHeader) - val file = new RandomAccessFile(filePath, "r") - val channel = file.getChannel - try - val headerSizeBuffer = ByteBuffer.allocate(8) - headerSizeBuffer.order(ByteOrder.LITTLE_ENDIAN) - channel.read(headerSizeBuffer) - headerSizeBuffer.flip() - val headerSize = headerSizeBuffer.getLong - - val jsonBuffer = ByteBuffer.allocate(headerSize.toInt) - channel.read(jsonBuffer) - jsonBuffer.flip() - val jsonString = new String(jsonBuffer.array(), StandardCharsets.UTF_8) - - val json = ujson.read(jsonString) - val meta = json.obj +// object GPT2Inference: + +// import java.io.RandomAccessFile +// import java.nio.channels.FileChannel +// import java.nio.{ByteBuffer, ByteOrder} +// import java.nio.charset.StandardCharsets +// import shapeful.jax.Jax +// import shapeful.tensor.DType +// import me.shadaj.scalapy.py +// import me.shadaj.scalapy.py.SeqConverters + +// case class TensorInfo(dtype: String, shape: List[Int], start: Long, end: Long) + +// object SafeTensorsReader: +// import me.shadaj.scalapy.py.SeqConverters +// import java.util.Base64 + +// // A compact Python loader that decodes Base64 back to a tensor +// // Defined as a single line to completely avoid IndentationErrors +// private val pythonLoader = py.eval("""lambda b64, dtype, shape: (__import__('numpy').frombuffer(__import__('base64').b64decode(b64), dtype={'F32':__import__('numpy').float32,'I32':__import__('numpy').int32,'I64':__import__('numpy').int64}[dtype]).reshape(shape))""") + +// def readHeader(filePath: String): (Map[String, TensorInfo], Long) = +// // ... (Keep your existing header parsing code exactly as it is) ... +// // (I omitted it here for brevity, but copy-paste your previous working readHeader) +// val file = new RandomAccessFile(filePath, "r") +// val channel = file.getChannel +// try +// val headerSizeBuffer = ByteBuffer.allocate(8) +// headerSizeBuffer.order(ByteOrder.LITTLE_ENDIAN) +// channel.read(headerSizeBuffer) +// headerSizeBuffer.flip() +// val headerSize = headerSizeBuffer.getLong + +// val jsonBuffer = ByteBuffer.allocate(headerSize.toInt) +// channel.read(jsonBuffer) +// jsonBuffer.flip() +// val jsonString = new String(jsonBuffer.array(), StandardCharsets.UTF_8) + +// val json = ujson.read(jsonString) +// val meta = json.obj - val tensorMap = meta.filterKeys(_ != "__metadata__").map { case (name, data) => - val offsets = data("data_offsets").arr.map(_.num.toLong) - val shape = data("shape").arr.map(_.num.toInt).toList - val dtype = data("dtype").str - name -> TensorInfo(dtype, shape, offsets(0), offsets(1)) - }.toMap +// val tensorMap = meta.filterKeys(_ != "__metadata__").map { case (name, data) => +// val offsets = data("data_offsets").arr.map(_.num.toLong) +// val shape = data("shape").arr.map(_.num.toInt).toList +// val dtype = data("dtype").str +// name -> TensorInfo(dtype, shape, offsets(0), offsets(1)) +// }.toMap - val dataStartPos = 8 + headerSize - (tensorMap, dataStartPos) - finally - file.close() - - def loadTensor(filePath: String, info: TensorInfo, dataStartPos: Long): Jax.PyDynamic = - // 1. Read bytes in JVM (Fast file IO) - val file = new RandomAccessFile(filePath, "r") - try - val len = (info.end - info.start).toInt - val bytes = new Array[Byte](len) +// val dataStartPos = 8 + headerSize +// (tensorMap, dataStartPos) +// finally +// file.close() + +// def loadTensor(filePath: String, info: TensorInfo, dataStartPos: Long): Jax.PyDynamic = +// // 1. Read bytes in JVM (Fast file IO) +// val file = new RandomAccessFile(filePath, "r") +// try +// val len = (info.end - info.start).toInt +// val bytes = new Array[Byte](len) - file.seek(dataStartPos + info.start) - file.readFully(bytes) // Reads entire chunk at once +// file.seek(dataStartPos + info.start) +// file.readFully(bytes) // Reads entire chunk at once - // 2. Encode to Base64 (Fast JVM native operation) - // This turns 500MB of bytes into one String, avoiding the "List of Ints" bottleneck - val b64String = Base64.getEncoder.encodeToString(bytes) +// // 2. Encode to Base64 (Fast JVM native operation) +// // This turns 500MB of bytes into one String, avoiding the "List of Ints" bottleneck +// val b64String = Base64.getEncoder.encodeToString(bytes) - // 3. Pass to Python - val result = pythonLoader(b64String, info.dtype, info.shape.toPythonProxy) +// // 3. Pass to Python +// val result = pythonLoader(b64String, info.dtype, info.shape.toPythonProxy) - Jax.jnp.array(result) - finally - file.close() - def main(args: Array[String]): Unit = - val filePath = "data/gpt.safetensors" +// Jax.jnp.array(result) +// finally +// file.close() +// def main(args: Array[String]): Unit = +// val filePath = "data/gpt.safetensors" - // Read header to get tensor info - val (tensorMap, dataStartPos) = SafeTensorsReader.readHeader(filePath) +// // Read header to get tensor info +// val (tensorMap, dataStartPos) = SafeTensorsReader.readHeader(filePath) - def load1[L](name: String, axis: Axis[L])(using Label[L]): Tensor1[L] = - val info = tensorMap(name) - val jaxArray = SafeTensorsReader.loadTensor(filePath, info, dataStartPos) - Tensor.fromPy(jaxArray) - - def load2[L1, L2](name: String, axis1: Axis[L1], axis2: Axis[L2])(using Label[L1], Label[L2]): Tensor2[L1, L2] = - val info = tensorMap(name) - val jaxArray = SafeTensorsReader.loadTensor(filePath, info, dataStartPos) - Tensor.fromPy(jaxArray) - - def loadLinear[In, Out](prefix: String, inAxis: Axis[In], outAxis: Axis[Out])(using Label[In], Label[Out]): LinearParams[In, Out] = - val w = load2(s"$prefix.weight", inAxis, outAxis) - val b = load1(s"$prefix.bias", outAxis) - LinearParams(w, b) - - def loadLN(prefix: String): LNParams = - val w = load1(s"$prefix.weight", Axis[Embedding]) - val b = load1(s"$prefix.bias", Axis[Embedding]) - LNParams(w, b) - - val wpe = load2("wpe.weight", Axis[Context], Axis[Embedding]) - println("Successfully loaded WPE parameters") - val wte = load2("wte.weight", Axis[Vocab], Axis[Embedding]) - println("Successfully loaded WTE parameters") - val ln_f = loadLN("ln_f") - println("Successfully loaded final LayerNorm parameters") - - val layers = (0 until 12).map { i => - val prefix = s"h.$i" - val ln1 = loadLN(s"$prefix.ln_1") - val ln2 = loadLN(s"$prefix.ln_2") +// def load1[L](name: String, axis: Axis[L])(using Label[L]): Tensor1[L] = +// val info = tensorMap(name) +// val jaxArray = SafeTensorsReader.loadTensor(filePath, info, dataStartPos) +// Tensor.fromPy(jaxArray) + +// def load2[L1, L2](name: String, axis1: Axis[L1], axis2: Axis[L2])(using Label[L1], Label[L2]): Tensor2[L1, L2] = +// val info = tensorMap(name) +// val jaxArray = SafeTensorsReader.loadTensor(filePath, info, dataStartPos) +// Tensor.fromPy(jaxArray) + +// def loadLinear[In, Out](prefix: String, inAxis: Axis[In], outAxis: Axis[Out])(using Label[In], Label[Out]): LinearParams[In, Out] = +// val w = load2(s"$prefix.weight", inAxis, outAxis) +// val b = load1(s"$prefix.bias", outAxis) +// LinearParams(w, b) + +// def loadLN(prefix: String): LNParams = +// val w = load1(s"$prefix.weight", Axis[Embedding]) +// val b = load1(s"$prefix.bias", Axis[Embedding]) +// LNParams(w, b) + +// val wpe = load2("wpe.weight", Axis[Context], Axis[Embedding]) +// println("Successfully loaded WPE parameters") +// val wte = load2("wte.weight", Axis[Vocab], Axis[Embedding]) +// println("Successfully loaded WTE parameters") +// val ln_f = loadLN("ln_f") +// println("Successfully loaded final LayerNorm parameters") + +// val layers = (0 until 12).map { i => +// val prefix = s"h.$i" +// val ln1 = loadLN(s"$prefix.ln_1") +// val ln2 = loadLN(s"$prefix.ln_2") - val cAttn = loadLinear(s"$prefix.attn.c_attn", Axis[Embedding], Axis[QKV]) - val cProj = loadLinear(s"$prefix.attn.c_proj", Axis[Heads |*| Value], Axis[Embedding]) - val attn = MultiHeadAttentionParams(cAttn, cProj, numHeads = 12) +// val cAttn = loadLinear(s"$prefix.attn.c_attn", Axis[Embedding], Axis[QKV]) +// val cProj = loadLinear(s"$prefix.attn.c_proj", Axis[Heads |*| Value], Axis[Embedding]) +// val attn = MultiHeadAttentionParams(cAttn, cProj, numHeads = 12) - val c_fc = loadLinear(s"$prefix.mlp.c_fc", Axis[Embedding], Axis[Inner]) - val c_proj = loadLinear(s"$prefix.mlp.c_proj", Axis[Inner], Axis[Embedding]) - val mlp = MLPParams(c_fc, c_proj) - println(s"Successfully loaded layer $i parameters") - - HiddenParams(ln1, attn, ln2, mlp) - }.toList - println("Successfully loaded all layers parameters") - - val params = GPT2Params(wpe, wte, layers, ln_f) - val gpt2 = GPT2(params) - val inference = Inference(gpt2, Tokenizer(tiktoken.get_encoding("gpt2"))) - // val stream = inference("Hello, my name is") - // stream.foreach(println) +// val c_fc = loadLinear(s"$prefix.mlp.c_fc", Axis[Embedding], Axis[Inner]) +// val c_proj = loadLinear(s"$prefix.mlp.c_proj", Axis[Inner], Axis[Embedding]) +// val mlp = MLPParams(c_fc, c_proj) +// println(s"Successfully loaded layer $i parameters") + +// HiddenParams(ln1, attn, ln2, mlp) +// }.toList +// println("Successfully loaded all layers parameters") + +// val params = GPT2Params(wpe, wte, layers, ln_f) +// val gpt2 = GPT2(params) +// val inference = Inference(gpt2, Tokenizer(tiktoken.get_encoding("gpt2"))) +// // val stream = inference("Hello, my name is") +// // stream.foreach(println) From eaddf2a0b0a21f5423971c2fb0ed14a6451f2855 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 23 Dec 2025 17:18:53 +0100 Subject: [PATCH 10/10] fix setting device on tensor creation --- core/src/main/scala/shapeful/tensor/Tensor.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/shapeful/tensor/Tensor.scala b/core/src/main/scala/shapeful/tensor/Tensor.scala index c9547c4..537e0d0 100644 --- a/core/src/main/scala/shapeful/tensor/Tensor.scala +++ b/core/src/main/scala/shapeful/tensor/Tensor.scala @@ -70,10 +70,10 @@ object Tensor: .array( values.toPythonProxy, dtype = dtype.jaxType, - device = device.jaxDevice, ) .reshape(shape.dimensions.toPythonProxy) - Tensor[T, V](jaxValues) + val jaxValuesOnDevice = Jax.device_put(jaxValues, device.jaxDevice) + Tensor[T, V](jaxValuesOnDevice) def zeros[T <: Tuple : Labels](shape: Shape[T]): Tensor[T, V] = Tensor[T, V](Jax.jnp.zeros(shape.dimensions.toPythonProxy, dtype = dtype.jaxType))