One interesting observation is that Mooncake’s forward mode makes it easy to trace Julia code into an intermediate representation. By restricting a subset of Julia syntax—such as disallowing input-value–dependent control flow and mutation—we can enable transformations similar to those in JAX. The lowest-hanging fruit would be to implement vmap and then regenerate Julia code from the transformed IR.
In addition, this approach allows us to linearize control flow and then leverage DiffMatic to differentiate vector, matrix, and tensor expressions.
EDIT 1: In parallel, it is compelling to explore generating Triton code (ie, via Triton's Python DSL) for linear algebra operations, similar in spirit to PyTorch 2.x, where TorchInductor emits Triton kernels for fused, GPU-friendly workloads while falling back to ATen and vendor libraries like cuBLAS and cuDNN for more complex cases.
EDIT 2: it can also generate cuTile.jl / CUDA.jl calls or kernels.
EDIT 3: If we restrict tracing to primitive numeric types (e.g., floating-point scalars) and StaticArrays, this approach becomes quite tractable. Importantly, concrete input types are required only during tracing; the regenerated code itself can remain generic and operate over GPU arrays or other array abstractions.
One interesting observation is that Mooncake’s forward mode makes it easy to trace Julia code into an intermediate representation. By restricting a subset of Julia syntax—such as disallowing input-value–dependent control flow and mutation—we can enable transformations similar to those in JAX. The lowest-hanging fruit would be to implement
vmapand then regenerate Julia code from the transformed IR.In addition, this approach allows us to linearize control flow and then leverage DiffMatic to differentiate vector, matrix, and tensor expressions.
EDIT 1: In parallel, it is compelling to explore generating Triton code (ie, via Triton's Python DSL) for linear algebra operations, similar in spirit to PyTorch 2.x, where TorchInductor emits Triton kernels for fused, GPU-friendly workloads while falling back to ATen and vendor libraries like cuBLAS and cuDNN for more complex cases.
EDIT 2: it can also generate cuTile.jl / CUDA.jl calls or kernels.
EDIT 3: If we restrict tracing to primitive numeric types (e.g., floating-point scalars) and
StaticArrays, this approach becomes quite tractable. Importantly, concrete input types are required only during tracing; the regenerated code itself can remain generic and operate over GPU arrays or other array abstractions.