|
| 1 | +# Design: Mixed Scalar Types in map/broadcast/ops (Issue #5) |
| 2 | + |
| 3 | +## Goal |
| 4 | + |
| 5 | +Generalize `strided-kernel` element-wise operations to accept different scalar types |
| 6 | +per operand (e.g., `f64 + Complex64 → Complex64`), without bulk pre-promotion. |
| 7 | +Each array is read in its native type; the closure or trait bound handles per-element |
| 8 | +conversion. |
| 9 | + |
| 10 | +## Approach |
| 11 | + |
| 12 | +Replace in-place: change single-type signatures to multi-type signatures. |
| 13 | +Backward compatible via Rust type inference (`D=A=B=T` inferred for existing callers). |
| 14 | +SIMD fast paths remain same-type only; mixed-type falls back to scalar. |
| 15 | + |
| 16 | +## Changes by module |
| 17 | + |
| 18 | +### map_view.rs — Separate type parameters per operand |
| 19 | + |
| 20 | +```rust |
| 21 | +// map_into: A → D |
| 22 | +pub fn map_into<D, A, Op>( |
| 23 | + dest: &mut StridedViewMut<D>, |
| 24 | + src: &StridedView<A, Op>, |
| 25 | + f: impl Fn(A) -> D + MaybeSync, |
| 26 | +) -> Result<()> |
| 27 | +where |
| 28 | + D: Copy + MaybeSendSync, |
| 29 | + A: Copy + ElementOpApply + MaybeSendSync, |
| 30 | + Op: ElementOp, |
| 31 | + |
| 32 | +// zip_map2_into: (A, B) → D |
| 33 | +pub fn zip_map2_into<D, A, B, OpA, OpB>( |
| 34 | + dest: &mut StridedViewMut<D>, |
| 35 | + a: &StridedView<A, OpA>, |
| 36 | + b: &StridedView<B, OpB>, |
| 37 | + f: impl Fn(A, B) -> D + MaybeSync, |
| 38 | +) -> Result<()> |
| 39 | +where |
| 40 | + D: Copy + MaybeSendSync, |
| 41 | + A: Copy + ElementOpApply + MaybeSendSync, |
| 42 | + B: Copy + ElementOpApply + MaybeSendSync, |
| 43 | + |
| 44 | +// zip_map3_into: (A, B, C) → D |
| 45 | +pub fn zip_map3_into<D, A, B, C, OpA, OpB, OpC>(...) |
| 46 | + |
| 47 | +// zip_map4_into: (A, B, C, E) → D |
| 48 | +pub fn zip_map4_into<D, A, B, C, E, OpA, OpB, OpC, OpE>(...) |
| 49 | +``` |
| 50 | + |
| 51 | +Inner loop functions gain matching type parameters. elem_size changes to: |
| 52 | +```rust |
| 53 | +let elem_size = size_of::<D>().max(size_of::<A>()).max(size_of::<B>()); |
| 54 | +``` |
| 55 | + |
| 56 | +### ops_view.rs — Trait bounds express type promotion |
| 57 | + |
| 58 | +```rust |
| 59 | +// add: dest[i] += src[i] |
| 60 | +pub fn add<D, S, Op>(dest: &mut StridedViewMut<D>, src: &StridedView<S, Op>) |
| 61 | +where D: Copy + Add<S, Output = D> + MaybeSendSync, |
| 62 | + S: Copy + ElementOpApply + MaybeSendSync, |
| 63 | + |
| 64 | +// mul: dest[i] *= src[i] |
| 65 | +pub fn mul<D, S, Op>(dest: &mut StridedViewMut<D>, src: &StridedView<S, Op>) |
| 66 | +where D: Copy + Mul<S, Output = D> + MaybeSendSync, |
| 67 | + S: Copy + ElementOpApply + MaybeSendSync, |
| 68 | + |
| 69 | +// axpy: dest[i] += alpha * src[i] |
| 70 | +pub fn axpy<D, S, A, Op>(dest: &mut StridedViewMut<D>, src: &StridedView<S, Op>, alpha: A) |
| 71 | +where A: Copy + Mul<S, Output = D>, |
| 72 | + D: Copy + Add<D, Output = D> + MaybeSendSync, |
| 73 | + S: Copy + ElementOpApply + MaybeSendSync, |
| 74 | + |
| 75 | +// fma: dest[i] += a[i] * b[i] |
| 76 | +pub fn fma<D, A, B, OpA, OpB>(dest: &mut StridedViewMut<D>, a: &StridedView<A, OpA>, b: &StridedView<B, OpB>) |
| 77 | +where A: Copy + ElementOpApply + Mul<B, Output = D> + MaybeSendSync, |
| 78 | + B: Copy + ElementOpApply + MaybeSendSync, |
| 79 | + D: Copy + Add<D, Output = D> + MaybeSendSync, |
| 80 | + |
| 81 | +// dot: sum(a[i] * b[i]) |
| 82 | +pub fn dot<A, B, R, OpA, OpB>(a: &StridedView<A, OpA>, b: &StridedView<B, OpB>) -> Result<R> |
| 83 | +where A: Copy + ElementOpApply + Mul<B, Output = R> + MaybeSendSync, |
| 84 | + B: Copy + ElementOpApply + MaybeSendSync, |
| 85 | + R: Copy + Zero + Add<Output = R> + MaybeSendSync, |
| 86 | + |
| 87 | +// copy_scale: dest[i] = scale * src[i] |
| 88 | +pub fn copy_scale<D, S, A, Op>(dest: &mut StridedViewMut<D>, src: &StridedView<S, Op>, scale: A) |
| 89 | +where A: Copy + Mul<S, Output = D>, |
| 90 | + D: Copy + MaybeSendSync, |
| 91 | + S: Copy + ElementOpApply + MaybeSendSync, |
| 92 | +``` |
| 93 | + |
| 94 | +### block.rs / kernel.rs — No changes |
| 95 | + |
| 96 | +`build_plan_fused` and `compute_block_sizes` already accept `elem_size: usize`. |
| 97 | +Only call sites change: use `.max()` across all operand sizes. |
| 98 | + |
| 99 | +### Unchanged modules |
| 100 | + |
| 101 | +- `reduce_view.rs`: Already supports T→U with `.max()` elem_size |
| 102 | +- `broadcast.rs`: CaptureArgs works with any closure type |
| 103 | +- `kernel.rs`: Iteration engine is type-agnostic (operates on `isize` offsets) |
| 104 | +- `copy_into`: Same-type by definition |
| 105 | +- `symmetrize_into`, `copy_transpose_scale_into`: Keep single-type |
| 106 | + |
| 107 | +### SIMD |
| 108 | + |
| 109 | +`sum()` and `dot()` SIMD fast paths (`MaybeSimdOps`) remain same-type + Identity only. |
| 110 | +Mixed-type calls bypass SIMD, use scalar closure path via generalized `reduce`/`zip_map2_into`. |
| 111 | + |
| 112 | +## Implementation order |
| 113 | + |
| 114 | +1. **map_view.rs**: Generalize `map_into`, `zip_map2_into`, `zip_map3_into`, `zip_map4_into` |
| 115 | +2. **ops_view.rs**: Update `add`, `mul`, `axpy`, `fma`, `dot`, `copy_scale` |
| 116 | +3. **Tests**: Add mixed f64/Complex64 test cases for each changed function |
| 117 | +4. **Verify**: All existing tests pass (backward compat) |
| 118 | + |
| 119 | +## Files to modify |
| 120 | + |
| 121 | +| File | Change | |
| 122 | +|------|--------| |
| 123 | +| `strided-kernel/src/map_view.rs` | Generalize type params, inner loops, elem_size | |
| 124 | +| `strided-kernel/src/ops_view.rs` | Generalize type params, delegate to map functions | |
| 125 | +| `strided-kernel/tests/correctness_view.rs` | Add mixed-type tests | |
| 126 | + |
| 127 | +## Out of scope |
| 128 | + |
| 129 | +- `strided-opteinsum` / `strided-einsum2`: Downstream consumers, adopt later |
| 130 | +- Mixed-type SIMD kernels: Scalar fallback is sufficient |
| 131 | +- `broadcast.rs`: Already works via closure genericity |
0 commit comments