From e02c4af72cd6490df79330d32bc291a4ac9ff93d Mon Sep 17 00:00:00 2001 From: Tristan Britt Date: Wed, 26 Jun 2024 11:57:41 -0700 Subject: [PATCH 1/8] added transposition iterator --- src/iterators.rs | 105 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) diff --git a/src/iterators.rs b/src/iterators.rs index e91b5f7..c5e9199 100644 --- a/src/iterators.rs +++ b/src/iterators.rs @@ -64,6 +64,61 @@ where { } +/// An iterator over the elements of the transpose of a Dimensional array. +/// +/// This struct is created by the `iter_transpose` method on Dimensional +/// to give access to row major elements of the transpose. Since we want +/// transposition to be out of place, not much sense in making a mutable +/// iterator for the moment. +pub struct DimensionalTransposeIter<'a, T, S, const N: usize> +where + T: Num + Copy, + S: DimensionalStorage, +{ + dimensional: &'a Dimensional, + current_index: [usize; N], + remaining: usize, +} + +impl<'a, T, S, const N: usize> Iterator for DimensionalTransposeIter<'a, T, S, N> +where + T: Num + Copy, + S: DimensionalStorage, +{ + type Item = &'a T; + + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + + // Transpose the current index + let transposed_index: [usize; N] = self.current_index.iter() + .enumerate() + .map(|(i, _)| self.current_index[N - 1 - i]) + .collect::>() + .try_into() + .expect("Failed to transpose index"); + + let result = &self.dimensional[transposed_index]; + + // Update the index for the next iteration + for i in (0..N).rev() { + self.current_index[i] += 1; + if self.current_index[i] < self.dimensional.shape()[N - 1 - i] { + break; + } + self.current_index[i] = 0; + } + + self.remaining -= 1; + Some(result) + } + + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} /// A mutable iterator over the elements of a Dimensional array. /// /// This struct is created by the `iter_mut` method on Dimensional. It provides @@ -188,6 +243,36 @@ where remaining: len, } } + + /// Returns an iterator over the eleents of the transposed array + /// The iterator yields all items in the transposed array in row-major order. + /// + /// # Examples + /// + /// ``` + /// use dimensionals::{Dimensional, vector, matrix, LinearArrayStorage}; + /// let v = vector![1, 2, 3, 4, 5]; + /// let mut iter = v.iter_transpose(); + /// assert_eq!(iter.next(), Some(&1)); + /// assert_eq!(iter.next(), Some(&2)); + + /// let m = matrix![[1, 2], [3, 4]]; + /// let mut iter = m.iter_transpose(); + /// assert_eq!(iter.next(), Some(&1)); + /// assert_eq!(iter.next(), Some(&3)); + /// assert_eq!(iter.next(), Some(&2)); + /// assert_eq!(iter.next(), Some(&4)); + /// assert_eq!(iter.next(), None); + /// + /// ``` + pub fn iter_transpose(&self) -> DimensionalTransposeIter{ + let len = self.len(); + DimensionalTransposeIter { + dimensional: self, + current_index: [0;N], + remaining: len, + } + } } // TODO: Since these are consuming, do they really need a lifetime? @@ -220,7 +305,7 @@ where #[cfg(test)] mod tests { - use crate::{matrix, storage::LinearArrayStorage, Dimensional}; + use crate::{matrix, storage::LinearArrayStorage, Dimensional, vector}; // ... (previous tests remain unchanged) @@ -234,4 +319,22 @@ mod tests { assert_eq!(iter.next(), Some(&mut 4)); assert_eq!(iter.next(), None); } + + + #[test] + fn test_iter_transpose(){ + let v = vector![1, 2, 3, 4, 5]; + let mut iter = v.iter_transpose(); + assert_eq!(iter.next(), Some(&1)); + assert_eq!(iter.next(), Some(&2)); + + let m = matrix![[1, 2], [3, 4]]; + let mut iter = m.iter_transpose(); + assert_eq!(iter.next(), Some(&1)); + assert_eq!(iter.next(), Some(&3)); + assert_eq!(iter.next(), Some(&2)); + assert_eq!(iter.next(), Some(&4)); + assert_eq!(iter.next(), None); + + } } From c712cf6adce25e9fd625b40c710eeb1dc47c5b2b Mon Sep 17 00:00:00 2001 From: Tristan Britt Date: Wed, 26 Jun 2024 13:29:03 -0700 Subject: [PATCH 2/8] [wip] first pass naive matrix mult --- src/iterators.rs | 1 + src/operators.rs | 45 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/iterators.rs b/src/iterators.rs index c5e9199..842b2e8 100644 --- a/src/iterators.rs +++ b/src/iterators.rs @@ -119,6 +119,7 @@ where (self.remaining, Some(self.remaining)) } } + /// A mutable iterator over the elements of a Dimensional array. /// /// This struct is created by the `iter_mut` method on Dimensional. It provides diff --git a/src/operators.rs b/src/operators.rs index 83cc73b..683e28c 100644 --- a/src/operators.rs +++ b/src/operators.rs @@ -1,7 +1,7 @@ -use crate::{storage::DimensionalStorage, Dimensional}; +use crate::{storage::DimensionalStorage, Dimensional, LinearArrayStorage}; use num_traits::Num; use std::ops::{ - Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign, + Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign }; /// Implements indexing operations for Dimensional arrays. @@ -212,6 +212,44 @@ where } } +///Implements matrix multiplication for 2-Dimensional arrays +impl Dimensional +where + S: DimensionalStorage, +{ + + pub fn matmul(&self, rhs: &Self) -> Dimensional { + assert_eq!( + self.shape()[1], + rhs.shape()[0], + "Requires matrices be of the shapes (MxN) % (NxK). Interior dimensions do not match." + ); + let m = self.shape()[0]; + let n = self.shape()[1]; + let k = rhs.shape()[1]; + + let shape = [m, k]; + + let mut retval: Dimensional = Dimensional::zeros(shape); + + for i in 0..m { + for j in 0..k { + let sum: T = (0..n).map(|x| { + let raveled = Dimensional::, 2>::ravel_index(&[i,x], &self.shape()); + let raveled_rhs = Dimensional::, 2>::ravel_index(&[x,j], &rhs.shape()); + let a = self.as_slice()[raveled]; + let b = rhs.as_slice()[raveled_rhs]; + a*b + }).sum(); + retval[[i,j]] = sum; + } + } + retval + + } +} + + // Assignment operations /// Implements scalar addition assignment for Dimensional arrays. @@ -490,6 +528,9 @@ mod tests { m3 *= &m2; assert_eq!(m3, matrix![[5, 12], [21, 32]]); + assert_eq!(m1.matmul(&m2), matrix![[19, 22],[43, 50]]); + + // Note: We don't test m3 /= m2 here because it would result in a matrix of zeros due to integer division } From c5d6f3ab08108ae9a29fd6f1bf65d26adf576452 Mon Sep 17 00:00:00 2001 From: Tristan Britt Date: Wed, 26 Jun 2024 14:23:37 -0700 Subject: [PATCH 3/8] [wip] zoomies --- src/operators.rs | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/operators.rs b/src/operators.rs index 683e28c..5bdc777 100644 --- a/src/operators.rs +++ b/src/operators.rs @@ -228,24 +228,26 @@ where let n = self.shape()[1]; let k = rhs.shape()[1]; + // given combination of dimensions, and the fact that current built in + // iterators and mappings only iterate pairwise with identical indices, + // something more custom is needed. naive algorithm with for looping + // given below let shape = [m, k]; - - let mut retval: Dimensional = Dimensional::zeros(shape); - - for i in 0..m { - for j in 0..k { - let sum: T = (0..n).map(|x| { - let raveled = Dimensional::, 2>::ravel_index(&[i,x], &self.shape()); - let raveled_rhs = Dimensional::, 2>::ravel_index(&[x,j], &rhs.shape()); - let a = self.as_slice()[raveled]; - let b = rhs.as_slice()[raveled_rhs]; - a*b - }).sum(); - retval[[i,j]] = sum; - } - } - retval - + let r: Vec = (0..m) + .flat_map(|i| { + (0..k).map(move |j| { + (0..n) + .map(|x| { + let raveled = Dimensional::, 2>::ravel_index(&[i, x], &self.shape()); + let raveled_rhs = Dimensional::, 2>::ravel_index(&[x, j], &rhs.shape()); + self.as_slice()[raveled] * rhs.as_slice()[raveled_rhs] + }) + .sum() + }) + }) + .collect(); + Dimensional::from_fn(shape, |[i, j]| r[k*i+j]) + } } @@ -529,7 +531,7 @@ mod tests { assert_eq!(m3, matrix![[5, 12], [21, 32]]); assert_eq!(m1.matmul(&m2), matrix![[19, 22],[43, 50]]); - + // Note: We don't test m3 /= m2 here because it would result in a matrix of zeros due to integer division } From 3093682fc80d2f96b7e7ee02d470ee59e5058b37 Mon Sep 17 00:00:00 2001 From: Tristan Britt Date: Wed, 26 Jun 2024 14:29:26 -0700 Subject: [PATCH 4/8] format --- src/iterators.rs | 22 +++++++++++----------- src/operators.rs | 44 ++++++++++++++++++++++++-------------------- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/src/iterators.rs b/src/iterators.rs index 842b2e8..94bce35 100644 --- a/src/iterators.rs +++ b/src/iterators.rs @@ -65,7 +65,7 @@ where } /// An iterator over the elements of the transpose of a Dimensional array. -/// +/// /// This struct is created by the `iter_transpose` method on Dimensional /// to give access to row major elements of the transpose. Since we want /// transposition to be out of place, not much sense in making a mutable @@ -93,7 +93,9 @@ where } // Transpose the current index - let transposed_index: [usize; N] = self.current_index.iter() + let transposed_index: [usize; N] = self + .current_index + .iter() .enumerate() .map(|(i, _)| self.current_index[N - 1 - i]) .collect::>() @@ -247,9 +249,9 @@ where /// Returns an iterator over the eleents of the transposed array /// The iterator yields all items in the transposed array in row-major order. - /// + /// /// # Examples - /// + /// /// ``` /// use dimensionals::{Dimensional, vector, matrix, LinearArrayStorage}; /// let v = vector![1, 2, 3, 4, 5]; @@ -264,13 +266,13 @@ where /// assert_eq!(iter.next(), Some(&2)); /// assert_eq!(iter.next(), Some(&4)); /// assert_eq!(iter.next(), None); - /// + /// /// ``` - pub fn iter_transpose(&self) -> DimensionalTransposeIter{ + pub fn iter_transpose(&self) -> DimensionalTransposeIter { let len = self.len(); DimensionalTransposeIter { dimensional: self, - current_index: [0;N], + current_index: [0; N], remaining: len, } } @@ -306,7 +308,7 @@ where #[cfg(test)] mod tests { - use crate::{matrix, storage::LinearArrayStorage, Dimensional, vector}; + use crate::{matrix, storage::LinearArrayStorage, vector, Dimensional}; // ... (previous tests remain unchanged) @@ -321,9 +323,8 @@ mod tests { assert_eq!(iter.next(), None); } - #[test] - fn test_iter_transpose(){ + fn test_iter_transpose() { let v = vector![1, 2, 3, 4, 5]; let mut iter = v.iter_transpose(); assert_eq!(iter.next(), Some(&1)); @@ -336,6 +337,5 @@ mod tests { assert_eq!(iter.next(), Some(&2)); assert_eq!(iter.next(), Some(&4)); assert_eq!(iter.next(), None); - } } diff --git a/src/operators.rs b/src/operators.rs index 5bdc777..9417e3e 100644 --- a/src/operators.rs +++ b/src/operators.rs @@ -1,7 +1,7 @@ use crate::{storage::DimensionalStorage, Dimensional, LinearArrayStorage}; use num_traits::Num; use std::ops::{ - Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign + Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign, }; /// Implements indexing operations for Dimensional arrays. @@ -213,11 +213,10 @@ where } ///Implements matrix multiplication for 2-Dimensional arrays -impl Dimensional +impl Dimensional where - S: DimensionalStorage, + S: DimensionalStorage, { - pub fn matmul(&self, rhs: &Self) -> Dimensional { assert_eq!( self.shape()[1], @@ -234,24 +233,30 @@ where // given below let shape = [m, k]; let r: Vec = (0..m) - .flat_map(|i| { - (0..k).map(move |j| { - (0..n) - .map(|x| { - let raveled = Dimensional::, 2>::ravel_index(&[i, x], &self.shape()); - let raveled_rhs = Dimensional::, 2>::ravel_index(&[x, j], &rhs.shape()); - self.as_slice()[raveled] * rhs.as_slice()[raveled_rhs] - }) - .sum() + .flat_map(|i| { + (0..k).map(move |j| { + (0..n) + .map(|x| { + let raveled = + Dimensional::, 2>::ravel_index( + &[i, x], + &self.shape(), + ); + let raveled_rhs = + Dimensional::, 2>::ravel_index( + &[x, j], + &rhs.shape(), + ); + self.as_slice()[raveled] * rhs.as_slice()[raveled_rhs] + }) + .sum() + }) }) - }) - .collect(); - Dimensional::from_fn(shape, |[i, j]| r[k*i+j]) - + .collect(); + Dimensional::from_fn(shape, |[i, j]| r[k * i + j]) } } - // Assignment operations /// Implements scalar addition assignment for Dimensional arrays. @@ -530,8 +535,7 @@ mod tests { m3 *= &m2; assert_eq!(m3, matrix![[5, 12], [21, 32]]); - assert_eq!(m1.matmul(&m2), matrix![[19, 22],[43, 50]]); - + assert_eq!(m1.matmul(&m2), matrix![[19, 22], [43, 50]]); // Note: We don't test m3 /= m2 here because it would result in a matrix of zeros due to integer division } From 47d11f819312ea8b646ae8739eaf6a885c04edaf Mon Sep 17 00:00:00 2001 From: Tristan Britt Date: Wed, 26 Jun 2024 15:04:36 -0700 Subject: [PATCH 5/8] transposition --- src/operators.rs | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/operators.rs b/src/operators.rs index 9417e3e..6773282 100644 --- a/src/operators.rs +++ b/src/operators.rs @@ -212,6 +212,31 @@ where } } +impl Dimensional +where + S: DimensionalStorage, +{ + pub fn transpose(&self) -> Dimensional { + let r: Vec = self + .iter_transpose() + .enumerate() + .map(|(_, val)| *val) + .collect(); + let new_shape: [usize; N] = self + .shape() + .iter() + .rev() + .copied() + .collect::>() + .try_into() + .unwrap(); + + Dimensional::from_fn(new_shape, |idxs: [usize; N]| { + r[Dimensional::, N>::ravel_index(&idxs, &new_shape)] + }) + } +} + ///Implements matrix multiplication for 2-Dimensional arrays impl Dimensional where @@ -537,6 +562,8 @@ mod tests { assert_eq!(m1.matmul(&m2), matrix![[19, 22], [43, 50]]); + assert_eq!(m1.transpose(), matrix![[1, 3], [2, 4]]) + // Note: We don't test m3 /= m2 here because it would result in a matrix of zeros due to integer division } From 2eb2fc581106cc0101380ff9f254342dbf4a6dc5 Mon Sep 17 00:00:00 2001 From: Tristan Britt Date: Wed, 26 Jun 2024 15:30:09 -0700 Subject: [PATCH 6/8] trace --- src/operators.rs | 41 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/src/operators.rs b/src/operators.rs index 6773282..c0c7d47 100644 --- a/src/operators.rs +++ b/src/operators.rs @@ -211,7 +211,22 @@ where self.zip_map(rhs, |a, b| a / b) } } - +impl Dimensional +where + S: DimensionalStorage, +{ + pub fn trace(&self) -> T { + (0..*self.shape().iter().min().unwrap()) + .map(|i| { + let idxs = [i; N]; + self.as_slice()[Dimensional::, N>::ravel_index( + &idxs, + &self.shape(), + )] + }) + .sum() + } +} impl Dimensional where S: DimensionalStorage, @@ -246,7 +261,7 @@ where assert_eq!( self.shape()[1], rhs.shape()[0], - "Requires matrices be of the shapes (MxN) % (NxK). Interior dimensions do not match." + "Interior dimensions must match for matrix multiplication" ); let m = self.shape()[0]; let n = self.shape()[1]; @@ -560,13 +575,31 @@ mod tests { m3 *= &m2; assert_eq!(m3, matrix![[5, 12], [21, 32]]); + // Note: We don't test m3 /= m2 here because it would result in a matrix of zeros due to integer division + } + + #[test] + fn test_matrix_specific_operations() { + let m1 = matrix![[1, 2], [3, 4]]; + let m2 = matrix![[5, 6], [7, 8]]; + let m3 = matrix![[1, 2, 3], [4, 5 ,6], [7, 8, 9], [10, 11, 12]]; + assert_eq!(m1.matmul(&m2), matrix![[19, 22], [43, 50]]); - assert_eq!(m1.transpose(), matrix![[1, 3], [2, 4]]) + assert_eq!(m1.transpose(), matrix![[1, 3], [2, 4]]); + + assert_eq!(m1.trace(), 5); + assert_eq!(m3.trace(), 15); - // Note: We don't test m3 /= m2 here because it would result in a matrix of zeros due to integer division } + #[test] + #[should_panic(expected = "Interior dimensions must match for matrix multiplication")] + fn test_mismatched_matrix_mult(){ + let m1 = matrix![[1, 2], [3, 4]]; + let m3 = matrix![[1, 2, 3], [4, 5 ,6], [7, 8, 9], [10, 11, 12]]; + let _ = m1.matmul(&m3); + } #[test] fn test_mixed_dimensional_operations() { let v = vector![1, 2, 3]; From 8205e44d54eeb756cf18190a940f1d4f96ec7991 Mon Sep 17 00:00:00 2001 From: Tristan Britt Date: Wed, 26 Jun 2024 16:50:06 -0700 Subject: [PATCH 7/8] much cleaner looks --- src/operators.rs | 139 ++++++++++++++++++++++++++--------------------- 1 file changed, 76 insertions(+), 63 deletions(-) diff --git a/src/operators.rs b/src/operators.rs index c0c7d47..41851eb 100644 --- a/src/operators.rs +++ b/src/operators.rs @@ -211,44 +211,74 @@ where self.zip_map(rhs, |a, b| a / b) } } + +impl Dimensional +where + S: DimensionalStorage, +{ + pub fn block(&self, idxs: [usize; N], szs: [usize; N]) -> Dimensional { + for (&i, &j) in self.shape().to_vec().iter().zip(szs.to_vec().iter()) { + assert!( + i >= j, + "Requested block sizes cannot be bigger than total number of indices {} {}", i, j + ); + } + + let mut retval: Dimensional = Dimensional::zeros(szs); + + let total_elements: usize = szs.iter().product(); + + let mut indices = vec![0; N]; + let mut block_index: usize = 0; + + while block_index < total_elements { + let mut mat_indices = idxs.to_vec(); + for i in 0..N { + mat_indices[i] += indices[i]; + } + + let value = self.as_slice()[Dimensional::, N>::ravel_index( + &mat_indices + .into_iter() + .collect::>() + .try_into() + .unwrap(), + &szs, + )]; + retval + [Dimensional::, N>::unravel_index(block_index, &szs)] = + value; + + block_index += 1; + for i in (0..N).rev() { + indices[i] += 1; + if indices[i] < szs[i] { + break; + } else { + indices[i] = 0; + } + } + } + retval + } +} + impl Dimensional where S: DimensionalStorage, { pub fn trace(&self) -> T { (0..*self.shape().iter().min().unwrap()) - .map(|i| { - let idxs = [i; N]; - self.as_slice()[Dimensional::, N>::ravel_index( - &idxs, - &self.shape(), - )] - }) - .sum() + .fold(T::zero(), |sum, i| sum + self[[i;N]]) } } -impl Dimensional +impl Dimensional where - S: DimensionalStorage, + S: DimensionalStorage, { - pub fn transpose(&self) -> Dimensional { - let r: Vec = self - .iter_transpose() - .enumerate() - .map(|(_, val)| *val) - .collect(); - let new_shape: [usize; N] = self - .shape() - .iter() - .rev() - .copied() - .collect::>() - .try_into() - .unwrap(); - - Dimensional::from_fn(new_shape, |idxs: [usize; N]| { - r[Dimensional::, N>::ravel_index(&idxs, &new_shape)] - }) + pub fn transpose(&self) -> Dimensional { + let (rows, cols) = (self.shape()[1], self.shape()[0]); + Self::from_fn([rows, cols], |[i, j]| self[[j, i]]) } } @@ -263,37 +293,11 @@ where rhs.shape()[0], "Interior dimensions must match for matrix multiplication" ); - let m = self.shape()[0]; - let n = self.shape()[1]; - let k = rhs.shape()[1]; - - // given combination of dimensions, and the fact that current built in - // iterators and mappings only iterate pairwise with identical indices, - // something more custom is needed. naive algorithm with for looping - // given below - let shape = [m, k]; - let r: Vec = (0..m) - .flat_map(|i| { - (0..k).map(move |j| { - (0..n) - .map(|x| { - let raveled = - Dimensional::, 2>::ravel_index( - &[i, x], - &self.shape(), - ); - let raveled_rhs = - Dimensional::, 2>::ravel_index( - &[x, j], - &rhs.shape(), - ); - self.as_slice()[raveled] * rhs.as_slice()[raveled_rhs] - }) - .sum() - }) - }) - .collect(); - Dimensional::from_fn(shape, |[i, j]| r[k * i + j]) + let (rows, cols) = (self.shape()[0], rhs.shape()[1]); + + Self::from_fn([rows, cols], |[i, j]| { + (0..self.shape()[1]).fold(T::zero(), |sum, k| sum + self[[i, k]] * rhs[[k, j]]) + }) } } @@ -578,11 +582,21 @@ mod tests { // Note: We don't test m3 /= m2 here because it would result in a matrix of zeros due to integer division } + #[test] + fn test_matrix_blocking(){ + let m1 = matrix![[1, 2], [3, 4]]; + let m2 = matrix![[1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12]]; + + assert_eq!(m1.block([0, 0], [1, 2]), matrix![[1, 2]]); + assert_eq!(m1.block([0,0], m1.shape()), m1); + assert_eq!(m2.block([0, 1], [2, 4]), matrix![[2, 3, 4, 5],[8, 9, 10, 11]]); + } #[test] fn test_matrix_specific_operations() { let m1 = matrix![[1, 2], [3, 4]]; let m2 = matrix![[5, 6], [7, 8]]; - let m3 = matrix![[1, 2, 3], [4, 5 ,6], [7, 8, 9], [10, 11, 12]]; + let m3 = matrix![[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]; assert_eq!(m1.matmul(&m2), matrix![[19, 22], [43, 50]]); @@ -590,14 +604,13 @@ mod tests { assert_eq!(m1.trace(), 5); assert_eq!(m3.trace(), 15); - } #[test] #[should_panic(expected = "Interior dimensions must match for matrix multiplication")] - fn test_mismatched_matrix_mult(){ + fn test_mismatched_matrix_mult() { let m1 = matrix![[1, 2], [3, 4]]; - let m3 = matrix![[1, 2, 3], [4, 5 ,6], [7, 8, 9], [10, 11, 12]]; + let m3 = matrix![[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]; let _ = m1.matmul(&m3); } #[test] From 6abbe3b06981284650cfd12adf3c7a2d9d07c04a Mon Sep 17 00:00:00 2001 From: Tristan Britt Date: Wed, 26 Jun 2024 17:10:05 -0700 Subject: [PATCH 8/8] fixed blocking --- src/operators.rs | 46 ++++++++++++++++++++-------------------------- 1 file changed, 20 insertions(+), 26 deletions(-) diff --git a/src/operators.rs b/src/operators.rs index 41851eb..7d5a8f6 100644 --- a/src/operators.rs +++ b/src/operators.rs @@ -1,4 +1,4 @@ -use crate::{storage::DimensionalStorage, Dimensional, LinearArrayStorage}; +use crate::{storage::DimensionalStorage, Dimensional}; use num_traits::Num; use std::ops::{ Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign, @@ -220,7 +220,9 @@ where for (&i, &j) in self.shape().to_vec().iter().zip(szs.to_vec().iter()) { assert!( i >= j, - "Requested block sizes cannot be bigger than total number of indices {} {}", i, j + "Requested block sizes cannot be bigger than total number of indices {} {}", + i, + j ); } @@ -228,34 +230,24 @@ where let total_elements: usize = szs.iter().product(); - let mut indices = vec![0; N]; + let mut target_index = [0; N]; let mut block_index: usize = 0; while block_index < total_elements { - let mut mat_indices = idxs.to_vec(); + let mut source_index = idxs; for i in 0..N { - mat_indices[i] += indices[i]; + source_index[i] += target_index[i]; } - let value = self.as_slice()[Dimensional::, N>::ravel_index( - &mat_indices - .into_iter() - .collect::>() - .try_into() - .unwrap(), - &szs, - )]; - retval - [Dimensional::, N>::unravel_index(block_index, &szs)] = - value; + retval[target_index] = self[source_index]; block_index += 1; for i in (0..N).rev() { - indices[i] += 1; - if indices[i] < szs[i] { + target_index[i] += 1; + if target_index[i] < szs[i] { break; } else { - indices[i] = 0; + target_index[i] = 0; } } } @@ -268,8 +260,7 @@ where S: DimensionalStorage, { pub fn trace(&self) -> T { - (0..*self.shape().iter().min().unwrap()) - .fold(T::zero(), |sum, i| sum + self[[i;N]]) + (0..*self.shape().iter().min().unwrap()).fold(T::zero(), |sum, i| sum + self[[i; N]]) } } impl Dimensional @@ -583,14 +574,17 @@ mod tests { } #[test] - fn test_matrix_blocking(){ + fn test_matrix_blocking() { let m1 = matrix![[1, 2], [3, 4]]; - let m2 = matrix![[1, 2, 3, 4, 5, 6], - [7, 8, 9, 10, 11, 12]]; + let m2 = matrix![[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]; assert_eq!(m1.block([0, 0], [1, 2]), matrix![[1, 2]]); - assert_eq!(m1.block([0,0], m1.shape()), m1); - assert_eq!(m2.block([0, 1], [2, 4]), matrix![[2, 3, 4, 5],[8, 9, 10, 11]]); + assert_eq!(m1.block([0, 0], m1.shape()), m1); + assert_eq!( + m2.block([0, 1], [2, 4]), + matrix![[2, 3, 4, 5], [8, 9, 10, 11]] + ); + assert_eq!(m2.block([1, 3], [1, 1]), matrix![[10]]); } #[test] fn test_matrix_specific_operations() {