Skip to content

Commit 6155d5f

Browse files
Merge pull request #88 from tensor4all/feat/canonical-binary-output-ids
Canonicalize binary intermediate id order in opteinsum
2 parents 7aa260c + 341705a commit 6155d5f

1 file changed

Lines changed: 42 additions & 0 deletions

File tree

strided-opteinsum/src/expr.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ fn collect_all_ids_inner(node: &EinsumNode, result: &mut Vec<char>) {
127127
/// needs from this node) AND it is actually present in at least one child
128128
/// subtree.
129129
fn compute_contract_output_ids(args: &[EinsumNode], needed_ids: &[char]) -> Vec<char> {
130+
if args.len() == 2 {
131+
let left_ids = collect_all_ids(&args[0]);
132+
let right_ids = collect_all_ids(&args[1]);
133+
return compute_binary_output_ids(&left_ids, &right_ids, needed_ids);
134+
}
135+
130136
// Walk args in order and collect ids preserving first-seen order
131137
let mut all_ids_ordered = Vec::new();
132138
for arg in args {
@@ -223,6 +229,36 @@ fn out_dims_from_ids(
223229
Ok(out_dims)
224230
}
225231

232+
/// Compute binary contraction output id order.
233+
///
234+
/// Uses canonical `[lo, ro, batch]` order:
235+
/// - lo: ids only in left and needed
236+
/// - ro: ids only in right and needed
237+
/// - batch: ids in both and needed
238+
fn compute_binary_output_ids(
239+
left_ids: &[char],
240+
right_ids: &[char],
241+
needed_ids: &[char],
242+
) -> Vec<char> {
243+
let mut out = Vec::new();
244+
for &id in left_ids {
245+
if needed_ids.contains(&id) && !right_ids.contains(&id) && !out.contains(&id) {
246+
out.push(id);
247+
}
248+
}
249+
for &id in right_ids {
250+
if needed_ids.contains(&id) && !left_ids.contains(&id) && !out.contains(&id) {
251+
out.push(id);
252+
}
253+
}
254+
for &id in left_ids {
255+
if needed_ids.contains(&id) && right_ids.contains(&id) && !out.contains(&id) {
256+
out.push(id);
257+
}
258+
}
259+
out
260+
}
261+
226262
/// Generic inner function for pairwise contraction with buffer pool.
227263
///
228264
/// Acquires an output buffer, runs `einsum2_into`, and releases input buffers
@@ -1005,6 +1041,12 @@ mod tests {
10051041
.into()
10061042
}
10071043

1044+
#[test]
1045+
fn test_binary_output_ids_canonical_lo_ro_batch_order() {
1046+
let out = compute_binary_output_ids(&['b', 'a', 'x'], &['x', 'c', 'a'], &['b', 'c', 'a']);
1047+
assert_eq!(out, vec!['b', 'c', 'a']);
1048+
}
1049+
10081050
#[test]
10091051
fn test_matmul() {
10101052
let code = parse_einsum("ij,jk->ik").unwrap();

0 commit comments

Comments
 (0)