Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions tenferro-prims/src/family_cpu_reduction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,16 @@ pub(crate) fn execute_mean_reduction<S: CpuScalarValue>(
return execute_unary_map(alpha, input, beta, output, |x| x);
}

let scale = scalar_from_usize::<S>(
reduced_axes
.iter()
.map(|&axis| input.dims()[axis])
.product(),
)?;
let reduced_total: usize = reduced_axes
.iter()
.map(|&axis| input.dims()[axis])
.product();
if reduced_total == 0 {
return Err(Error::InvalidArgument(
"mean reduction requires a non-empty reduction domain".into(),
));
}
let scale = scalar_from_usize::<S>(reduced_total)?;
let mean_scale = S::one() / scale;

if reduced_axes.len() == input.ndim() {
Expand Down
40 changes: 40 additions & 0 deletions tenferro-prims/src/tests/scalar_phase1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,43 @@ fn cuda_scalar_phase1_does_not_advertise_unimplemented_ops() {
)
);
}

#[test]
fn cpu_scalar_mean_reduction_rejects_empty_reduction_domain() {
let mut ctx = CpuContext::new(1);

let mean_desc = ScalarPrimsDescriptor::Reduction {
modes_a: vec![0, 1],
modes_c: vec![1],
op: ScalarReductionOp::Mean,
};
let mean_plan = <CpuBackend as TensorScalarPrims<Standard<f64>>>::plan(
&mut ctx,
&mean_desc,
&[&[0, 2], &[2]],
)
.unwrap();

let input = Tensor::<f64>::zeros(
&[0, 2],
LogicalMemorySpace::MainMemory,
MemoryOrder::ColumnMajor,
);
let mut mean_out = Tensor::<f64>::zeros(
&[2],
LogicalMemorySpace::MainMemory,
MemoryOrder::ColumnMajor,
);
let result = <CpuBackend as TensorScalarPrims<Standard<f64>>>::execute(
&mut ctx,
&mean_plan,
1.0,
&[&input],
0.0,
&mut mean_out,
);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, tenferro_device::Error::InvalidArgument(_)));
assert!(err.to_string().contains("non-empty reduction domain"));
}
6 changes: 5 additions & 1 deletion tenferro-tensor/src/tensor/constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,11 @@ impl<T: Scalar> Tensor<T> {
let strides = compute_contiguous_strides(&dims, order);
let mut data = vec![T::zero(); n * n];
for i in 0..n {
let pos = (i as isize * strides[0] + i as isize * strides[1]) as usize;
let pos = (i as isize)
.checked_mul(strides[0])
.and_then(|a| a.checked_add((i as isize).checked_mul(strides[1])?))
.and_then(|pos| usize::try_from(pos).ok())
.expect("position overflow in eye");
data[pos] = T::one();
}
Self::finish_allocation(
Expand Down
32 changes: 25 additions & 7 deletions tenferro-tensor/src/tensor/data_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,31 @@ impl<T: Scalar> Tensor<T> {
continue;
}

let src_pos = (self.offset
+ src_batch_off
+ i as isize * self.strides[0]
+ j as isize * self.strides[1]) as usize;
let dst_pos = (dst_batch_off
+ i as isize * out_strides[0]
+ j as isize * out_strides[1]) as usize;
let i_stride0 = (i as isize)
.checked_mul(self.strides[0])
.expect("stride multiplication overflow in triangular extraction");
let j_stride1 = (j as isize)
.checked_mul(self.strides[1])
.expect("stride multiplication overflow in triangular extraction");
let src_pos = self
.offset
.checked_add(src_batch_off)
.and_then(|off| off.checked_add(i_stride0))
.and_then(|off| off.checked_add(j_stride1))
.and_then(|pos| usize::try_from(pos).ok())
.expect("src position overflow in triangular extraction");

let i_out_stride0 = (i as isize)
.checked_mul(out_strides[0])
.expect("out stride multiplication overflow in triangular extraction");
let j_out_stride1 = (j as isize)
.checked_mul(out_strides[1])
.expect("out stride multiplication overflow in triangular extraction");
let dst_pos = dst_batch_off
.checked_add(i_out_stride0)
.and_then(|off| off.checked_add(j_out_stride1))
.and_then(|pos| usize::try_from(pos).ok())
.expect("dst position overflow in triangular extraction");
data[dst_pos] = src[src_pos];
}
}
Expand Down
100 changes: 100 additions & 0 deletions tenferro-tensor/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,103 @@ fn tensor_debug_is_summary_style() {
assert!(dbg.contains("logical_memory_space"));
assert!(dbg.contains("is_contiguous"));
}

