@@ -127,6 +127,12 @@ fn collect_all_ids_inner(node: &EinsumNode, result: &mut Vec<char>) {
127127/// needs from this node) AND it is actually present in at least one child
128128/// subtree.
129129fn compute_contract_output_ids ( args : & [ EinsumNode ] , needed_ids : & [ char ] ) -> Vec < char > {
130+ if args. len ( ) == 2 {
131+ let left_ids = collect_all_ids ( & args[ 0 ] ) ;
132+ let right_ids = collect_all_ids ( & args[ 1 ] ) ;
133+ return compute_binary_output_ids ( & left_ids, & right_ids, needed_ids) ;
134+ }
135+
130136 // Walk args in order and collect ids preserving first-seen order
131137 let mut all_ids_ordered = Vec :: new ( ) ;
132138 for arg in args {
@@ -223,6 +229,36 @@ fn out_dims_from_ids(
223229 Ok ( out_dims)
224230}
225231
232+ /// Compute binary contraction output id order.
233+ ///
234+ /// Uses canonical `[lo, ro, batch]` order:
235+ /// - lo: ids only in left and needed
236+ /// - ro: ids only in right and needed
237+ /// - batch: ids in both and needed
238+ fn compute_binary_output_ids (
239+ left_ids : & [ char ] ,
240+ right_ids : & [ char ] ,
241+ needed_ids : & [ char ] ,
242+ ) -> Vec < char > {
243+ let mut out = Vec :: new ( ) ;
244+ for & id in left_ids {
245+ if needed_ids. contains ( & id) && !right_ids. contains ( & id) && !out. contains ( & id) {
246+ out. push ( id) ;
247+ }
248+ }
249+ for & id in right_ids {
250+ if needed_ids. contains ( & id) && !left_ids. contains ( & id) && !out. contains ( & id) {
251+ out. push ( id) ;
252+ }
253+ }
254+ for & id in left_ids {
255+ if needed_ids. contains ( & id) && right_ids. contains ( & id) && !out. contains ( & id) {
256+ out. push ( id) ;
257+ }
258+ }
259+ out
260+ }
261+
226262/// Generic inner function for pairwise contraction with buffer pool.
227263///
228264/// Acquires an output buffer, runs `einsum2_into`, and releases input buffers
@@ -1005,6 +1041,12 @@ mod tests {
10051041 . into ( )
10061042 }
10071043
1044+ #[ test]
1045+ fn test_binary_output_ids_canonical_lo_ro_batch_order ( ) {
1046+ let out = compute_binary_output_ids ( & [ 'b' , 'a' , 'x' ] , & [ 'x' , 'c' , 'a' ] , & [ 'b' , 'c' , 'a' ] ) ;
1047+ assert_eq ! ( out, vec![ 'b' , 'c' , 'a' ] ) ;
1048+ }
1049+
10081050 #[ test]
10091051 fn test_matmul ( ) {
10101052 let code = parse_einsum ( "ij,jk->ik" ) . unwrap ( ) ;
0 commit comments