Skip to content
This repository was archived by the owner on Jan 8, 2026. It is now read-only.
14 changes: 7 additions & 7 deletions core/src/main/scala/shapeful/autodiff/Autodiff.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package shapeful.autodiff

import shapeful.tensor.{Tensor, Tensor0, Tensor1, Tensor2, Shape, AxisIndices}
import shapeful.tensor.{Tensor, Tensor0, Tensor1, Tensor2, Shape, AxisIndices, Value}
import shapeful.jax.Jax
import me.shadaj.scalapy.py

Expand All @@ -9,17 +9,17 @@ object Autodiff:
type Gradient[In, Out] = Out match
case EmptyTuple => EmptyTuple
case h *: t => Gradient[In, h] *: Gradient[In, t]
case Tensor[outS] => GradientTensorVsInput[In, outS]
case Tensor[outS, v] => GradientTensorVsInput[In, outS, v]
case _ => EmptyTuple

type GradientTensorVsInput[In, OutShape <: Tuple] = In match
type GradientTensorVsInput[In, OutShape <: Tuple, V] = In match
case EmptyTuple => EmptyTuple
case h *: t => GradientTensorVsInput[h, OutShape] *: GradientTensorVsInput[t, OutShape]
case Tensor[inS] => Tensor[Tuple.Concat[OutShape, inS]]
case h *: t => GradientTensorVsInput[h, OutShape, V] *: GradientTensorVsInput[t, OutShape, V]
case Tensor[inS, v2] => Tensor[Tuple.Concat[OutShape, inS], V]

def grad[Input](f: Input => Tensor0)(using
def grad[Input, V : Value](f: Input => Tensor0[V])(using
inTree: ToPyTree[Input],
outTree: ToPyTree[Tensor0],
outTree: ToPyTree[Tensor0[V]],
): Input => Input =

val fpy = (jxpr: py.Dynamic) =>
Expand Down
16 changes: 9 additions & 7 deletions core/src/main/scala/shapeful/autodiff/PyTree.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package shapeful.autodiff

import shapeful.tensor.{Tensor, Shape}
import shapeful.tensor.{Tensor, Shape, Value}
import shapeful.jax.Jax

import me.shadaj.scalapy.py
Expand All @@ -18,9 +18,9 @@ object ToPyTree:
def apply[P](using pt: ToPyTree[P]): ToPyTree[P] = pt

// Keep the tensor instance
given [T <: Tuple : Labels]: ToPyTree[Tensor[T]] with
def toPyTree(t: Tensor[T]): Jax.PyAny = t.jaxValue
def fromPyTree(p: Jax.PyAny): Tensor[T] = Tensor.fromPy(p.as[Jax.PyDynamic])
given [T <: Tuple : Labels, V : Value]: ToPyTree[Tensor[T, V]] with
def toPyTree(t: Tensor[T, V]): Jax.PyAny = t.jaxValue
def fromPyTree(p: Jax.PyAny): Tensor[T, V] = Tensor.fromPy(p.as[Jax.PyDynamic])

// Tuple instances - these should have lower priority than specific case classes
given tupleInstance[A, B](using ta: ToPyTree[A], tb: ToPyTree[B]): ToPyTree[(A, B)] with
Expand Down Expand Up @@ -57,7 +57,9 @@ object ToPyTree:

inline def reconstructField[T](pyElem: py.Dynamic): T =
inline erasedValue[T] match
case _: Tensor[?] => Tensor.fromPy(pyElem.as[Jax.PyDynamic]).asInstanceOf[T]
case _: Tensor[?, ?] =>
// For tensors, delegate to the ToPyTree instance which has the proper type info
compiletime.summonInline[ToPyTree[T]].fromPyTree(pyElem)
case _: String =>
pyElem.as[String].asInstanceOf[T]
case _: Int =>
Expand All @@ -84,8 +86,8 @@ object ToPyTree:

inline def convertSingleField[T](elem: T): Jax.PyAny =
inline erasedValue[T] match
case _: Tensor[?] =>
elem.asInstanceOf[Tensor[?]].jaxValue
case _: Tensor[?, ?] =>
elem.asInstanceOf[Tensor[?, ?]].jaxValue
case _: String =>
py.Dynamic.global.str(elem.asInstanceOf[String])
case _: Int =>
Expand Down
22 changes: 11 additions & 11 deletions core/src/main/scala/shapeful/autodiff/TensorTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@ import scala.compiletime.*
// TODO hot fix with retag and context parameter... maybe this can be improved?

trait TensorTree[P]:
def map(p: P, f: [T <: Tuple] => Labels[T] ?=> Tensor[T] => Tensor[T]): P
def zipMap(p1: P, p2: P, f: [T <: Tuple] => Labels[T] ?=> (Tensor[T], Tensor[T]) => Tensor[T]): P
def map(p: P, f: [T <: Tuple, V] => (Labels[T], Value[V]) ?=> (Tensor[T, V] => Tensor[T, V])): P
def zipMap(p1: P, p2: P, f: [T <: Tuple, V] => (Labels[T], Value[V]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): P

object TensorTree extends TensorTreeLowPriority:
def apply[P](using pt: TensorTree[P]): TensorTree[P] = pt

given [Q <: Tuple](using n: Labels[Q]): TensorTree[Tensor[Q]] with
def map(t: Tensor[Q], f: [T <: Tuple] => Labels[T] ?=> Tensor[T] => Tensor[T]): Tensor[Q] =
given [Q <: Tuple, V](using n: Labels[Q], v: Value[V]): TensorTree[Tensor[Q, V]] with
def map(t: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T], Value[V2]) ?=> (Tensor[T, V2] => Tensor[T, V2])): Tensor[Q, V] =
import TensorOps.retag
f[Q](using n)(t.retag[Q](using n))
f[Q, V](using n, v)(t.retag[Q](using n))

