Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,64 @@ where
{
}

/// An iterator over the elements of the transpose of a Dimensional array.
///
/// This struct is created by the `iter_transpose` method on Dimensional
/// to give access to row major elements of the transpose. Since we want
/// transposition to be out of place, not much sense in making a mutable
/// iterator for the moment.
pub struct DimensionalTransposeIter<'a, T, S, const N: usize>
where
T: Num + Copy,
S: DimensionalStorage<T, N>,
{
dimensional: &'a Dimensional<T, S, N>,
current_index: [usize; N],
remaining: usize,
}

impl<'a, T, S, const N: usize> Iterator for DimensionalTransposeIter<'a, T, S, N>
where
T: Num + Copy,
S: DimensionalStorage<T, N>,
{
type Item = &'a T;

fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}

// Transpose the current index
let transposed_index: [usize; N] = self
.current_index
.iter()
.enumerate()
.map(|(i, _)| self.current_index[N - 1 - i])
.collect::<Vec<_>>()
.try_into()
.expect("Failed to transpose index");

let result = &self.dimensional[transposed_index];

// Update the index for the next iteration
for i in (0..N).rev() {
self.current_index[i] += 1;
if self.current_index[i] < self.dimensional.shape()[N - 1 - i] {
break;
}
self.current_index[i] = 0;
}

self.remaining -= 1;
Some(result)
}

fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}

/// A mutable iterator over the elements of a Dimensional array.
///
/// This struct is created by the `iter_mut` method on Dimensional. It provides
Expand Down Expand Up @@ -188,6 +246,36 @@ where
remaining: len,
}
}

/// Returns an iterator over the eleents of the transposed array
/// The iterator yields all items in the transposed array in row-major order.
///
/// # Examples
///
/// ```
/// use dimensionals::{Dimensional, vector, matrix, LinearArrayStorage};
/// let v = vector![1, 2, 3, 4, 5];
/// let mut iter = v.iter_transpose();
/// assert_eq!(iter.next(), Some(&1));
/// assert_eq!(iter.next(), Some(&2));

/// let m = matrix![[1, 2], [3, 4]];
/// let mut iter = m.iter_transpose();
/// assert_eq!(iter.next(), Some(&1));
/// assert_eq!(iter.next(), Some(&3));
/// assert_eq!(iter.next(), Some(&2));
/// assert_eq!(iter.next(), Some(&4));
/// assert_eq!(iter.next(), None);
///
/// ```
pub fn iter_transpose(&self) -> DimensionalTransposeIter<T, S, N> {
let len = self.len();
DimensionalTransposeIter {
dimensional: self,
current_index: [0; N],
remaining: len,
}
}
}

// TODO: Since these are consuming, do they really need a lifetime?
Expand Down Expand Up @@ -220,7 +308,7 @@ where

#[cfg(test)]
mod tests {
use crate::{matrix, storage::LinearArrayStorage, Dimensional};
use crate::{matrix, storage::LinearArrayStorage, vector, Dimensional};

// ... (previous tests remain unchanged)

Expand All @@ -234,4 +322,20 @@ mod tests {
assert_eq!(iter.next(), Some(&mut 4));
assert_eq!(iter.next(), None);
}

#[test]
fn test_iter_transpose() {
let v = vector![1, 2, 3, 4, 5];
let mut iter = v.iter_transpose();
assert_eq!(iter.next(), Some(&1));
assert_eq!(iter.next(), Some(&2));

let m = matrix![[1, 2], [3, 4]];
let mut iter = m.iter_transpose();
assert_eq!(iter.next(), Some(&1));
assert_eq!(iter.next(), Some(&3));
assert_eq!(iter.next(), Some(&2));
assert_eq!(iter.next(), Some(&4));
assert_eq!(iter.next(), None);
}
}
114 changes: 114 additions & 0 deletions src/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,86 @@ where
}
}

