Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tenferro-einsum/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use computegraph::fragment::FragmentBuilder;
use computegraph::types::{OpMode, ValRef};
use computegraph::GraphOp;

use tenferro_ops::dim_expr::DimExpr;
use tenferro_ops::semiring_ops::SemiringOps;
use tenferro_tensor::DotGeneralConfig;

Expand Down Expand Up @@ -88,7 +89,7 @@ fn reduce_val<Op: GraphOp + SemiringOps>(
.map(|(_, &s)| s)
.collect();
let outputs = builder.add_op(
Op::reduce_sum(reduce_axes, lv.shape.clone()),
Op::reduce_sum(reduce_axes, DimExpr::from_concrete(&lv.shape)),
vec![lv.val.clone()],
OpMode::Primal,
);
Expand Down Expand Up @@ -397,12 +398,12 @@ fn outer_product<Op: GraphOp + SemiringOps>(
.collect();

let lhs_bc = builder.add_op(
Op::broadcast_in_dim(combined_shape.clone(), lhs_dims),
Op::broadcast_in_dim(DimExpr::from_concrete(&combined_shape), lhs_dims),
vec![lhs.val.clone()],
OpMode::Primal,
);
let rhs_bc = builder.add_op(
Op::broadcast_in_dim(combined_shape.clone(), rhs_dims),
Op::broadcast_in_dim(DimExpr::from_concrete(&combined_shape), rhs_dims),
vec![rhs.val.clone()],
OpMode::Primal,
);
Expand Down
96 changes: 66 additions & 30 deletions tenferro-ops/src/ad/contraction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use computegraph::fragment::FragmentBuilder;
use computegraph::types::{GlobalValKey, LocalValId, OpMode, ValRef};
use tenferro_tensor::{CompareDir, DotGeneralConfig};

use crate::dim_expr::DimExpr;
use crate::std_tensor_op::StdTensorOp;

pub fn linearize_dot_general(
Expand Down Expand Up @@ -78,16 +79,17 @@ pub fn linearize_reduce_prod(
primal_out: &[GlobalValKey<StdTensorOp>],
tangent_in: &[Option<LocalValId>],
axes: &[usize],
input_shape: &[usize],
input_shape: &[DimExpr],
) -> Vec<Option<LocalValId>> {
let Some(dx) = tangent_in[0] else {
return vec![None];
};

let kept_dims = kept_dims(input_shape.len(), axes);
let prod_broadcast = broadcast_reduction_output_fixed(
let prod_broadcast = broadcast_reduction_output(
builder,
ValRef::External(primal_out[0].clone()),
ValRef::External(primal_in[0].clone()),
input_shape,
&kept_dims,
);
Expand Down Expand Up @@ -125,16 +127,17 @@ pub fn linearize_reduce_chooser(
primal_out: &[GlobalValKey<StdTensorOp>],
tangent_in: &[Option<LocalValId>],
axes: &[usize],
input_shape: &[usize],
input_shape: &[DimExpr],
) -> Vec<Option<LocalValId>> {
let Some(dx) = tangent_in[0] else {
return vec![None];
};

let kept_dims = kept_dims(input_shape.len(), axes);
let answer_broadcast = broadcast_reduction_output_fixed(
let answer_broadcast = broadcast_reduction_output(
builder,
ValRef::External(primal_out[0].clone()),
ValRef::External(primal_in[0].clone()),
input_shape,
&kept_dims,
);
Expand Down Expand Up @@ -238,6 +241,7 @@ pub fn transpose_reduce_sum(
builder: &mut FragmentBuilder<StdTensorOp>,
cotangent_out: &[Option<LocalValId>],
op: &StdTensorOp,
inputs: &[ValRef<StdTensorOp>],
) -> Vec<Option<LocalValId>> {
let StdTensorOp::ReduceSum { axes, input_shape } = op else {
unreachable!("transpose_reduce_sum expects ReduceSum");
Expand All @@ -251,7 +255,7 @@ pub fn transpose_reduce_sum(
let cotangent = if kept_dims.is_empty() {
let scalar = builder.add_op(
StdTensorOp::Reshape {
from_shape: vec![1],
from_shape: DimExpr::from_concrete(&[1]),
to_shape: vec![],
},
vec![ValRef::Local(ct)],
Expand All @@ -263,15 +267,22 @@ pub fn transpose_reduce_sum(
} else {
ValRef::Local(ct)
};
let shape = DimExpr::remap_all(input_shape, 0, 1);
let needs_shape_source = DimExpr::max_input_idx_all(&shape).is_some_and(|idx| idx > 0);
let mut op_inputs = vec![cotangent];
let active_mask = if needs_shape_source {
op_inputs.push(inputs[0].clone());
vec![true, false]
} else {
vec![true]
};
let out = builder.add_op(
StdTensorOp::BroadcastInDim {
shape: input_shape.clone(),
shape,
dims: kept_dims,
},
vec![cotangent],
OpMode::Linear {
active_mask: vec![true],
},
op_inputs,
OpMode::Linear { active_mask },
);
vec![Some(out[0])]
}
Expand All @@ -293,20 +304,28 @@ pub fn transpose_reduce_prod(
Some(ct) => {
let kept_dims = kept_dims(input_shape.len(), axes);
let cotangent = normalize_reduction_cotangent(builder, ct, &kept_dims);
let shape = DimExpr::remap_all(input_shape, 0, 1);
let needs_shape_source = DimExpr::max_input_idx_all(&shape).is_some_and(|idx| idx > 0);
let mut op_inputs = vec![cotangent];
let active_mask = if needs_shape_source {
op_inputs.push(inputs[0].clone());
vec![true, false]
} else {
vec![true]
};
let cotangent = builder.add_op(
StdTensorOp::BroadcastInDim {
shape: input_shape.clone(),
shape,
dims: kept_dims.clone(),
},
vec![cotangent],
OpMode::Linear {
active_mask: vec![true],
},
op_inputs,
OpMode::Linear { active_mask },
)[0];
let prod = builder.add_op(op.clone(), vec![inputs[0].clone()], OpMode::Primal)[0];
let prod_broadcast = broadcast_reduction_output_fixed(
let prod_broadcast = broadcast_reduction_output(
builder,
ValRef::Local(prod),
inputs[0].clone(),
input_shape,
&kept_dims,
);
Expand Down Expand Up @@ -349,20 +368,28 @@ pub fn transpose_reduce_chooser(
Some(ct) => {
let kept_dims = kept_dims(input_shape.len(), axes);
let cotangent = normalize_reduction_cotangent(builder, ct, &kept_dims);
let shape = DimExpr::remap_all(input_shape, 0, 1);
let needs_shape_source = DimExpr::max_input_idx_all(&shape).is_some_and(|idx| idx > 0);
let mut op_inputs = vec![cotangent];
let active_mask = if needs_shape_source {
op_inputs.push(inputs[0].clone());
vec![true, false]
} else {
vec![true]
};
let cotangent = builder.add_op(
StdTensorOp::BroadcastInDim {
shape: input_shape.clone(),
shape,
dims: kept_dims.clone(),
},
vec![cotangent],
OpMode::Linear {
active_mask: vec![true],
},
op_inputs,
OpMode::Linear { active_mask },
)[0];
let answer = builder.add_op(op.clone(), vec![inputs[0].clone()], OpMode::Primal)[0];
let answer_broadcast = broadcast_reduction_output_fixed(
let answer_broadcast = broadcast_reduction_output(
builder,
ValRef::Local(answer),
inputs[0].clone(),
input_shape,
&kept_dims,
);
Expand All @@ -372,9 +399,10 @@ pub fn transpose_reduce_chooser(
ValRef::Local(answer_broadcast),
);
let counts = reduction_location_counts(builder, indicators, axes, input_shape);
let counts_broadcast = broadcast_reduction_output_fixed(
let counts_broadcast = broadcast_reduction_output(
builder,
ValRef::Local(counts),
inputs[0].clone(),
input_shape,
&kept_dims,
);
Expand Down Expand Up @@ -425,7 +453,7 @@ fn normalize_reduction_cotangent(
if kept_dims.is_empty() {
let scalar = builder.add_op(
StdTensorOp::Reshape {
from_shape: vec![1],
from_shape: DimExpr::from_concrete(&[1]),
to_shape: vec![],
},
vec![ValRef::Local(cotangent)],
Expand All @@ -439,18 +467,26 @@ fn normalize_reduction_cotangent(
}
}

fn broadcast_reduction_output_fixed(
fn broadcast_reduction_output(
builder: &mut FragmentBuilder<StdTensorOp>,
output: ValRef<StdTensorOp>,
input_shape: &[usize],
shape_source: ValRef<StdTensorOp>,
input_shape: &[DimExpr],
kept_dims: &[usize],
) -> LocalValId {
let shape = DimExpr::remap_all(input_shape, 0, 1);
let needs_shape_source = DimExpr::max_input_idx_all(&shape).is_some_and(|idx| idx > 0);
let inputs = if needs_shape_source {
vec![output, shape_source]
} else {
vec![output]
};
builder.add_op(
StdTensorOp::BroadcastInDim {
shape: input_shape.to_vec(),
shape,
dims: kept_dims.to_vec(),
},
vec![output],
inputs,
OpMode::Primal,
)[0]
}
Expand All @@ -471,7 +507,7 @@ fn reduction_location_counts(
builder: &mut FragmentBuilder<StdTensorOp>,
indicators: LocalValId,
axes: &[usize],
input_shape: &[usize],
input_shape: &[DimExpr],
) -> LocalValId {
builder.add_op(
StdTensorOp::ReduceSum {
Expand All @@ -491,7 +527,7 @@ fn normalize_scalar_cotangent(
if output_rank == 0 {
let scalar = builder.add_op(
StdTensorOp::Reshape {
from_shape: vec![1],
from_shape: DimExpr::from_concrete(&[1]),
to_shape: vec![],
},
vec![ValRef::Local(cotangent)],
Expand Down
Loading
Loading