def zipMap(p1: Tensor[Q], p2: Tensor[Q], f: [T <: Tuple] => Labels[T] ?=> (Tensor[T], Tensor[T]) => Tensor[T]): Tensor[Q] =
def zipMap(p1: Tensor[Q, V], p2: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T], Value[V2]) ?=> ((Tensor[T, V2], Tensor[T, V2]) => Tensor[T, V2])): Tensor[Q, V] =
import TensorOps.retag
f[Q](using n)(p1.retag[Q](using n), p2.retag[Q](using n))
f[Q, V](using n, v)(p1.retag[Q](using n), p2.retag[Q](using n))

inline given derived[P <: Product](using m: Mirror.ProductOf[P]): TensorTree[P] =
val elemInstances = summonAll[Tuple.Map[m.MirroredElemTypes, TensorTree]]
Expand All @@ -31,13 +31,13 @@ object TensorTree extends TensorTreeLowPriority:
instances: List[TensorTree[Any]],
m: Mirror.ProductOf[P]
): TensorTree[P] = new TensorTree[P]:
def map(p: P, f: [T <: Tuple] => Labels[T] ?=> Tensor[T] => Tensor[T]): P =
def map(p: P, f: [T <: Tuple, V] => (Labels[T], Value[V]) ?=> (Tensor[T, V] => Tensor[T, V])): P =
val inputs = p.productIterator.toList
val mappedElems = inputs.zip(instances).map:
case (elem, inst) => inst.map(elem, f)
m.fromProduct(Tuple.fromArray(mappedElems.map(_.asInstanceOf[Object]).toArray))