#[test]
fn eye_creates_identity_matrix_col_major() {
let id = Tensor::<f64>::eye(3, LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor);
assert_eq!(id.dims(), &[3, 3]);

let data = id.buffer().as_slice().unwrap();
assert_eq!(data[0], 1.0);
assert_eq!(data[4], 1.0);
assert_eq!(data[8], 1.0);
assert_eq!(data[1], 0.0);
assert_eq!(data[2], 0.0);
assert_eq!(data[3], 0.0);
}

#[test]
fn eye_creates_identity_matrix_row_major() {
let id = Tensor::<f64>::eye(3, LogicalMemorySpace::MainMemory, MemoryOrder::RowMajor);
assert_eq!(id.dims(), &[3, 3]);

let data = id.buffer().as_slice().unwrap();
assert_eq!(data[0], 1.0);
assert_eq!(data[4], 1.0);
assert_eq!(data[8], 1.0);
assert_eq!(data[1], 0.0);
assert_eq!(data[2], 0.0);
assert_eq!(data[3], 0.0);
}

#[test]
fn tril_extracts_lower_triangular() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let a = Tensor::<f64>::from_slice(&data, &[3, 3], MemoryOrder::ColumnMajor).unwrap();
let lower = a.tril(0);

let expected = [1.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0, 0.0, 9.0];
assert_eq!(lower.buffer().as_slice().unwrap(), expected);
}

#[test]
fn triu_extracts_upper_triangular() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let a = Tensor::<f64>::from_slice(&data, &[3, 3], MemoryOrder::ColumnMajor).unwrap();
let upper = a.triu(0);

let expected = [1.0, 0.0, 0.0, 4.0, 5.0, 0.0, 7.0, 8.0, 9.0];
assert_eq!(upper.buffer().as_slice().unwrap(), expected);
}

#[test]
fn tril_with_diagonal_offset() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let a = Tensor::<f64>::from_slice(&data, &[3, 3], MemoryOrder::ColumnMajor).unwrap();
let lower = a.tril(1);

let expected = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 8.0, 9.0];
assert_eq!(lower.buffer().as_slice().unwrap(), expected);
}

#[test]
fn triu_with_diagonal_offset() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let a = Tensor::<f64>::from_slice(&data, &[3, 3], MemoryOrder::ColumnMajor).unwrap();
let upper = a.triu(-1);

let expected = [1.0, 2.0, 0.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
assert_eq!(upper.buffer().as_slice().unwrap(), expected);
}

#[test]
fn narrow_returns_subrange() {
let t = Tensor::<f64>::zeros(
&[2, 10],
LogicalMemorySpace::MainMemory,
MemoryOrder::ColumnMajor,
);
let sub = t.narrow(1, 2, 3).unwrap();
assert_eq!(sub.dims(), &[2, 3]);
}

#[test]
fn narrow_rejects_out_of_bounds() {
let t = Tensor::<f64>::zeros(
&[2, 10],
LogicalMemorySpace::MainMemory,
MemoryOrder::ColumnMajor,
);
assert!(t.narrow(1, 8, 5).is_err());
assert!(t.narrow(1, 0, 15).is_err());
}

#[test]
fn select_returns_single_slice() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let t = Tensor::<f64>::from_slice(&data, &[2, 4], MemoryOrder::ColumnMajor).unwrap();
let slice = t.select(1, 1).unwrap();
assert_eq!(slice.dims(), &[2]);
let slice_data = slice.contiguous(MemoryOrder::ColumnMajor);
assert_eq!(slice_data.buffer().as_slice().unwrap(), &[3.0, 4.0]);
}