Skip to content

Commit e8f6fbe

Browse files
authored
Merge pull request #84 from tensor4all/feature/mixed-scalar-types
feat: mixed scalar types in strided-kernel ops
2 parents 93e8647 + 3f012a9 commit e8f6fbe

4 files changed

Lines changed: 571 additions & 153 deletions

File tree

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Design: Mixed Scalar Types in map/broadcast/ops (Issue #5)
2+
3+
## Goal
4+
5+
Generalize `strided-kernel` element-wise operations to accept different scalar types
6+
per operand (e.g., `f64 + Complex64 → Complex64`), without bulk pre-promotion.
7+
Each array is read in its native type; the closure or trait bound handles per-element
8+
conversion.
9+
10+
## Approach
11+
12+
Replace in-place: change single-type signatures to multi-type signatures.
13+
Backward compatible via Rust type inference (`D=A=B=T` inferred for existing callers).
14+
SIMD fast paths remain same-type only; mixed-type falls back to scalar.
15+
16+
## Changes by module
17+
18+
### map_view.rs — Separate type parameters per operand
19+
20+
```rust
21+
// map_into: A → D
22+
pub fn map_into<D, A, Op>(
23+
dest: &mut StridedViewMut<D>,
24+
src: &StridedView<A, Op>,
25+
f: impl Fn(A) -> D + MaybeSync,
26+
) -> Result<()>
27+
where
28+
D: Copy + MaybeSendSync,
29+
A: Copy + ElementOpApply + MaybeSendSync,
30+
Op: ElementOp,
31+
32+
// zip_map2_into: (A, B) → D
33+
pub fn zip_map2_into<D, A, B, OpA, OpB>(
34+
dest: &mut StridedViewMut<D>,
35+
a: &StridedView<A, OpA>,
36+
b: &StridedView<B, OpB>,
37+
f: impl Fn(A, B) -> D + MaybeSync,
38+
) -> Result<()>
39+
where
40+
D: Copy + MaybeSendSync,
41+
A: Copy + ElementOpApply + MaybeSendSync,
42+
B: Copy + ElementOpApply + MaybeSendSync,
43+
44+
// zip_map3_into: (A, B, C) → D
45+
pub fn zip_map3_into<D, A, B, C, OpA, OpB, OpC>(...)
46+
47+
// zip_map4_into: (A, B, C, E) → D
48+
pub fn zip_map4_into<D, A, B, C, E, OpA, OpB, OpC, OpE>(...)
49+
```
50+
51+
Inner loop functions gain matching type parameters. elem_size changes to:
52+
```rust
53+
let elem_size = size_of::<D>().max(size_of::<A>()).max(size_of::<B>());
54+
```
55+
56+
### ops_view.rs — Trait bounds express type promotion
57+
58+
```rust
59+
// add: dest[i] += src[i]
60+
pub fn add<D, S, Op>(dest: &mut StridedViewMut<D>, src: &StridedView<S, Op>)
61+
where D: Copy + Add<S, Output = D> + MaybeSendSync,
62+
S: Copy + ElementOpApply + MaybeSendSync,
63+
64+
// mul: dest[i] *= src[i]
65+
pub fn mul<D, S, Op>(dest: &mut StridedViewMut<D>, src: &StridedView<S, Op>)
66+
where D: Copy + Mul<S, Output = D> + MaybeSendSync,
67+
S: Copy + ElementOpApply + MaybeSendSync,
68+
69+
// axpy: dest[i] += alpha * src[i]
70+
pub fn axpy<D, S, A, Op>(dest: &mut StridedViewMut<D>, src: &StridedView<S, Op>, alpha: A)
71+
where A: Copy + Mul<S, Output = D>,
72+
D: Copy + Add<D, Output = D> + MaybeSendSync,
73+
S: Copy + ElementOpApply + MaybeSendSync,
74+
75+
// fma: dest[i] += a[i] * b[i]
76+
pub fn fma<D, A, B, OpA, OpB>(dest: &mut StridedViewMut<D>, a: &StridedView<A, OpA>, b: &StridedView<B, OpB>)
77+
where A: Copy + ElementOpApply + Mul<B, Output = D> + MaybeSendSync,
78+
B: Copy + ElementOpApply + MaybeSendSync,
79+
D: Copy + Add<D, Output = D> + MaybeSendSync,
80+
81+
// dot: sum(a[i] * b[i])
82+
pub fn dot<A, B, R, OpA, OpB>(a: &StridedView<A, OpA>, b: &StridedView<B, OpB>) -> Result<R>
83+
where A: Copy + ElementOpApply + Mul<B, Output = R> + MaybeSendSync,
84+
B: Copy + ElementOpApply + MaybeSendSync,
85+
R: Copy + Zero + Add<Output = R> + MaybeSendSync,
86+
87+
// copy_scale: dest[i] = scale * src[i]
88+
pub fn copy_scale<D, S, A, Op>(dest: &mut StridedViewMut<D>, src: &StridedView<S, Op>, scale: A)
89+
where A: Copy + Mul<S, Output = D>,
90+
D: Copy + MaybeSendSync,
91+
S: Copy + ElementOpApply + MaybeSendSync,
92+
```
93+
94+
### block.rs / kernel.rs — No changes
95+
96+
`build_plan_fused` and `compute_block_sizes` already accept `elem_size: usize`.
97+
Only call sites change: use `.max()` across all operand sizes.
98+
99+
### Unchanged modules
100+
101+
- `reduce_view.rs`: Already supports TU with `.max()` elem_size
102+
- `broadcast.rs`: CaptureArgs works with any closure type
103+
- `kernel.rs`: Iteration engine is type-agnostic (operates on `isize` offsets)
104+
- `copy_into`: Same-type by definition
105+
- `symmetrize_into`, `copy_transpose_scale_into`: Keep single-type
106+
107+
### SIMD
108+
109+
`sum()` and `dot()` SIMD fast paths (`MaybeSimdOps`) remain same-type + Identity only.
110+
Mixed-type calls bypass SIMD, use scalar closure path via generalized `reduce`/`zip_map2_into`.
111+
112+
## Implementation order
113+
114+
1. **map_view.rs**: Generalize `map_into`, `zip_map2_into`, `zip_map3_into`, `zip_map4_into`
115+
2. **ops_view.rs**: Update `add`, `mul`, `axpy`, `fma`, `dot`, `copy_scale`
116+
3. **Tests**: Add mixed f64/Complex64 test cases for each changed function
117+
4. **Verify**: All existing tests pass (backward compat)
118+
119+
## Files to modify
120+
121+
| File | Change |
122+
|------|--------|
123+
| `strided-kernel/src/map_view.rs` | Generalize type params, inner loops, elem_size |
124+
| `strided-kernel/src/ops_view.rs` | Generalize type params, delegate to map functions |
125+
| `strided-kernel/tests/correctness_view.rs` | Add mixed-type tests |
126+
127+
## Out of scope
128+
129+
- `strided-opteinsum` / `strided-einsum2`: Downstream consumers, adopt later
130+
- Mixed-type SIMD kernels: Scalar fallback is sufficient
131+
- `broadcast.rs`: Already works via closure genericity

0 commit comments

Comments
 (0)