def zipMap(p1: P, p2: P, f: [T <: Tuple] => Labels[T] ?=> (Tensor[T], Tensor[T]) => Tensor[T]): P =
def zipMap(p1: P, p2: P, f: [T <: Tuple, V] => (Labels[T], Value[V]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): P =
val inputs1 = p1.productIterator.toList
val inputs2 = p2.productIterator.toList
val mappedElems = inputs1.zip(inputs2).zip(instances).map:
Expand All @@ -46,5 +46,5 @@ object TensorTree extends TensorTreeLowPriority:

trait TensorTreeLowPriority:
given identity[A]: TensorTree[A] = new TensorTree[A]:
def map(p: A, f: [T <: Tuple] => Labels[T] ?=> Tensor[T] => Tensor[T]): A = p
def zipMap(p1: A, p2: A, f: [T <: Tuple] => Labels[T] ?=> (Tensor[T], Tensor[T]) => Tensor[T]): A = p1
def map(p: A, f: [T <: Tuple, V] => (Labels[T], Value[V]) ?=> (Tensor[T, V] => Tensor[T, V])): A = p
def zipMap(p1: A, p2: A, f: [T <: Tuple, V] => (Labels[T], Value[V]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): A = p1
10 changes: 5 additions & 5 deletions core/src/main/scala/shapeful/jax/Jit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import shapeful.tensor.{Tensor, Shape, Labels}
import shapeful.jax.{Jax, JaxDType}
import shapeful.autodiff.ToPyTree
import me.shadaj.scalapy.py

import shapeful.tensor.Value

object Jit:

def jit[PyTree: ToPyTree, OutT <: Tuple : Labels](
f: PyTree => Tensor[OutT]
): PyTree => Tensor[OutT] =
def jit[PyTree: ToPyTree, OutT <: Tuple : Labels, V : Value](
f: PyTree => Tensor[OutT, V]
): PyTree => Tensor[OutT, V] =

// Python function that accepts a pytree
val fpy = (pyTreePy: Jax.PyDynamic) =>
Expand All @@ -25,7 +25,7 @@ object Jit:
(pyTree: PyTree) =>
val pyTreePy = ToPyTree[PyTree].toPyTree(pyTree)
val resultJax = jitted(pyTreePy)
Tensor.fromPy[OutT](resultJax)
Tensor.fromPy[OutT, V](resultJax)

def jit2[PyTree: ToPyTree, OutT <: Tuple : Labels](
f: PyTree => PyTree
Expand Down
52 changes: 46 additions & 6 deletions core/src/main/scala/shapeful/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ package object shapeful:
case Prime[l] *: tail => l *: RemovePrimes[tail]
case h *: tail => h *: RemovePrimes[tail]

extension[T <: Tuple : Labels](tensor: Tensor[T])
def dropPrimes: Tensor[RemovePrimes[T]] =
extension[T <: Tuple : Labels, V : Value](tensor: Tensor[T, V])
def dropPrimes: Tensor[RemovePrimes[T], V] =
given newLabels: Labels[RemovePrimes[T]] with
val names: List[String] =
val oldLabels = summon[Labels[T]]
oldLabels.names.toList.map(_.replace("'", ""))
Tensor.fromPy(tensor.jaxValue)
Tensor.fromPy[RemovePrimes[T], V](tensor.jaxValue)

@targetName("On")
infix trait ~[A, B]
Expand All @@ -42,6 +42,48 @@ package object shapeful:
export shapeful.tensor.{Shape, Shape0, Shape1, Shape2, Shape3}
export shapeful.tensor.{DType, Device}
export shapeful.tensor.{Label, Labels, Axis, AxisIndex, AxisIndices, Dim}
export shapeful.tensor.Value

// Opaque types for DTypes - clean imports and display without .type suffix
opaque type Float32 = DType.Float32.type
object Float32:
given Value[Float32] = summon[Value[DType.Float32.type]]

opaque type Float64 = DType.Float64.type
object Float64:
given Value[Float64] = summon[Value[DType.Float64.type]]

opaque type Int32 = DType.Int32.type
object Int32:
given Value[Int32] = summon[Value[DType.Int32.type]]

opaque type Int64 = DType.Int64.type
object Int64:
given Value[Int64] = summon[Value[DType.Int64.type]]

opaque type Int16 = DType.Int16.type
object Int16:
given Value[Int16] = summon[Value[DType.Int16.type]]

opaque type Int8 = DType.Int8.type
object Int8:
given Value[Int8] = summon[Value[DType.Int8.type]]

opaque type UInt32 = DType.UInt32.type
object UInt32:
given Value[UInt32] = summon[Value[DType.UInt32.type]]

opaque type UInt16 = DType.UInt16.type
object UInt16:
given Value[UInt16] = summon[Value[DType.UInt16.type]]

opaque type UInt8 = DType.UInt8.type
object UInt8:
given Value[UInt8] = summon[Value[DType.UInt8.type]]

opaque type Bool = DType.Bool.type
object Bool:
given Value[Bool] = summon[Value[DType.Bool.type]]

// Export type helpers
export shapeful.tensor.Axis.UnwrapAxes
Expand All @@ -58,7 +100,5 @@ package object shapeful:
// Export Just-in-Time compilation
export shapeful.jax.Jit.{jit, jit2}

// Export implicit conversions
object Conversions:
export shapeful.tensor.Tensor0.{given_Conversion_Int_Tensor0, given_Conversion_Float_Tensor0}

export shapeful.tensor.Tensor0.{given_Conversion_Int_Tensor0, given_Conversion_Float_Tensor0}
33 changes: 17 additions & 16 deletions core/src/main/scala/shapeful/random/Random.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import shapeful.tensor.*
import shapeful.tensor.TensorOps.*
import shapeful.jax.{Jax, JaxDType}
import me.shadaj.scalapy.py.SeqConverters
import shapeful.Float32

/** JAX-based random number generation with proper key management.
*
Expand Down Expand Up @@ -49,42 +50,42 @@ object Random:
*/

/** Normal distribution with specified mean and standard deviation */
def normal[T <: Tuple : Labels](
def normal[T <: Tuple : Labels, V : Value](
key: Key,
shape: Shape[T],
mean: Tensor0 = Tensor0(0f),
std: Tensor0 = Tensor0(1f),
dtype: DType = DType.Float32
): Tensor[T] =
mean: Tensor0[V] = Tensor0.of[Float32].apply(0f),
std: Tensor0[V] = Tensor0.of[Float32].apply(1f)
): Tensor[T, V] =
val dtype = summon[Value[V]].dtype
val jaxValues = Jax.jrandom.normal(
key.jaxKey,
shape.dimensions.toPythonProxy,
dtype = JaxDType.jaxDtype(dtype)
)
val standardNormal = Tensor.fromPy[T](jaxValues)
val standardNormal = Tensor.fromPy[T, V](jaxValues)
standardNormal :* std :+ mean

/** Uniform distribution in [0, 1) */
def uniform[T <: Tuple : Labels](
def uniform[T <: Tuple : Labels, V : Value](
key: Key,
shape: Shape[T],
dtype: DType = DType.Float32
): Tensor[T] = uniform(key, shape, Tensor0(0f), Tensor0(1f), dtype)
shape: Shape[T]
)(using ev: V =:= DType.Float32.type): Tensor[T, V] =
uniform(key, shape, Tensor0.of[V].apply(0f), Tensor0.of[V].apply(1f))

/** Uniform distribution in [minval, maxval) */
def uniform[T <: Tuple : Labels](
def uniform[T <: Tuple : Labels, V : Value](
key: Key,
shape: Shape[T],
minval: Tensor0,
maxval: Tensor0,
dtype: DType
): Tensor[T] =
minval: Tensor0[V],
maxval: Tensor0[V]
): Tensor[T, V] =
val dtype = summon[Value[V]].dtype
val jaxValues = Jax.jrandom.uniform(
key.jaxKey,
shape.dimensions.toPythonProxy,
minval = minval.jaxValue,
maxval = maxval.jaxValue,
dtype = JaxDType.jaxDtype(dtype)
)
Tensor.fromPy[T](jaxValues)
Tensor.fromPy[T, V](jaxValues)

Loading