impl<T: Num + Copy + std::iter::Sum + std::fmt::Debug, S, const N: usize> Dimensional<T, S, N>
where
S: DimensionalStorage<T, N>,
{
pub fn block(&self, idxs: [usize; N], szs: [usize; N]) -> Dimensional<T, S, N> {
for (&i, &j) in self.shape().to_vec().iter().zip(szs.to_vec().iter()) {
assert!(
i >= j,
"Requested block sizes cannot be bigger than total number of indices {} {}",
i,
j
);
}

let mut retval: Dimensional<T, S, N> = Dimensional::zeros(szs);

let total_elements: usize = szs.iter().product();

let mut target_index = [0; N];
let mut block_index: usize = 0;

while block_index < total_elements {
let mut source_index = idxs;
for i in 0..N {
source_index[i] += target_index[i];
}

retval[target_index] = self[source_index];

block_index += 1;
for i in (0..N).rev() {
target_index[i] += 1;
if target_index[i] < szs[i] {
break;
} else {
target_index[i] = 0;
}
}
}
retval
}
}

impl<T: Num + Copy + std::iter::Sum, S, const N: usize> Dimensional<T, S, N>
where
S: DimensionalStorage<T, N>,
{
pub fn trace(&self) -> T {
(0..*self.shape().iter().min().unwrap()).fold(T::zero(), |sum, i| sum + self[[i; N]])
}
}
impl<T: Num + Copy + std::iter::Sum, S> Dimensional<T, S, 2>
where
S: DimensionalStorage<T, 2>,
{
pub fn transpose(&self) -> Dimensional<T, S, 2> {
let (rows, cols) = (self.shape()[1], self.shape()[0]);
Self::from_fn([rows, cols], |[i, j]| self[[j, i]])
}
}

///Implements matrix multiplication for 2-Dimensional arrays
impl<T: Num + Copy + std::iter::Sum, S> Dimensional<T, S, 2>
where
S: DimensionalStorage<T, 2>,
{
pub fn matmul(&self, rhs: &Self) -> Dimensional<T, S, 2> {
assert_eq!(
self.shape()[1],
rhs.shape()[0],
"Interior dimensions must match for matrix multiplication"
);
let (rows, cols) = (self.shape()[0], rhs.shape()[1]);

Self::from_fn([rows, cols], |[i, j]| {
(0..self.shape()[1]).fold(T::zero(), |sum, k| sum + self[[i, k]] * rhs[[k, j]])
})
}
}

// Assignment operations

/// Implements scalar addition assignment for Dimensional arrays.
Expand Down Expand Up @@ -493,6 +573,40 @@ mod tests {
// Note: We don't test m3 /= m2 here because it would result in a matrix of zeros due to integer division
}

#[test]
fn test_matrix_blocking() {
let m1 = matrix![[1, 2], [3, 4]];
let m2 = matrix![[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]];

assert_eq!(m1.block([0, 0], [1, 2]), matrix![[1, 2]]);
assert_eq!(m1.block([0, 0], m1.shape()), m1);
assert_eq!(
m2.block([0, 1], [2, 4]),
matrix![[2, 3, 4, 5], [8, 9, 10, 11]]
);
assert_eq!(m2.block([1, 3], [1, 1]), matrix![[10]]);
}
#[test]
fn test_matrix_specific_operations() {
let m1 = matrix![[1, 2], [3, 4]];
let m2 = matrix![[5, 6], [7, 8]];
let m3 = matrix![[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]];

assert_eq!(m1.matmul(&m2), matrix![[19, 22], [43, 50]]);

assert_eq!(m1.transpose(), matrix![[1, 3], [2, 4]]);

assert_eq!(m1.trace(), 5);
assert_eq!(m3.trace(), 15);
}

#[test]
#[should_panic(expected = "Interior dimensions must match for matrix multiplication")]
fn test_mismatched_matrix_mult() {
let m1 = matrix![[1, 2], [3, 4]];
let m3 = matrix![[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]];
let _ = m1.matmul(&m3);
}
#[test]
fn test_mixed_dimensional_operations() {
let v = vector![1, 2, 3];
Expand Down