Skip to content

test: verify iterative AD works with eval() mid-loop #654

@shinaoka

Description

@shinaoka

Caution

AI auto-implementation prohibited — This issue is for verification and documentation only. Do not auto-generate implementation PRs.

Background

With the DimExpr migration (#651, #653), the graph layer now supports shape-agnostic execution. A key use case is iterative tensor network algorithms (DMRG, ALS) where:

  1. A computation graph is built incrementally across iterations
  2. eval() is called mid-loop for convergence checks
  3. After convergence, the entire loop is differentiated

Key insight

eval() is non-destructive — it compiles and executes the Fragment but does not consume or modify it. The TracedTensor remains connected to its graph after evaluation. This means:

  • Convergence checks via eval() do not break the graph
  • The user can continue building the graph after eval()
  • grad() can differentiate through the entire iterative computation

Pseudo-code to verify

let params = TracedTensor::from_tensor(/* ... */);
let mut x = TracedTensor::from_tensor(/* initial value */);

for _ in 0..max_iter {
    // Build graph (each iteration extends the graph)
    let y = f(&x, &params);

    // Convergence check — eval() does NOT break the graph
    let val = y.eval(&mut engine);
    if converged(&val) {
        x = y;  // keep the TracedTensor, not the concrete value
        break;
    }

    // Feed TracedTensor (not concrete Tensor) to next iteration
    // This keeps the graph connected
    x = g(&y);
}

// Differentiate through the ENTIRE loop
let loss = compute_loss(&x);
let grad = loss.grad(&params);
let grad_val = grad.eval(&mut engine);
// grad_val contains gradients accumulated through all iterations

What to verify

  1. Graph continuity: After y.eval(), y is still usable as a TracedTensor input to subsequent operations
  2. Gradient correctness: loss.grad(&params) produces correct gradients through multiple iterations (compare with finite differences)
  3. Graph deduplication: When both primal and gradient are evaluated, materialize_merge() deduplicates shared subexpressions via GlobalOpKey
  4. Memory growth: Graph size grows linearly with iteration count (O(K × |f|) for K iterations)

Notes

  • Memory grows with each iteration (all intermediate Fragments are retained). This is the standard unrolled-through-time tradeoff.
  • For memory-constrained use cases, gradient checkpointing would be a future extension.
  • Linalg AD rules currently require concrete DimExpr::Const values (const_dim panics on symbolic values). This is a known limitation documented in feat: shape-agnostic graph with DimExpr (#651) #653.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions