Skip to content

IR Tracing and vmap for Julia #903

@yebai

Description

@yebai

One interesting observation is that Mooncake’s forward mode makes it easy to trace Julia code into an intermediate representation. By restricting a subset of Julia syntax—such as disallowing input-value–dependent control flow and mutation—we can enable transformations similar to those in JAX. The lowest-hanging fruit would be to implement vmap and then regenerate Julia code from the transformed IR.

In addition, this approach allows us to linearize control flow and then leverage DiffMatic to differentiate vector, matrix, and tensor expressions.

EDIT 1: In parallel, it is compelling to explore generating Triton code (ie, via Triton's Python DSL) for linear algebra operations, similar in spirit to PyTorch 2.x, where TorchInductor emits Triton kernels for fused, GPU-friendly workloads while falling back to ATen and vendor libraries like cuBLAS and cuDNN for more complex cases.

EDIT 2: it can also generate cuTile.jl / CUDA.jl calls or kernels.

EDIT 3: If we restrict tracing to primitive numeric types (e.g., floating-point scalars) and StaticArrays, this approach becomes quite tractable. Importantly, concrete input types are required only during tracing; the regenerated code itself can remain generic and operate over GPU arrays or other array abstractions.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions