From 1d1189887ddaf2b970bc6e53d3250a5e05dcdb87 Mon Sep 17 00:00:00 2001 From: Cobord Date: Wed, 19 Nov 2025 19:01:01 -0500 Subject: [PATCH 01/13] do many steps of SimpleOptimizer so example of optimization_demo gets actually closer to the target, multi_constraint example gave false positives for is_in_manifold --- examples/multi_constraints.rs | 98 ++++++++++++++++++++++---------- examples/optimization_demo.rs | 12 +++- examples/riemannian_adam_demo.rs | 2 +- src/manifolds.rs | 78 +++++++++++++------------ src/manifolds/sphere.rs | 11 +--- src/manifolds/steifiel.rs | 11 ++-- src/optimizers.rs | 55 +++++++++++++++++- src/optimizers/multiple.rs | 31 +++++----- 8 files changed, 199 insertions(+), 99 deletions(-) diff --git a/examples/multi_constraints.rs b/examples/multi_constraints.rs index bad4673..e1e0d4d 100644 --- a/examples/multi_constraints.rs +++ b/examples/multi_constraints.rs @@ -1,10 +1,12 @@ +use burn::module::Module; use burn::module::ModuleVisitor; -use burn::nn::LinearConfig; use burn::nn::Linear; -use burn::module::Module; +use burn::nn::LinearConfig; use manopt_rs::manifolds::Constrained; -use manopt_rs::optimizers::multiple::{MultiManifoldOptimizer, MultiManifoldOptimizerConfig, ManifoldOptimizable}; +use manopt_rs::optimizers::multiple::{ + ManifoldOptimizable, MultiManifoldOptimizer, MultiManifoldOptimizerConfig, +}; use manopt_rs::prelude::*; // Example: User-defined custom manifold @@ -33,7 +35,11 @@ impl Manifold for CustomSphereManifold { new_point / norm } - fn inner(_point: Tensor, u: Tensor, v: Tensor) -> Tensor { + fn inner( + _point: Tensor, + u: Tensor, + v: Tensor, + ) -> Tensor { u * v } @@ -43,9 +49,10 @@ impl Manifold for CustomSphereManifold { point / norm } - fn is_in_manifold(_point: Tensor) -> bool { - // For now, just return true - in a real implementation you'd check the constraint - true + fn is_in_manifold(point: Tensor) -> bool { + let r_squared = point.clone().powf_scalar(2.0).sum(); + let one = Tensor::::from_floats([1.0], &<::Device>::default()); + r_squared.all_close(one, None, None) } } @@ -53,7 +60,7 @@ impl Manifold for CustomSphereManifold { pub struct TestModel { // Euclidean constrained linear layer linear_euclidean: Constrained, Euclidean>, - // Custom sphere constrained linear layer + // Custom sphere constrained linear layer linear_sphere: Constrained, CustomSphereManifold>, // Regular unconstrained linear layer linear_regular: Linear, @@ -125,7 +132,7 @@ impl TestModel { let linear_euclidean = LinearConfig::new(10, 5).init(device); let linear_sphere = LinearConfig::new(5, 3).init(device); let linear_regular = LinearConfig::new(3, 1).init(device); - + Self { linear_euclidean: Constrained::new(linear_euclidean), linear_sphere: Constrained::new(linear_sphere), @@ -138,12 +145,26 @@ struct ManifoldAwareVisitor; impl ModuleVisitor for ManifoldAwareVisitor { fn visit_float(&mut self, id: burn::module::ParamId, tensor: &Tensor) { - println!("Visiting parameter: {:?} with shape: {:?}", id, tensor.dims()); + println!( + "Visiting parameter: {:?} with shape: {:?}", + id, + tensor.dims() + ); } - fn visit_int(&mut self, _id: burn::module::ParamId, _tensor: &Tensor) {} + fn visit_int( + &mut self, + _id: burn::module::ParamId, + _tensor: &Tensor, + ) { + } - fn visit_bool(&mut self, _id: burn::module::ParamId, _tensor: &Tensor) {} + fn visit_bool( + &mut self, + _id: burn::module::ParamId, + _tensor: &Tensor, + ) { + } } fn main() { @@ -151,53 +172,68 @@ fn main() { type AutoDiffBackend = burn::backend::Autodiff; let device = Default::default(); - + // Create a model with mixed manifold constraints let model = TestModel::::new(&device); - + println!("=== Model Structure ==="); - println!("Euclidean layer manifold: {}", model.linear_euclidean.manifold_name::()); - println!("Sphere layer manifold: {}", model.linear_sphere.manifold_name::()); - + println!( + "Euclidean layer manifold: {}", + model.linear_euclidean.manifold_name::() + ); + println!( + "Sphere layer manifold: {}", + model.linear_sphere.manifold_name::() + ); + // Create multi-manifold optimizer let config = MultiManifoldOptimizerConfig::default(); let mut optimizer = MultiManifoldOptimizer::new(config); - + // Collect manifold information from the model optimizer.collect_manifolds(&model); - + // Register custom manifold for specific parameters (if needed) optimizer.register_manifold::("linear_sphere.weight".to_string()); - + println!("\n=== Manifold Information ==="); - println!("Euclidean info: {:?}", model.linear_euclidean.get_manifold_info()); + println!( + "Euclidean info: {:?}", + model.linear_euclidean.get_manifold_info() + ); println!("Sphere info: {:?}", model.linear_sphere.get_manifold_info()); - + // Example of applying constraints let constrained_model = optimizer.apply_constraints(model); - + // Visit the model to see parameter structure println!("\n=== Parameter Structure ==="); let mut visitor = ManifoldAwareVisitor; constrained_model.visit(&mut visitor); - + println!("\n=== Demonstrating Custom Manifold Operations ==="); - + // Show how the custom sphere manifold works let point = Tensor::::from_floats([3.0, 4.0, 0.0], &device); let vector = Tensor::::from_floats([1.0, 1.0, 1.0], &device); - + println!("Original point: {:?}", point.to_data()); println!("Original vector: {:?}", vector.to_data()); - + // Project point to sphere let projected_point = CustomSphereManifold::proj(point.clone()); println!("Point projected to sphere: {:?}", projected_point.to_data()); - + // Project vector to tangent space let projected_vector = CustomSphereManifold::project(projected_point.clone(), vector); - println!("Vector projected to tangent space: {:?}", projected_vector.to_data()); - + println!( + "Vector projected to tangent space: {:?}", + projected_vector.to_data() + ); + // Check if point is on manifold - println!("Is projected point on sphere? {}", CustomSphereManifold::is_in_manifold(projected_point)); + println!( + "Is projected point on sphere? {}", + CustomSphereManifold::is_in_manifold(projected_point) + ); } diff --git a/examples/optimization_demo.rs b/examples/optimization_demo.rs index 9d8c36c..0289202 100644 --- a/examples/optimization_demo.rs +++ b/examples/optimization_demo.rs @@ -1,5 +1,5 @@ use burn::optim::SimpleOptimizer; -use manopt_rs::prelude::*; +use manopt_rs::{optimizers::LessSimpleOptimizer, prelude::*}; fn main() { // Configure the optimizer @@ -39,9 +39,19 @@ fn main() { } } + println!("\nResult after 100:"); + println!("x = {}", x); + println!("Target = {}", target); + let final_loss = (x.clone() - target.clone()).powf_scalar(2.0).sum(); + println!("Loss after 100 = {}", final_loss); + + // Perform optimization steps + (x, state) = optimizer.many_steps(|_| 1.0, 400, |x| (x - target.clone()) * 2.0, x, state); + println!("\nFinal result:"); println!("x = {}", x); println!("Target = {}", target); let final_loss = (x.clone() - target.clone()).powf_scalar(2.0).sum(); println!("Final loss = {}", final_loss); + println!("State is set {}", state.is_some()); } diff --git a/examples/riemannian_adam_demo.rs b/examples/riemannian_adam_demo.rs index 6609ed7..8200f14 100644 --- a/examples/riemannian_adam_demo.rs +++ b/examples/riemannian_adam_demo.rs @@ -20,7 +20,7 @@ fn main() { let (new_tensor, state) = optimizer.step(1.0, tensor.clone(), grad, None); println!("Original tensor: {}", tensor); - println!("New tensor: {}", new_tensor); + println!("New tensor: {:.4}", new_tensor); println!("State initialized: {}", state.is_some()); println!("Riemannian Adam test completed successfully!"); diff --git a/src/manifolds.rs b/src/manifolds.rs index 1cc2802..a4ce5bd 100644 --- a/src/manifolds.rs +++ b/src/manifolds.rs @@ -9,7 +9,10 @@ use std::{fmt::Debug, marker::PhantomData}; use crate::prelude::*; pub mod steifiel; -use burn::{module::{AutodiffModule, ModuleDisplay}, tensor::backend::AutodiffBackend}; +use burn::{ + module::{AutodiffModule, ModuleDisplay}, + tensor::backend::AutodiffBackend, +}; pub use steifiel::SteifielsManifold; pub mod sphere; @@ -56,10 +59,7 @@ pub trait Manifold: Clone + Send + Sync { fn name() -> &'static str; fn project(point: Tensor, vector: Tensor) -> Tensor; - fn retract( - point: Tensor, - direction: Tensor, - ) -> Tensor; + fn retract(point: Tensor, direction: Tensor) -> Tensor; /// Convert Euclidean gradient to Riemannian gradient fn egrad2rgrad(point: Tensor, grad: Tensor) -> Tensor { @@ -71,10 +71,7 @@ pub trait Manifold: Clone + Send + Sync { -> Tensor; /// Exponential map: move from point along tangent vector u with step size - fn expmap( - point: Tensor, - direction: Tensor, - ) -> Tensor { + fn expmap(point: Tensor, direction: Tensor) -> Tensor { Self::retract(point, direction) } @@ -122,10 +119,7 @@ impl Manifold for Euclidean { vector } - fn retract( - point: Tensor, - direction: Tensor, - ) -> Tensor { + fn retract(point: Tensor, direction: Tensor) -> Tensor { point + direction } @@ -148,8 +142,8 @@ pub struct Constrained { _manifold: PhantomData, } -impl Module for Constrained -where +impl Module for Constrained +where M: Module, B: Backend, Man: Clone + Debug + Send, @@ -201,9 +195,8 @@ where } } - -impl AutodiffModule for Constrained -where +impl AutodiffModule for Constrained +where M: AutodiffModule, B: AutodiffBackend, Man: Clone + Debug + Send, @@ -216,7 +209,7 @@ where } impl burn::module::ModuleDisplayDefault for Constrained -where +where M: burn::module::ModuleDisplayDefault, Man: Clone + Debug + Send, { @@ -225,8 +218,8 @@ where } } -impl ModuleDisplay for Constrained -where +impl ModuleDisplay for Constrained +where M: ModuleDisplay, Man: Clone + Debug + Send, { @@ -242,19 +235,23 @@ impl Constrained { _manifold: PhantomData, } } - + /// Get a reference to the inner module pub fn inner(&self) -> &M { &self.module } - + /// Get a mutable reference to the inner module pub fn inner_mut(&mut self) -> &mut M { &mut self.module } - + /// Apply manifold projection to a tensor - requires explicit Backend type - pub fn project_tensor(&self, point: Tensor, vector: Tensor) -> Tensor + pub fn project_tensor( + &self, + point: Tensor, + vector: Tensor, + ) -> Tensor where B: Backend, M: Module, @@ -262,9 +259,13 @@ impl Constrained { { Man::project(point, vector) } - + /// Apply manifold retraction to a tensor - requires explicit Backend type - pub fn retract_tensor(&self, point: Tensor, direction: Tensor) -> Tensor + pub fn retract_tensor( + &self, + point: Tensor, + direction: Tensor, + ) -> Tensor where B: Backend, M: Module, @@ -272,9 +273,13 @@ impl Constrained { { Man::retract(point, direction) } - + /// Convert Euclidean gradient to Riemannian gradient - requires explicit Backend type - pub fn euclidean_to_riemannian(&self, point: Tensor, grad: Tensor) -> Tensor + pub fn euclidean_to_riemannian( + &self, + point: Tensor, + grad: Tensor, + ) -> Tensor where B: Backend, M: Module, @@ -282,7 +287,7 @@ impl Constrained { { Man::egrad2rgrad(point, grad) } - + /// Project point onto manifold - requires explicit Backend type pub fn project_to_manifold(&self, point: Tensor) -> Tensor where @@ -292,9 +297,9 @@ impl Constrained { { Man::proj(point) } - + /// Get the manifold name - pub fn manifold_name(&self) -> &'static str + pub fn manifold_name(&self) -> &'static str where B: Backend, Man: Manifold, @@ -306,11 +311,12 @@ impl Constrained { /// Trait for modules that have manifold constraints pub trait ConstrainedModule { /// Apply manifold constraints to all parameters in the module + #[must_use] fn apply_manifold_constraints(self) -> Self; - + /// Get information about the manifold constraints fn get_manifold_info(&self) -> std::collections::HashMap; - + /// Check if this module has manifold constraints fn has_manifold_constraints(&self) -> bool { true @@ -327,10 +333,10 @@ where fn apply_manifold_constraints(self) -> Self { self } - + fn get_manifold_info(&self) -> std::collections::HashMap { let mut info = std::collections::HashMap::new(); info.insert("manifold_type".to_string(), Man::name().to_string()); info } -} \ No newline at end of file +} diff --git a/src/manifolds/sphere.rs b/src/manifolds/sphere.rs index 17bfd28..4c091d2 100644 --- a/src/manifolds/sphere.rs +++ b/src/manifolds/sphere.rs @@ -13,18 +13,13 @@ impl Manifold for Sphere { "Sphere" } - fn project(_point: Tensor, vector: Tensor) -> Tensor - { + fn project(_point: Tensor, vector: Tensor) -> Tensor { // Y/||y| - vector.clone()/(vector.clone().transpose().matmul(vector)).sqrt() + vector.clone() / (vector.clone().transpose().matmul(vector)).sqrt() } - fn retract( - _point: Tensor, - _direction: Tensor, - ) -> Tensor { + fn retract(_point: Tensor, _direction: Tensor) -> Tensor { todo!("Implement retract for Sphere manifold") - } fn inner( diff --git a/src/manifolds/steifiel.rs b/src/manifolds/steifiel.rs index 3c5be4c..569164b 100644 --- a/src/manifolds/steifiel.rs +++ b/src/manifolds/steifiel.rs @@ -17,7 +17,7 @@ impl Manifold for SteifielsManifold { } /// Project direction onto tangent space at point - /// For Stiefel manifold: P_X(Z) = Z - X(X^T Z + Z^T X)/2 + /// For Stiefel manifold: `P_X(Z) = Z - X(X^T Z + Z^T X)/2` fn project(point: Tensor, direction: Tensor) -> Tensor { let xtd = point.clone().transpose().matmul(direction.clone()); let dtx = direction.clone().transpose().matmul(point.clone()); @@ -25,16 +25,13 @@ impl Manifold for SteifielsManifold { direction - point.matmul(symmetric_part) } - fn retract( - point: Tensor, - direction: Tensor, - ) -> Tensor { + fn retract(point: Tensor, direction: Tensor) -> Tensor { let s = point + direction; gram_schmidt(&s) } fn inner( - _point: Tensor< B, D>, + _point: Tensor, u: Tensor, v: Tensor, ) -> Tensor { @@ -309,7 +306,7 @@ mod test { let step = 0.1; let retracted = - SteifielsManifold::::retract(point.clone(), direction.clone()*step); + SteifielsManifold::::retract(point.clone(), direction.clone() * step); // Check that the result has orthonormal columns let q1 = retracted.clone().slice([0..3, 0..1]); diff --git a/src/optimizers.rs b/src/optimizers.rs index 4bcf6b4..d3ca525 100644 --- a/src/optimizers.rs +++ b/src/optimizers.rs @@ -2,15 +2,51 @@ //! //! This module provides Riemannian optimization algorithms that work on manifolds, //! extending classical optimization methods to handle geometric constraints. +use crate::prelude::*; use burn::module::AutodiffModule; use burn::optim::{adaptor::OptimizerAdaptor, LrDecayState, SimpleOptimizer}; use burn::record::Record; use burn::tensor::backend::AutodiffBackend; use burn::LearningRate; use std::marker::PhantomData; -use crate::prelude::*; pub mod multiple; +pub trait LessSimpleOptimizer: SimpleOptimizer { + fn many_steps( + &self, + lr_function: impl FnMut(usize) -> LearningRate, + num_steps: usize, + grad_function: impl FnMut(Tensor) -> Tensor, + tensor: Tensor, + state: Option>, + ) -> (Tensor, Option>); +} + +impl> LessSimpleOptimizer for T { + #[inline] + fn many_steps( + &self, + mut lr_function: impl FnMut(usize) -> LearningRate, + num_steps: usize, + mut grad_function: impl FnMut(Tensor) -> Tensor, + mut tensor: Tensor, + mut state: Option>, + ) -> (Tensor, Option>) { + // Perform optimization steps + for step in 0..num_steps { + // Compute gradient at tensor + let cur_grad = grad_function(tensor.clone()); + // The current learning rate for this step + let cur_lr = lr_function(step); + // Perform optimizer step + let (new_x, new_state) = self.step(cur_lr, tensor.clone(), cur_grad, state); + tensor = new_x; + state = new_state; + } + (tensor, state) + } +} + #[derive(Debug)] pub struct ManifoldRGDConfig { _manifold: PhantomData, @@ -77,6 +113,7 @@ where _state: Self::State, _device: &::Device, ) -> Self::State { + #[allow(clippy::used_underscore_binding)] _state } } @@ -86,6 +123,7 @@ where M: Manifold, B: Backend, { + #[must_use] pub fn init>( &self, ) -> OptimizerAdaptor, Mod, Back> @@ -154,40 +192,48 @@ where M: Manifold, B: Backend, { + #[must_use] pub fn new() -> Self { Self::default() } + #[must_use] pub fn with_lr(mut self, lr: f64) -> Self { self.lr = lr; self } + #[must_use] pub fn with_beta1(mut self, beta1: f64) -> Self { self.beta1 = beta1; self } + #[must_use] pub fn with_beta2(mut self, beta2: f64) -> Self { self.beta2 = beta2; self } + #[must_use] pub fn with_eps(mut self, eps: f64) -> Self { self.eps = eps; self } + #[must_use] pub fn with_weight_decay(mut self, weight_decay: f64) -> Self { self.weight_decay = weight_decay; self } + #[must_use] pub fn with_amsgrad(mut self, amsgrad: bool) -> Self { self.amsgrad = amsgrad; self } + #[must_use] pub fn with_stabilize(mut self, stabilize: Option) -> Self { self.stabilize = stabilize; self @@ -205,6 +251,7 @@ where M: Manifold, B: Backend, { + #[must_use] pub fn new(config: RiemannianAdamConfig) -> Self { Self { config } } @@ -269,7 +316,8 @@ where state.exp_avg.clone() * self.config.beta1 + rgrad.clone() * (1.0 - self.config.beta1); let inner_product = M::inner(tensor.clone(), rgrad.clone(), rgrad.clone()); - state.exp_avg_sq = state.exp_avg_sq.clone() * self.config.beta2 + inner_product * (1.0 - self.config.beta2); + state.exp_avg_sq = state.exp_avg_sq.clone() * self.config.beta2 + + inner_product * (1.0 - self.config.beta2); // Compute denominator let denom = if self.config.amsgrad { @@ -282,7 +330,9 @@ where }; // Bias correction + #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)] let bias_correction1 = 1.0 - self.config.beta1.powi(state.step as i32); + #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)] let bias_correction2 = 1.0 - self.config.beta2.powi(state.step as i32); let step_size = learning_rate * bias_correction2.sqrt() / bias_correction1; @@ -319,6 +369,7 @@ where M: Manifold, B: Backend, { + #[must_use] pub fn init>( &self, ) -> OptimizerAdaptor, Mod, Back> diff --git a/src/optimizers/multiple.rs b/src/optimizers/multiple.rs index 1c6c887..78cb9fd 100644 --- a/src/optimizers/multiple.rs +++ b/src/optimizers/multiple.rs @@ -2,17 +2,17 @@ //! //! This module provides optimizers that can handle multiple manifold types within //! a single optimization step, allowing complex models with heterogeneous constraints. -//! +//! //! Users can implement their own manifolds and use them with the multi-manifold optimizer. use std::collections::HashMap; -use std::marker::PhantomData; use std::fmt::Debug; +use std::marker::PhantomData; use burn::module::Module; -use crate::prelude::*; use crate::manifolds::Constrained; +use crate::prelude::*; /// Multi-manifold optimizer configuration #[derive(Debug, Clone)] @@ -39,11 +39,13 @@ impl Default for MultiManifoldOptimizerConfig { /// Multi-manifold optimizer that can handle different manifold types #[derive(Debug)] pub struct MultiManifoldOptimizer { + #[allow(unused)] config: MultiManifoldOptimizerConfig, _backend: PhantomData, } impl MultiManifoldOptimizer { + #[must_use] pub fn new(config: MultiManifoldOptimizerConfig) -> Self { Self { config, @@ -73,8 +75,9 @@ impl MultiManifoldOptimizer { /// Extension trait for modules with manifold constraints pub trait ManifoldOptimizable: Module { /// Apply manifold constraints to the module + #[must_use] fn apply_manifold_constraints(self) -> Self; - + /// Get information about manifold constraints fn get_manifold_info(&self) -> HashMap; } @@ -90,7 +93,7 @@ where // Apply constraints to the inner module and wrap it back self } - + fn get_manifold_info(&self) -> HashMap { let mut info = HashMap::new(); info.insert("manifold_type".to_string(), Man::name().to_string()); @@ -106,12 +109,11 @@ mod tests { type TestBackend = NdArray; - #[test] fn test_multi_manifold_optimizer() { let config = MultiManifoldOptimizerConfig::default(); let optimizer = MultiManifoldOptimizer::::new(config); - + // Test basic construction assert_eq!(optimizer.config.learning_rate, 1e-3); } @@ -121,7 +123,7 @@ mod tests { let device = Default::default(); let linear = LinearConfig::new(2, 2).init::(&device); let constrained_linear = Constrained::<_, Euclidean>::new(linear); - + let info = constrained_linear.get_manifold_info(); assert_eq!(info.get("manifold_type"), Some(&"Euclidean".to_string())); } @@ -130,15 +132,18 @@ mod tests { fn test_apply_constraints() { let config = MultiManifoldOptimizerConfig::default(); let optimizer = MultiManifoldOptimizer::::new(config); - + let device = Default::default(); let linear = LinearConfig::new(2, 2).init::(&device); let constrained_linear = Constrained::<_, Euclidean>::new(linear); - + // Test applying constraints let result = optimizer.apply_constraints(constrained_linear); - + // Should return the same module since we have a simplified implementation - assert_eq!(result.get_manifold_info().get("manifold_type"), Some(&"Euclidean".to_string())); + assert_eq!( + result.get_manifold_info().get("manifold_type"), + Some(&"Euclidean".to_string()) + ); } -} \ No newline at end of file +} From 6862a17269dce1612989b4d5501471b85d17dd47 Mon Sep 17 00:00:00 2001 From: Cobord Date: Thu, 20 Nov 2025 13:22:09 -0500 Subject: [PATCH 02/13] unwrap -> expect (particularly relevant for today), test that can do grad_remove as well when only using that gradient information once --- src/manifolds/steifiel.rs | 39 ++++++++++++++++++++++++++++++++++++++- src/optimizers.rs | 16 +++++++++++----- 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/src/manifolds/steifiel.rs b/src/manifolds/steifiel.rs index 569164b..b9350f6 100644 --- a/src/manifolds/steifiel.rs +++ b/src/manifolds/steifiel.rs @@ -403,7 +403,44 @@ mod test { .matmul(x.clone()) .sum(); let grads = loss.backward(); - let x_grad = x.grad(&grads).unwrap(); + let x_grad = x + .grad(&grads) + .expect("The gradients do exist we just did loss.backwards()"); + // Convert gradient to autodiff backend and ensure independent tensor + let x_grad_data = x_grad.to_data(); + let x_grad_ad = Tensor::::from_data(x_grad_data, &x.device()); + // Clone x to ensure independent tensor for optimizer + let x_clone = x.clone(); + let (new_x, _) = optimiser.step(0.1, x_clone, x_grad_ad, None); + x = new_x.detach().require_grad(); + println!("Loss: {}", loss); + } + println!("Optimised tensor: {}", x); + } + + #[test] + fn test_optimiser_remove() { + let optimiser = ManifoldRGD::, TestBackend>::default(); + + let a = create_test_matrix(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0]); + + let mut x = Tensor::::random( + [3, 3], + burn::tensor::Distribution::Normal(1., 1.), + &a.device(), + ) + .require_grad(); + for _i in 0..100 { + let loss = x + .clone() + .transpose() + .matmul(a.clone()) + .matmul(x.clone()) + .sum(); + let mut grads = loss.backward(); + let x_grad = x + .grad_remove(&mut grads) + .expect("The gradients do exist we just did loss.backwards()"); // Convert gradient to autodiff backend and ensure independent tensor let x_grad_data = x_grad.to_data(); let x_grad_ad = Tensor::::from_data(x_grad_data, &x.device()); diff --git a/src/optimizers.rs b/src/optimizers.rs index d3ca525..02b84ba 100644 --- a/src/optimizers.rs +++ b/src/optimizers.rs @@ -321,7 +321,9 @@ where // Compute denominator let denom = if self.config.amsgrad { - let max_exp_avg_sq = state.max_exp_avg_sq.as_ref().unwrap(); + let max_exp_avg_sq = state.max_exp_avg_sq.as_ref().expect( + "On an initial None state, having config.amsgrad be True makes this maximum field set to 0. \ + If there was an input state then it will be present because of earlier steps"); let new_max = Tensor::max_pair(max_exp_avg_sq.clone(), state.exp_avg_sq.clone()); state.max_exp_avg_sq = Some(new_max.clone()); new_max.sqrt() + self.config.eps @@ -470,10 +472,11 @@ mod tests { // Check that AMSGrad state is initialized assert!(state.is_some()); - let state = state.unwrap(); + let state = + state.expect("RiemannianAdam optimizer always gives back an initialized state on step"); assert!( state.max_exp_avg_sq.is_some(), - "AMSGrad should initialize max_exp_avg_sq" + "AMSGrad should initialize max_exp_avg_sq. See the explanation around the compute denominator part of step" ); } @@ -512,13 +515,16 @@ mod tests { // First step let (tensor1, state1) = optimizer.step(1.0, tensor, grad.clone(), None); assert!(state1.is_some()); - let state1 = state1.unwrap(); + let state1 = state1.expect( + "RiemannianAdam optimizer always gives back an initialized state on step even with None initial state"); assert_eq!(state1.step, 1); // Second step with state let (_, state2) = optimizer.step(1.0, tensor1, grad, Some(state1)); assert!(state2.is_some()); - let state2 = state2.unwrap(); + let state2 = state2.expect( + "There was an input state so RiemannianAdam optimizer's step modifies that and returns it" + ); assert_eq!(state2.step, 2); } } From d023aba4bbbdb736d3e48835d531775ef07f989e Mon Sep 17 00:00:00 2001 From: Cobord Date: Thu, 20 Nov 2025 13:36:47 -0500 Subject: [PATCH 03/13] if B is not AutodiffBackend, then no effect anyway --- src/optimizers.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimizers.rs b/src/optimizers.rs index 02b84ba..f1e518b 100644 --- a/src/optimizers.rs +++ b/src/optimizers.rs @@ -40,7 +40,7 @@ impl> LessSimpleOptimizer for T { let cur_lr = lr_function(step); // Perform optimizer step let (new_x, new_state) = self.step(cur_lr, tensor.clone(), cur_grad, state); - tensor = new_x; + tensor = new_x.detach().require_grad(); state = new_state; } (tensor, state) From 71d9935a229f411375431cd1fc871fe8391027ad Mon Sep 17 00:00:00 2001 From: Cobord Date: Thu, 20 Nov 2025 13:45:29 -0500 Subject: [PATCH 04/13] a place where that detach and require_grad has an effect and is needed to make the test pass --- src/manifolds/steifiel.rs | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/manifolds/steifiel.rs b/src/manifolds/steifiel.rs index b9350f6..d721776 100644 --- a/src/manifolds/steifiel.rs +++ b/src/manifolds/steifiel.rs @@ -67,6 +67,8 @@ fn gram_schmidt(v: &Tensor) -> Tensor { #[cfg(test)] mod test { + use crate::optimizers::LessSimpleOptimizer; + use super::*; use burn::{ backend::{Autodiff, NdArray}, @@ -453,6 +455,39 @@ mod test { println!("Optimised tensor: {}", x); } + #[test] + fn test_optimiser_many() { + let optimiser = ManifoldRGD::, TestBackend>::default(); + + let a = create_test_matrix(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0]); + + let mut x = Tensor::::random( + [3, 3], + burn::tensor::Distribution::Normal(1., 1.), + &a.device(), + ) + .require_grad(); + + fn grad_fn( + x: Tensor, 2>, + a: Tensor, 2>, + ) -> Tensor, 2> { + let loss = x.clone().transpose().matmul(a).matmul(x.clone()).sum(); + let mut grads = loss.backward(); + let x_grad = x + .grad_remove(&mut grads) + .expect("The gradients do exist we just did loss.backwards()"); + // Convert gradient to autodiff backend and ensure independent tensor + let x_grad_ad = Tensor::::from_data(x_grad.to_data(), &x.device()); + x_grad_ad + } + + let mut state = None; + (x, state) = optimiser.many_steps(|_| 0.1, 100, |x| grad_fn(x, a.clone()), x, state); + assert!(state.is_none()); + println!("Optimised tensor: {}", x); + } + #[test] fn test_simple_optimizer_step() { let optimiser = ManifoldRGD::, TestBackend>::default(); From 520b9321bc91958bfec493ffab1cd1686103a96c Mon Sep 17 00:00:00 2001 From: Cobord Date: Tue, 25 Nov 2025 02:18:43 -0500 Subject: [PATCH 05/13] many steps and constrained modules moved to own files since they are not strictly tied of other content of the files they were in, expand tests and examples --- examples/multi_constraints.rs | 16 ++- examples/optimization_demo.rs | 27 +++- examples/riemannian_adam_demo.rs | 14 +- src/constrained_module.rs | 217 ++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/manifolds.rs | 221 ++----------------------------- src/manifolds/sphere.rs | 1 - src/manifolds/steifiel.rs | 32 ++++- src/optimizers.rs | 56 ++------ src/optimizers/many_steps.rs | 50 +++++++ src/optimizers/multiple.rs | 2 +- 11 files changed, 372 insertions(+), 265 deletions(-) create mode 100644 src/constrained_module.rs create mode 100644 src/optimizers/many_steps.rs diff --git a/examples/multi_constraints.rs b/examples/multi_constraints.rs index e1e0d4d..5dda54b 100644 --- a/examples/multi_constraints.rs +++ b/examples/multi_constraints.rs @@ -3,7 +3,7 @@ use burn::module::ModuleVisitor; use burn::nn::Linear; use burn::nn::LinearConfig; -use manopt_rs::manifolds::Constrained; +use manopt_rs::constrained_module::Constrained; use manopt_rs::optimizers::multiple::{ ManifoldOptimizable, MultiManifoldOptimizer, MultiManifoldOptimizerConfig, }; @@ -54,6 +54,12 @@ impl Manifold for CustomSphereManifold { let one = Tensor::::from_floats([1.0], &<::Device>::default()); r_squared.all_close(one, None, None) } + + fn is_tangent_at(point: Tensor, vector: Tensor) -> bool { + let dot_product = (point * vector).sum(); + let zero = Tensor::::from_floats([0.0], &<::Device>::default()); + dot_product.all_close(zero, None, Some(1e-6)) + } } #[derive(Debug, Clone)] @@ -234,6 +240,12 @@ fn main() { // Check if point is on manifold println!( "Is projected point on sphere? {}", - CustomSphereManifold::is_in_manifold(projected_point) + CustomSphereManifold::is_in_manifold(projected_point.clone()) + ); + + // Check if vector is tangent at point on manifold + println!( + "Is projected vector tangent to point on sphere? {}", + CustomSphereManifold::is_tangent_at(projected_point, projected_vector) ); } diff --git a/examples/optimization_demo.rs b/examples/optimization_demo.rs index 0289202..8bc7c27 100644 --- a/examples/optimization_demo.rs +++ b/examples/optimization_demo.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use burn::optim::SimpleOptimizer; use manopt_rs::{optimizers::LessSimpleOptimizer, prelude::*}; @@ -18,6 +20,8 @@ fn main() { Tensor::::from_floats([0.0, 0.0, 0.0], &Default::default()); let mut state = None; + let mut loss_decay: HashMap = HashMap::new(); + println!("Target: {}", target); println!("Initial x: {}", x); println!("\nOptimization steps:"); @@ -35,15 +39,21 @@ fn main() { // Print progress every 10 steps if step % 10 == 0 { let loss = (x.clone() - target.clone()).powf_scalar(2.0).sum(); - println!("Step {}: x = {}, loss = {}", step, x, loss); + let loss_scalar = loss.into_scalar(); + println!("Step {}: x = {}, loss = {:.5}", step, x, loss_scalar); + loss_decay.insert(step, loss_scalar); } } println!("\nResult after 100:"); println!("x = {}", x); println!("Target = {}", target); - let final_loss = (x.clone() - target.clone()).powf_scalar(2.0).sum(); - println!("Loss after 100 = {}", final_loss); + let final_loss = (x.clone() - target.clone()) + .powf_scalar(2.0) + .sum() + .into_scalar(); + println!("Loss after 100 = {:.5}", final_loss); + loss_decay.insert(100, final_loss); // Perform optimization steps (x, state) = optimizer.many_steps(|_| 1.0, 400, |x| (x - target.clone()) * 2.0, x, state); @@ -51,7 +61,14 @@ fn main() { println!("\nFinal result:"); println!("x = {}", x); println!("Target = {}", target); - let final_loss = (x.clone() - target.clone()).powf_scalar(2.0).sum(); - println!("Final loss = {}", final_loss); + let final_loss = (x.clone() - target.clone()) + .powf_scalar(2.0) + .sum() + .into_scalar(); + println!("Final loss = {:.5}", final_loss); println!("State is set {}", state.is_some()); + loss_decay.insert(500, final_loss); + let mut sorted_losses: Vec<(usize, f32)> = loss_decay.into_iter().collect(); + sorted_losses.sort_by_key(|z| z.0); + println!("The loss decayed as follows: {:?}", sorted_losses); } diff --git a/examples/riemannian_adam_demo.rs b/examples/riemannian_adam_demo.rs index 8200f14..9238f6d 100644 --- a/examples/riemannian_adam_demo.rs +++ b/examples/riemannian_adam_demo.rs @@ -1,5 +1,5 @@ use burn::optim::SimpleOptimizer; -use manopt_rs::prelude::*; +use manopt_rs::{optimizers::LessSimpleOptimizer, prelude::*}; fn main() { println!("Testing Riemannian Adam optimizer..."); @@ -23,5 +23,17 @@ fn main() { println!("New tensor: {:.4}", new_tensor); println!("State initialized: {}", state.is_some()); + fn grad_function(_x: Tensor) -> Tensor { + Tensor::::ones([2, 2], &Default::default()) + } + // Perform more optimization steps + let (new_tensor, state) = + optimizer.many_steps(|_| 1.0, 9, grad_function, tensor.clone(), state); + println!("After even more steps Tensor: {:.4}", new_tensor); + println!( + "After even more steps State initialized: {}", + state.is_some() + ); + println!("Riemannian Adam test completed successfully!"); } diff --git a/src/constrained_module.rs b/src/constrained_module.rs new file mode 100644 index 0000000..01ec1f0 --- /dev/null +++ b/src/constrained_module.rs @@ -0,0 +1,217 @@ +//! Riemannian manifolds for constrained optimization. +//! +//! This module defines manifolds and their operations for Riemannian optimization. + +use std::{fmt::Debug, marker::PhantomData}; + +use crate::prelude::*; + +use burn::{ + module::{AutodiffModule, ModuleDisplay}, + tensor::backend::AutodiffBackend, +}; + +#[derive(Clone, Debug)] +pub struct Constrained { + module: M, + _manifold: PhantomData, +} + +impl Module for Constrained +where + M: Module, + B: Backend, + Man: Clone + Debug + Send, +{ + type Record = M::Record; + + fn collect_devices(&self, devices: burn::module::Devices) -> burn::module::Devices { + self.module.collect_devices(devices) + } + + fn fork(self, device: &B::Device) -> Self { + let module = self.module.fork(device); + Self { + module, + _manifold: PhantomData, + } + } + + fn to_device(self, device: &B::Device) -> Self { + let module = self.module.to_device(device); + Self { + module, + _manifold: PhantomData, + } + } + + fn visit>(&self, visitor: &mut Visitor) { + self.module.visit(visitor); + } + + fn map>(self, mapper: &mut Mapper) -> Self { + let module = self.module.map(mapper); + Self { + module, + _manifold: PhantomData, + } + } + + fn load_record(self, record: Self::Record) -> Self { + let module = self.module.load_record(record); + Self { + module, + _manifold: PhantomData, + } + } + + fn into_record(self) -> Self::Record { + self.module.into_record() + } +} + +impl AutodiffModule for Constrained +where + M: AutodiffModule, + B: AutodiffBackend, + Man: Clone + Debug + Send, +{ + type InnerModule = M::InnerModule; + + fn valid(&self) -> Self::InnerModule { + self.module.valid() + } +} + +impl burn::module::ModuleDisplayDefault for Constrained +where + M: burn::module::ModuleDisplayDefault, + Man: Clone + Debug + Send, +{ + fn content(&self, content: burn::module::Content) -> Option { + self.module.content(content) + } +} + +impl ModuleDisplay for Constrained +where + M: ModuleDisplay, + Man: Clone + Debug + Send, +{ + fn format(&self, passed_settings: burn::module::DisplaySettings) -> String { + format!("Constrained<{}>", self.module.format(passed_settings)) + } +} + +impl Constrained { + pub fn new(module: M) -> Self { + Self { + module, + _manifold: PhantomData, + } + } + + /// Get a reference to the inner module + pub fn inner(&self) -> &M { + &self.module + } + + /// Get a mutable reference to the inner module + pub fn inner_mut(&mut self) -> &mut M { + &mut self.module + } + + /// Apply manifold projection to a tensor - requires explicit Backend type + pub fn project_tensor( + &self, + point: Tensor, + vector: Tensor, + ) -> Tensor + where + B: Backend, + M: Module, + Man: Manifold + Clone + Debug + Send, + { + Man::project(point, vector) + } + + /// Apply manifold retraction to a tensor - requires explicit Backend type + pub fn retract_tensor( + &self, + point: Tensor, + direction: Tensor, + ) -> Tensor + where + B: Backend, + M: Module, + Man: Manifold + Clone + Debug + Send, + { + Man::retract(point, direction) + } + + /// Convert Euclidean gradient to Riemannian gradient - requires explicit Backend type + pub fn euclidean_to_riemannian( + &self, + point: Tensor, + grad: Tensor, + ) -> Tensor + where + B: Backend, + M: Module, + Man: Manifold + Clone + Debug + Send, + { + Man::egrad2rgrad(point, grad) + } + + /// Project point onto manifold - requires explicit Backend type + pub fn project_to_manifold(&self, point: Tensor) -> Tensor + where + B: Backend, + M: Module, + Man: Manifold + Clone + Debug + Send, + { + Man::proj(point) + } + + /// Get the manifold name + pub fn manifold_name(&self) -> &'static str + where + B: Backend, + Man: Manifold, + { + Man::name() + } +} + +/// Trait for modules that have manifold constraints +pub trait ConstrainedModule { + /// Apply manifold constraints to all parameters in the module + #[must_use] + fn apply_manifold_constraints(self) -> Self; + + /// Get information about the manifold constraints + fn get_manifold_info(&self) -> std::collections::HashMap; + + /// Check if this module has manifold constraints + fn has_manifold_constraints(&self) -> bool { + true + } +} + +/// Blanket implementation for Constrained wrapper +impl ConstrainedModule for Constrained +where + M: Module, + B: Backend, + Man: Manifold + Clone + Debug + Send, +{ + fn apply_manifold_constraints(self) -> Self { + self + } + + fn get_manifold_info(&self) -> std::collections::HashMap { + let mut info = std::collections::HashMap::new(); + info.insert("manifold_type".to_string(), Man::name().to_string()); + info + } +} diff --git a/src/lib.rs b/src/lib.rs index 7bdbd5c..8dd69de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod constrained_module; pub mod manifolds; pub mod optimizers; pub mod prelude; diff --git a/src/manifolds.rs b/src/manifolds.rs index a4ce5bd..1cd4320 100644 --- a/src/manifolds.rs +++ b/src/manifolds.rs @@ -1,21 +1,18 @@ //! Riemannian manifolds for constrained optimization. //! -//! This module defines manifolds and their operations for Riemannian optimization. +//! This module defines manifolds and their operations. //! Each manifold implements geometric operations like projection, retraction, //! exponential maps, and parallel transport. -use std::{fmt::Debug, marker::PhantomData}; +use std::fmt::Debug; use crate::prelude::*; pub mod steifiel; -use burn::{ - module::{AutodiffModule, ModuleDisplay}, - tensor::backend::AutodiffBackend, -}; pub use steifiel::SteifielsManifold; pub mod sphere; +pub use sphere::Sphere; /// A Riemannian manifold defines the geometric structure for optimization. /// @@ -100,6 +97,13 @@ pub trait Manifold: Clone + Send + Sync { fn is_in_manifold(_point: Tensor) -> bool { false } + + /// Check if a vector is in the tangent space at point + /// given that point is in the manifold. + /// By default, this is not implemented and returns `false`. + fn is_tangent_at(_point: Tensor, _vector: Tensor) -> bool { + false + } } /// Euclidean manifold - the simplest case where no projection is needed @@ -135,208 +139,3 @@ impl Manifold for Euclidean { true } } - -#[derive(Clone, Debug)] -pub struct Constrained { - module: M, - _manifold: PhantomData, -} - -impl Module for Constrained -where - M: Module, - B: Backend, - Man: Clone + Debug + Send, -{ - type Record = M::Record; - - fn collect_devices(&self, devices: burn::module::Devices) -> burn::module::Devices { - self.module.collect_devices(devices) - } - - fn fork(self, device: &B::Device) -> Self { - let module = self.module.fork(device); - Self { - module, - _manifold: PhantomData, - } - } - - fn to_device(self, device: &B::Device) -> Self { - let module = self.module.to_device(device); - Self { - module, - _manifold: PhantomData, - } - } - - fn visit>(&self, visitor: &mut Visitor) { - self.module.visit(visitor); - } - - fn map>(self, mapper: &mut Mapper) -> Self { - let module = self.module.map(mapper); - Self { - module, - _manifold: PhantomData, - } - } - - fn load_record(self, record: Self::Record) -> Self { - let module = self.module.load_record(record); - Self { - module, - _manifold: PhantomData, - } - } - - fn into_record(self) -> Self::Record { - self.module.into_record() - } -} - -impl AutodiffModule for Constrained -where - M: AutodiffModule, - B: AutodiffBackend, - Man: Clone + Debug + Send, -{ - type InnerModule = M::InnerModule; - - fn valid(&self) -> Self::InnerModule { - self.module.valid() - } -} - -impl burn::module::ModuleDisplayDefault for Constrained -where - M: burn::module::ModuleDisplayDefault, - Man: Clone + Debug + Send, -{ - fn content(&self, content: burn::module::Content) -> Option { - self.module.content(content) - } -} - -impl ModuleDisplay for Constrained -where - M: ModuleDisplay, - Man: Clone + Debug + Send, -{ - fn format(&self, passed_settings: burn::module::DisplaySettings) -> String { - format!("Constrained<{}>", self.module.format(passed_settings)) - } -} - -impl Constrained { - pub fn new(module: M) -> Self { - Self { - module, - _manifold: PhantomData, - } - } - - /// Get a reference to the inner module - pub fn inner(&self) -> &M { - &self.module - } - - /// Get a mutable reference to the inner module - pub fn inner_mut(&mut self) -> &mut M { - &mut self.module - } - - /// Apply manifold projection to a tensor - requires explicit Backend type - pub fn project_tensor( - &self, - point: Tensor, - vector: Tensor, - ) -> Tensor - where - B: Backend, - M: Module, - Man: Manifold + Clone + Debug + Send, - { - Man::project(point, vector) - } - - /// Apply manifold retraction to a tensor - requires explicit Backend type - pub fn retract_tensor( - &self, - point: Tensor, - direction: Tensor, - ) -> Tensor - where - B: Backend, - M: Module, - Man: Manifold + Clone + Debug + Send, - { - Man::retract(point, direction) - } - - /// Convert Euclidean gradient to Riemannian gradient - requires explicit Backend type - pub fn euclidean_to_riemannian( - &self, - point: Tensor, - grad: Tensor, - ) -> Tensor - where - B: Backend, - M: Module, - Man: Manifold + Clone + Debug + Send, - { - Man::egrad2rgrad(point, grad) - } - - /// Project point onto manifold - requires explicit Backend type - pub fn project_to_manifold(&self, point: Tensor) -> Tensor - where - B: Backend, - M: Module, - Man: Manifold + Clone + Debug + Send, - { - Man::proj(point) - } - - /// Get the manifold name - pub fn manifold_name(&self) -> &'static str - where - B: Backend, - Man: Manifold, - { - Man::name() - } -} - -/// Trait for modules that have manifold constraints -pub trait ConstrainedModule { - /// Apply manifold constraints to all parameters in the module - #[must_use] - fn apply_manifold_constraints(self) -> Self; - - /// Get information about the manifold constraints - fn get_manifold_info(&self) -> std::collections::HashMap; - - /// Check if this module has manifold constraints - fn has_manifold_constraints(&self) -> bool { - true - } -} - -/// Blanket implementation for Constrained wrapper -impl ConstrainedModule for Constrained -where - M: Module, - B: Backend, - Man: Manifold + Clone + Debug + Send, -{ - fn apply_manifold_constraints(self) -> Self { - self - } - - fn get_manifold_info(&self) -> std::collections::HashMap { - let mut info = std::collections::HashMap::new(); - info.insert("manifold_type".to_string(), Man::name().to_string()); - info - } -} diff --git a/src/manifolds/sphere.rs b/src/manifolds/sphere.rs index 4c091d2..7331f67 100644 --- a/src/manifolds/sphere.rs +++ b/src/manifolds/sphere.rs @@ -1,6 +1,5 @@ use crate::prelude::*; -/// Euclidean manifold - the simplest case where no projection is needed #[derive(Clone, Debug)] pub struct Sphere; diff --git a/src/manifolds/steifiel.rs b/src/manifolds/steifiel.rs index d721776..bc532ce 100644 --- a/src/manifolds/steifiel.rs +++ b/src/manifolds/steifiel.rs @@ -483,15 +483,43 @@ mod test { } let mut state = None; + let x_original: Tensor = + Tensor::::from_data(x.to_data(), &Default::default()); + let a_original: Tensor = + Tensor::::from_data(a.to_data(), &Default::default()); + let unoptimised_loss = x_original + .clone() + .transpose() + .matmul(a_original.clone()) + .matmul(x_original.clone()) + .sum() + .into_scalar(); + println!( + "Unoptimised tensor: {} with loss {}", + x_original, unoptimised_loss + ); (x, state) = optimiser.many_steps(|_| 0.1, 100, |x| grad_fn(x, a.clone()), x, state); assert!(state.is_none()); - println!("Optimised tensor: {}", x); + let x_optimised: Tensor = + Tensor::::from_data(x.to_data(), &Default::default()); + let optimised_loss = x_optimised + .clone() + .transpose() + .matmul(a_original) + .matmul(x_optimised.clone()) + .sum() + .into_scalar(); + println!( + "Optimised tensor: {} with loss {}", + x_optimised, optimised_loss + ); + assert!(optimised_loss <= unoptimised_loss, + "The optimimisation should have lowered the loss function. It was {unoptimised_loss} before and {optimised_loss} after"); } #[test] fn test_simple_optimizer_step() { let optimiser = ManifoldRGD::, TestBackend>::default(); - // Create simple test tensors let point = create_test_matrix(3, 2, vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]); diff --git a/src/optimizers.rs b/src/optimizers.rs index f1e518b..ed8e168 100644 --- a/src/optimizers.rs +++ b/src/optimizers.rs @@ -9,46 +9,12 @@ use burn::record::Record; use burn::tensor::backend::AutodiffBackend; use burn::LearningRate; use std::marker::PhantomData; +pub mod many_steps; pub mod multiple; - -pub trait LessSimpleOptimizer: SimpleOptimizer { - fn many_steps( - &self, - lr_function: impl FnMut(usize) -> LearningRate, - num_steps: usize, - grad_function: impl FnMut(Tensor) -> Tensor, - tensor: Tensor, - state: Option>, - ) -> (Tensor, Option>); -} - -impl> LessSimpleOptimizer for T { - #[inline] - fn many_steps( - &self, - mut lr_function: impl FnMut(usize) -> LearningRate, - num_steps: usize, - mut grad_function: impl FnMut(Tensor) -> Tensor, - mut tensor: Tensor, - mut state: Option>, - ) -> (Tensor, Option>) { - // Perform optimization steps - for step in 0..num_steps { - // Compute gradient at tensor - let cur_grad = grad_function(tensor.clone()); - // The current learning rate for this step - let cur_lr = lr_function(step); - // Perform optimizer step - let (new_x, new_state) = self.step(cur_lr, tensor.clone(), cur_grad, state); - tensor = new_x.detach().require_grad(); - state = new_state; - } - (tensor, state) - } -} +pub use many_steps::LessSimpleOptimizer; #[derive(Debug)] -pub struct ManifoldRGDConfig { +pub struct ManifoldRGDConfig, B: Backend> { _manifold: PhantomData, _backend: PhantomData, } @@ -110,11 +76,17 @@ where } fn to_device( - _state: Self::State, - _device: &::Device, + state: Self::State, + device: &::Device, ) -> Self::State { - #[allow(clippy::used_underscore_binding)] - _state + const DECAY_STATE_TO_DEVICE: bool = false; + if DECAY_STATE_TO_DEVICE { + ManifoldRGDState { + lr_decay: state.lr_decay.to_device(device), + } + } else { + state + } } } @@ -155,7 +127,7 @@ where /// .with_amsgrad(true); /// ``` #[derive(Debug, Clone)] -pub struct RiemannianAdamConfig { +pub struct RiemannianAdamConfig, B: Backend> { pub lr: f64, pub beta1: f64, pub beta2: f64, diff --git a/src/optimizers/many_steps.rs b/src/optimizers/many_steps.rs new file mode 100644 index 0000000..2c53c52 --- /dev/null +++ b/src/optimizers/many_steps.rs @@ -0,0 +1,50 @@ +//! A optimizer that allows for many steps with a given learning schedule +//! and a way of evaluating the gradient function on arbitrary +//! points. This way we can step using `SimpleOptimizer::step` with that gradient +//! several times. +use crate::prelude::*; +use burn::{optim::SimpleOptimizer, LearningRate}; + +/// A optimizer that allows for many steps with a given learning schedule +/// and a way of evaluating the gradient function on arbitrary +/// points. This way we can step using `SimpleOptimizer::step` with that gradient +/// several times. +pub trait LessSimpleOptimizer: SimpleOptimizer { + fn many_steps( + &self, + lr_function: impl FnMut(usize) -> LearningRate, + num_steps: usize, + grad_function: impl FnMut(Tensor) -> Tensor, + tensor: Tensor, + state: Option>, + ) -> (Tensor, Option>); +} + +/// The implementation of `LessSimpleOptimizer` is completely determined +/// by how `SimpleOptimizer` has been implemented because we are +/// just taking gradients using the input `grad_function` and steping with +/// `SimpleOptimizer::step` +impl> LessSimpleOptimizer for T { + #[inline] + fn many_steps( + &self, + mut lr_function: impl FnMut(usize) -> LearningRate, + num_steps: usize, + mut grad_function: impl FnMut(Tensor) -> Tensor, + mut tensor: Tensor, + mut state: Option>, + ) -> (Tensor, Option>) { + // Perform optimization steps + for step in 0..num_steps { + // Compute gradient at tensor + let cur_grad = grad_function(tensor.clone()); + // The current learning rate for this step + let cur_lr = lr_function(step); + // Perform optimizer step + let (new_x, new_state) = self.step(cur_lr, tensor.clone(), cur_grad, state); + tensor = new_x.detach().require_grad(); + state = new_state; + } + (tensor, state) + } +} diff --git a/src/optimizers/multiple.rs b/src/optimizers/multiple.rs index 78cb9fd..7932d31 100644 --- a/src/optimizers/multiple.rs +++ b/src/optimizers/multiple.rs @@ -11,7 +11,7 @@ use std::marker::PhantomData; use burn::module::Module; -use crate::manifolds::Constrained; +use crate::constrained_module::Constrained; use crate::prelude::*; /// Multi-manifold optimizer configuration From 4c3e18aa7d079e113309ed82ad5a9069e1d054c0 Mon Sep 17 00:00:00 2001 From: Cobord Date: Tue, 2 Dec 2025 17:26:09 -0500 Subject: [PATCH 06/13] debugging, methods on Manifold are supposed to take D rank tensors but fail with runtime errors if they are not precisely 1 or 2 depending on the example, an optimizer that is allowed to use hessian but just ignores it --- examples/multi_constraints.rs | 7 +- src/manifolds.rs | 39 +++++-- src/manifolds/sphere.rs | 163 ++++++++++++++++++++++++++-- src/manifolds/steifiel.rs | 18 +++ src/optimizers.rs | 1 + src/optimizers/hessian_optimizer.rs | 137 +++++++++++++++++++++++ 6 files changed, 344 insertions(+), 21 deletions(-) create mode 100644 src/optimizers/hessian_optimizer.rs diff --git a/examples/multi_constraints.rs b/examples/multi_constraints.rs index 5dda54b..cab6699 100644 --- a/examples/multi_constraints.rs +++ b/examples/multi_constraints.rs @@ -14,6 +14,9 @@ use manopt_rs::prelude::*; pub struct CustomSphereManifold; impl Manifold for CustomSphereManifold { + type PointOnManifold = Tensor; + type TangentVectorWithoutPoint = Tensor; + fn new() -> Self { Self } @@ -51,13 +54,13 @@ impl Manifold for CustomSphereManifold { fn is_in_manifold(point: Tensor) -> bool { let r_squared = point.clone().powf_scalar(2.0).sum(); - let one = Tensor::::from_floats([1.0], &<::Device>::default()); + let one = Tensor::::from_floats([1.0], &r_squared.device()); r_squared.all_close(one, None, None) } fn is_tangent_at(point: Tensor, vector: Tensor) -> bool { let dot_product = (point * vector).sum(); - let zero = Tensor::::from_floats([0.0], &<::Device>::default()); + let zero = Tensor::::from_floats([0.0], &dot_product.device()); dot_product.all_close(zero, None, Some(1e-6)) } } diff --git a/src/manifolds.rs b/src/manifolds.rs index 1cd4320..cf800bf 100644 --- a/src/manifolds.rs +++ b/src/manifolds.rs @@ -32,6 +32,9 @@ pub use sphere::Sphere; /// struct MyManifold; /// /// impl Manifold for MyManifold { +/// type PointOnManifold = Tensor; +/// +/// type TangentVectorWithoutPoint = Tensor; /// fn new() -> Self { MyManifold } /// fn name() -> &'static str { "MyManifold" } /// @@ -52,27 +55,36 @@ pub use sphere::Sphere; /// } /// ``` pub trait Manifold: Clone + Send + Sync { + type PointOnManifold; + type TangentVectorWithoutPoint; + fn new() -> Self; fn name() -> &'static str; + /// Project `vector` to the tangent space at `point` fn project(point: Tensor, vector: Tensor) -> Tensor; + + // Move along the manifold from `point` along the tangent vector `direction` with step size fn retract(point: Tensor, direction: Tensor) -> Tensor; - /// Convert Euclidean gradient to Riemannian gradient + /// Convert Euclidean gradient `grad` to Riemannian gradient at `point` fn egrad2rgrad(point: Tensor, grad: Tensor) -> Tensor { Self::project(point, grad) } - /// Riemannian inner product at a given point + /// Riemannian inner product at a given `point` + /// `u` and `v` are in the tangent space at `point` fn inner(point: Tensor, u: Tensor, v: Tensor) -> Tensor; - /// Exponential map: move from point along tangent vector u with step size + /// Exponential map: move from `point` along tangent vector `direction` with step size fn expmap(point: Tensor, direction: Tensor) -> Tensor { Self::retract(point, direction) } - /// Parallel transport of tangent vector from point1 to point2 + /// Parallel transport of a tangent vector `tangent` from `point1` to `point2` + /// By default, this is not accurately implemented and ignores the metric/connection + /// just projecting to the tangent space. fn parallel_transport( _point1: Tensor, point2: Tensor, @@ -82,25 +94,25 @@ pub trait Manifold: Clone + Send + Sync { Self::project_tangent(point2, tangent) } - /// Project vector to tangent space at point + /// Project `vector` to the tangent space at `point` fn project_tangent(point: Tensor, vector: Tensor) -> Tensor { Self::project(point, vector) } - /// Project point onto manifold + /// Project `point` onto manifold fn proj(point: Tensor) -> Tensor { point } - /// Check if a point is in the manifold. - /// By default, this is not implemented and returns `false`. + /// Check if a `point` is in the manifold. + /// By default, this is not accurately implemented and returns `false`. fn is_in_manifold(_point: Tensor) -> bool { false } - /// Check if a vector is in the tangent space at point - /// given that point is in the manifold. - /// By default, this is not implemented and returns `false`. + /// Check if a `vector` is in the tangent space at `point` + /// given that `point` is in the manifold. + /// By default, this is not accurately implemented and returns `false`. fn is_tangent_at(_point: Tensor, _vector: Tensor) -> bool { false } @@ -111,6 +123,9 @@ pub trait Manifold: Clone + Send + Sync { pub struct Euclidean; impl Manifold for Euclidean { + type PointOnManifold = Tensor; + + type TangentVectorWithoutPoint = Tensor; fn new() -> Self { Self } @@ -132,7 +147,7 @@ impl Manifold for Euclidean { u: Tensor, v: Tensor, ) -> Tensor { - u * v + (u * v).sum_dim(D - 1) } fn is_in_manifold(_point: Tensor) -> bool { diff --git a/src/manifolds/sphere.rs b/src/manifolds/sphere.rs index 7331f67..92933b7 100644 --- a/src/manifolds/sphere.rs +++ b/src/manifolds/sphere.rs @@ -4,6 +4,10 @@ use crate::prelude::*; pub struct Sphere; impl Manifold for Sphere { + type PointOnManifold = Tensor; + + type TangentVectorWithoutPoint = Tensor; + fn new() -> Self { Self } @@ -12,13 +16,17 @@ impl Manifold for Sphere { "Sphere" } - fn project(_point: Tensor, vector: Tensor) -> Tensor { - // Y/||y| - vector.clone() / (vector.clone().transpose().matmul(vector)).sqrt() + fn project(point: Tensor, vector: Tensor) -> Tensor { + // For sphere: project vector orthogonal to point + let dot_product = (point.clone() * vector.clone()).sum_dim(D - 1); + vector - point * dot_product } - fn retract(_point: Tensor, _direction: Tensor) -> Tensor { - todo!("Implement retract for Sphere manifold") + fn retract(point: Tensor, direction: Tensor) -> Tensor { + // For sphere: normalize the result + let new_point = point + direction; + let norm = new_point.clone().powf_scalar(2.0).sum().sqrt().unsqueeze(); + new_point / norm } fn inner( @@ -29,7 +37,148 @@ impl Manifold for Sphere { u.transpose().matmul(v) } - fn is_in_manifold(_point: Tensor) -> bool { - true + fn is_in_manifold(point: Tensor) -> bool { + let r_squared = point.clone().powf_scalar(2.0).sum_dim(D - 1); + let one = Tensor::::ones(r_squared.shape(), &point.device()); + r_squared.all_close(one, None, None) + } + + fn is_tangent_at(point: Tensor, vector: Tensor) -> bool { + let dot_product = (point * vector).sum(); + let zero = Tensor::::from_floats([0.0], &dot_product.device()); + dot_product.all_close(zero, None, Some(1e-6)) + } +} + +#[cfg(test)] +mod test { + use crate::prelude::Manifold; + + use super::Sphere; + use burn::{ + backend::{Autodiff, NdArray}, + tensor::Tensor, + }; + + type TestBackend = Autodiff; + type TestTensor = Tensor; + + const TOLERANCE: f32 = 1e-6; + + fn assert_tensor_close(a: &TestTensor, b: &TestTensor, tol: f32) { + let diff = (a.clone() - b.clone()).abs(); + let max_diff = diff.max().into_scalar(); + assert!( + max_diff < tol, + "Tensors differ by {}, tolerance: {}", + max_diff, + tol + ); + } + + fn create_test_matrix(rows: usize, values: Vec) -> TestTensor { + let device = Default::default(); + let data = &values[0..rows]; + Tensor::from_floats(data, &device) + } + + #[test] + fn test_manifold_creation() { + let _manifold = >::new(); + assert_eq!(>::name(), "Sphere"); + } + + #[test] + fn test_projection_tangent_space() { + // Create a point on the Sphere manifold + let point = create_test_matrix(6, vec![3.0 / 5.0, 0.0, 0.0, 4.0 / 5.0, 0.0, 0.0]); + + // Create a direction vector + let direction = create_test_matrix(6, vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]); + + let projected = + >::project(point.clone(), direction.clone()); + + // The projection should be orthogonal to the point + // i.e., point^T * projected should be 0 + let product = (point.clone() * projected.clone()).sum(); + let max_entry = product.abs().max().into_scalar(); + assert!( + max_entry < TOLERANCE, + "Projected direction not in tangent space: absoulte value of the dot product = {}", + max_entry + ); + } + + #[test] + fn test_projection_preserves_tangent_vectors() { + // Create a point on the Sphere manifold + let point = create_test_matrix(6, vec![3.0 / 5.0, 0.0, 0.0, 4.0 / 5.0, 0.0, 0.0]); + + assert!( + Sphere::is_in_manifold(point.clone()), + "This is a point on the sphere by construction" + ); + + // Create a direction vector + let direction = create_test_matrix(6, vec![4.0 / 5.0, 0.2, 0.3, -3.0 / 5.0, 0.5, 0.6]); + + assert!( + Sphere::is_tangent_at(point.clone(), direction.clone()), + "This direction is orthogonal to point by construction" + ); + + let projected = + >::project(point.clone(), direction.clone()); + + // The projection should be orthogonal to the point + // i.e., point^T * projected should be 0 + let product = (point.clone() * projected.clone()).sum(); + let max_entry = product.abs().max().into_scalar(); + assert!( + max_entry < TOLERANCE, + "Projected direction not in tangent space: absoulte value of the dot product = {}", + max_entry + ); + + assert!( + Sphere::is_tangent_at(point.clone(), projected.clone()), + "Projecting something already in the tangent space stays in the tangent space" + ); + assert_tensor_close(&projected, &direction, TOLERANCE); + } + + #[test] + fn test_retraction_preserves_sphere_property() { + let point = create_test_matrix(6, vec![3.0 / 5.0, 0.0, 0.0, 4.0 / 5.0, 0.0, 0.0]); + + assert!( + Sphere::is_in_manifold(point.clone()), + "This is a point on the sphere by construction" + ); + + let direction = create_test_matrix(6, vec![4.0 / 5.0, 0.2, 0.3, -3.0 / 5.0, 0.5, 0.6]); + + let moved = Sphere::retract(point, direction); + + assert!(Sphere::is_in_manifold(moved)); + } + + #[test] + fn test_parallel_transport() { + let point = create_test_matrix(6, vec![3.0 / 5.0, 0.0, 0.0, 4.0 / 5.0, 0.0, 0.0]); + + assert!( + Sphere::is_in_manifold(point.clone()), + "This is a point on the sphere by construction" + ); + + let direction = create_test_matrix(6, vec![4.0 / 5.0, 0.2, 0.3, -3.0 / 5.0, 0.5, 0.6]); + + let moved_point = Sphere::retract(point.clone(), direction.clone()); + let moved_vector = Sphere::parallel_transport(point, moved_point.clone(), direction); + + assert!(Sphere::is_in_manifold(moved_point.clone())); + assert!(Sphere::is_tangent_at(moved_point, moved_vector)); } } diff --git a/src/manifolds/steifiel.rs b/src/manifolds/steifiel.rs index bc532ce..b56cb27 100644 --- a/src/manifolds/steifiel.rs +++ b/src/manifolds/steifiel.rs @@ -1,3 +1,5 @@ +use burn::tensor::cast::ToElement; + use crate::prelude::*; #[derive(Debug, Clone, Default)] @@ -6,6 +8,10 @@ pub struct SteifielsManifold { } impl Manifold for SteifielsManifold { + type PointOnManifold = Tensor; + + type TangentVectorWithoutPoint = Tensor; + fn new() -> Self { SteifielsManifold { _backend: std::marker::PhantomData, @@ -38,6 +44,14 @@ impl Manifold for SteifielsManifold { // For Stiefel manifold, we use the standard Euclidean inner product u * v } + + fn is_tangent_at(point: Tensor, vector: Tensor) -> bool { + let xtv = point.clone().transpose().matmul(vector.clone()); + let vtx = vector.clone().transpose().matmul(point.clone()); + let skew = xtv + vtx.transpose(); + let max_skew = skew.abs().max().into_scalar(); + max_skew.to_f64() < 1e-6 + } } fn gram_schmidt(v: &Tensor) -> Tensor { @@ -296,6 +310,10 @@ mod test { "Tangent space property violated: max skew = {}", max_skew ); + assert!( + SteifielsManifold::is_tangent_at(point, tangent), + "Tangent space property violated: max skew unknown" + ) } #[test] diff --git a/src/optimizers.rs b/src/optimizers.rs index ed8e168..4887b3b 100644 --- a/src/optimizers.rs +++ b/src/optimizers.rs @@ -9,6 +9,7 @@ use burn::record::Record; use burn::tensor::backend::AutodiffBackend; use burn::LearningRate; use std::marker::PhantomData; +pub mod hessian_optimizer; pub mod many_steps; pub mod multiple; pub use many_steps::LessSimpleOptimizer; diff --git a/src/optimizers/hessian_optimizer.rs b/src/optimizers/hessian_optimizer.rs new file mode 100644 index 0000000..6c52ad8 --- /dev/null +++ b/src/optimizers/hessian_optimizer.rs @@ -0,0 +1,137 @@ +use burn::{ + optim::SimpleOptimizer, prelude::Backend, record::Record, tensor::Tensor, LearningRate, +}; + +/// TODO document and construct an implementation +pub trait SimpleHessianOptimizer: Send + Sync + Clone + SimpleOptimizer +where + B: Backend, +{ + /// The state of the optimizer. It also implements [record](Record), so that it can be saved. + type StateWithHessian: Record + + Clone + + From> + + Into> + + 'static; + + /// The optimizer step is performed for one tensor at a time with its gradient, hessian and state. + /// + /// Note that the state is passed as parameter, so implementations don't have to handle + /// the saving and loading of recorded states. + fn step_with_hessian( + &self, + lr: LearningRate, + tensor: Tensor, + grad: Tensor, + hessian: Tensor, + state: Option>, + ) -> (Tensor, Option>); + + /// The optimizer step is performed for one tensor at a time with its gradient, hessian and state. + /// + /// Note that the state is passed as parameter, so implementations don't have to handle + /// the saving and loading of recorded states. + fn step( + &self, + lr: LearningRate, + tensor: Tensor, + grad: Tensor, + hessian: Option>, + state: Option>, + ) -> (Tensor, Option>) { + if let Some(hessian) = hessian { + self.step_with_hessian(lr, tensor, grad, hessian, state) + } else { + let (new_pt, new_state) = + SimpleOptimizer::step(self, lr, tensor, grad, state.map(Into::into)); + (new_pt, new_state.map(Into::into)) + } + } + + /// Change the device of the state. + /// + /// This function will be called accordindly to have the state on the same device as the + /// gradient and the tensor when the [step](SimpleOptimizer::step) function is called. + fn to_device( + state: Self::StateWithHessian, + device: &B::Device, + ) -> Self::StateWithHessian; +} + +/// A `SimpleOptimizer` also works as a `SimpleHessianOptimizer` by ignoring the Hessian information +impl> SimpleHessianOptimizer for T { + type StateWithHessian = >::State; + + fn step_with_hessian( + &self, + lr: LearningRate, + tensor: Tensor, + grad: Tensor, + _hessian: Tensor, + state: Option>, + ) -> (Tensor, Option>) { + self.step(lr, tensor, grad, state) + } + + fn to_device( + state: Self::StateWithHessian, + device: &::Device, + ) -> Self::StateWithHessian { + >::to_device(state, device) + } +} + +/// A optimizer that allows for many steps with a given learning schedule +/// and a way of evaluating the gradient and hessian functions on arbitrary +/// points. This way we can step using `SimpleHessianOptimizer::step` with +/// that gradient and hessian several times. +pub trait LessSimpleHessianOptimizer: SimpleHessianOptimizer { + fn many_steps( + &self, + lr_function: impl FnMut(usize) -> LearningRate, + num_steps: usize, + grad_function: impl FnMut(Tensor) -> Tensor, + hessian_function: impl FnMut(Tensor) -> Option>, + tensor: Tensor, + state: Option>, + ) -> (Tensor, Option>); +} + +/// The implementation of `LessSimpleHessianOptimizer` is completely determined +/// by how `SimpleHessianOptimizer` has been implemented because we are +/// just taking gradients using the input `grad_function` and hessians with `hessian_function` +/// and steping with `SimpleHessianOptimizer::step` +impl> LessSimpleHessianOptimizer for T { + #[inline] + fn many_steps( + &self, + mut lr_function: impl FnMut(usize) -> LearningRate, + num_steps: usize, + mut grad_function: impl FnMut(Tensor) -> Tensor, + mut hessian_function: impl FnMut(Tensor) -> Option>, + mut tensor: Tensor, + mut state: Option>, + ) -> (Tensor, Option>) { + // Perform optimization steps + for step in 0..num_steps { + // Compute gradient at tensor + let cur_grad = grad_function(tensor.clone()); + // Compute hessian at tensor + let cur_hessian = hessian_function(tensor.clone()); + // The current learning rate for this step + let cur_lr = lr_function(step); + // Perform optimizer step + let (new_x, new_state) = SimpleHessianOptimizer::step( + self, + cur_lr, + tensor.clone(), + cur_grad, + cur_hessian, + state, + ); + tensor = new_x.detach().require_grad(); + state = new_state; + } + (tensor, state) + } +} From ef13d680535d12ce6c39fb885c595a0d95b78039 Mon Sep 17 00:00:00 2001 From: Cobord Date: Tue, 9 Dec 2025 21:05:22 -0500 Subject: [PATCH 07/13] progress on fixing channels as the first all but some const parameter dimensions if fix all those coordinates the remaining tensor is a single point on the manifold --- examples/multi_constraints.rs | 45 ++++++++----- src/manifolds.rs | 115 +++++++++++++++++++--------------- src/manifolds/sphere.rs | 52 ++++++++------- src/manifolds/steifiel.rs | 46 ++++++++++---- src/optimizers.rs | 2 +- 5 files changed, 159 insertions(+), 101 deletions(-) diff --git a/examples/multi_constraints.rs b/examples/multi_constraints.rs index cab6699..075ed79 100644 --- a/examples/multi_constraints.rs +++ b/examples/multi_constraints.rs @@ -14,8 +14,7 @@ use manopt_rs::prelude::*; pub struct CustomSphereManifold; impl Manifold for CustomSphereManifold { - type PointOnManifold = Tensor; - type TangentVectorWithoutPoint = Tensor; + const RANK_PER_POINT: usize = 1; fn new() -> Self { Self @@ -27,14 +26,21 @@ impl Manifold for CustomSphereManifold { fn project(point: Tensor, vector: Tensor) -> Tensor { // For sphere: project vector orthogonal to point - let dot_product = (point.clone() * vector.clone()).sum(); - vector - point * dot_product.unsqueeze() + debug_assert!(point.shape() == vector.shape()); + let dot_product = + (point.clone() * vector.clone()).sum_dim(D - >::RANK_PER_POINT); + vector - point * dot_product } fn retract(point: Tensor, direction: Tensor) -> Tensor { // For sphere: normalize the result + debug_assert!(point.shape() == direction.shape()); let new_point = point + direction; - let norm = new_point.clone().powf_scalar(2.0).sum().sqrt().unsqueeze(); + let norm = new_point + .clone() + .powf_scalar(2.0) + .sum_dim(D - >::RANK_PER_POINT) + .sqrt(); new_point / norm } @@ -43,25 +49,34 @@ impl Manifold for CustomSphereManifold { u: Tensor, v: Tensor, ) -> Tensor { - u * v + (u * v).sum_dim(D - >::RANK_PER_POINT) } fn proj(point: Tensor) -> Tensor { // Project point onto unit sphere - let norm = point.clone().powf_scalar(2.0).sum().sqrt().unsqueeze(); + let norm = point + .clone() + .powf_scalar(2.0) + .sum_dim(D - >::RANK_PER_POINT) + .sqrt(); point / norm } - fn is_in_manifold(point: Tensor) -> bool { - let r_squared = point.clone().powf_scalar(2.0).sum(); - let one = Tensor::::from_floats([1.0], &r_squared.device()); - r_squared.all_close(one, None, None) + fn is_in_manifold(point: Tensor) -> Tensor { + let r_squared = point + .powf_scalar(2.0) + .sum_dim(D - >::RANK_PER_POINT); + let one = r_squared.ones_like(); + r_squared.is_close(one, None, None) } - fn is_tangent_at(point: Tensor, vector: Tensor) -> bool { - let dot_product = (point * vector).sum(); - let zero = Tensor::::from_floats([0.0], &dot_product.device()); - dot_product.all_close(zero, None, Some(1e-6)) + fn is_tangent_at( + point: Tensor, + vector: Tensor, + ) -> Tensor { + let dot_product = (point * vector).sum_dim(D - >::RANK_PER_POINT); + let zeros = dot_product.zeros_like(); + dot_product.is_close(zeros, None, Some(1e-6)) } } diff --git a/src/manifolds.rs b/src/manifolds.rs index cf800bf..1f8dd98 100644 --- a/src/manifolds.rs +++ b/src/manifolds.rs @@ -25,61 +25,45 @@ pub use sphere::Sphere; /// /// # Example Implementation /// -/// ```rust +/// ```r u st /// use manopt_rs::prelude::*; /// /// #[derive(Clone)] /// struct MyManifold; /// /// impl Manifold for MyManifold { -/// type PointOnManifold = Tensor; -/// -/// type TangentVectorWithoutPoint = Tensor; -/// fn new() -> Self { MyManifold } -/// fn name() -> &'static str { "MyManifold" } -/// -/// fn project(point: Tensor, vector: Tensor) -> Tensor { -/// // Project vector to tangent space at point -/// vector -/// } -/// -/// fn retract(point: Tensor, direction: Tensor) -> Tensor { -/// // Move along manifold from point in direction with step size -/// point + direction -/// } -/// -/// fn inner(_point: Tensor, u: Tensor, v: Tensor) -> Tensor { -/// // Riemannian inner product at point -/// u * v -/// } /// } /// ``` pub trait Manifold: Clone + Send + Sync { - type PointOnManifold; - type TangentVectorWithoutPoint; + const RANK_PER_POINT: usize; fn new() -> Self; fn name() -> &'static str; + fn specific_name(s: &Shape) -> String { + let dims = &s.dims; + let num_dims = dims.len(); + let (channel_dims, manifold_dims) = dims.split_at(num_dims - Self::RANK_PER_POINT); + format!( + "{channel_dims:?} Channels worth of points in {} with specific n's {manifold_dims:?}", + Self::name() + ) + } + + fn acceptable_shape(s: &Shape) -> bool { + s.num_dims() >= Self::RANK_PER_POINT + } /// Project `vector` to the tangent space at `point` fn project(point: Tensor, vector: Tensor) -> Tensor; - // Move along the manifold from `point` along the tangent vector `direction` with step size - fn retract(point: Tensor, direction: Tensor) -> Tensor; - /// Convert Euclidean gradient `grad` to Riemannian gradient at `point` fn egrad2rgrad(point: Tensor, grad: Tensor) -> Tensor { Self::project(point, grad) } - /// Riemannian inner product at a given `point` - /// `u` and `v` are in the tangent space at `point` - fn inner(point: Tensor, u: Tensor, v: Tensor) - -> Tensor; - - /// Exponential map: move from `point` along tangent vector `direction` with step size - fn expmap(point: Tensor, direction: Tensor) -> Tensor { - Self::retract(point, direction) + /// Project `vector` to the tangent space at `point` + fn project_tangent(point: Tensor, vector: Tensor) -> Tensor { + Self::project(point, vector) } /// Parallel transport of a tangent vector `tangent` from `point1` to `point2` @@ -91,31 +75,35 @@ pub trait Manifold: Clone + Send + Sync { tangent: Tensor, ) -> Tensor { // Default implementation: project to tangent space at point2 - Self::project_tangent(point2, tangent) + Self::project_tangent(point2, tangent.into()) } - /// Project `vector` to the tangent space at `point` - fn project_tangent(point: Tensor, vector: Tensor) -> Tensor { - Self::project(point, vector) + /// Move along the manifold from `point` along the tangent vector `direction` with step size + fn retract(point: Tensor, direction: Tensor) -> Tensor; + + /// Exponential map: move from `point` along tangent vector `direction` with step size + fn expmap(point: Tensor, direction: Tensor) -> Tensor { + Self::retract(point, direction) } + /// Riemannian inner product at a given `point` + /// `u` and `v` are in the tangent space at `point` + fn inner(point: Tensor, u: Tensor, v: Tensor) + -> Tensor; + /// Project `point` onto manifold - fn proj(point: Tensor) -> Tensor { - point - } + fn proj(point: Tensor) -> Tensor; /// Check if a `point` is in the manifold. - /// By default, this is not accurately implemented and returns `false`. - fn is_in_manifold(_point: Tensor) -> bool { - false - } + fn is_in_manifold(point: Tensor) -> Tensor; /// Check if a `vector` is in the tangent space at `point` /// given that `point` is in the manifold. /// By default, this is not accurately implemented and returns `false`. - fn is_tangent_at(_point: Tensor, _vector: Tensor) -> bool { - false - } + fn is_tangent_at( + point: Tensor, + vector: Tensor, + ) -> Tensor; } /// Euclidean manifold - the simplest case where no projection is needed @@ -123,9 +111,8 @@ pub trait Manifold: Clone + Send + Sync { pub struct Euclidean; impl Manifold for Euclidean { - type PointOnManifold = Tensor; + const RANK_PER_POINT: usize = 1; - type TangentVectorWithoutPoint = Tensor; fn new() -> Self { Self } @@ -150,7 +137,31 @@ impl Manifold for Euclidean { (u * v).sum_dim(D - 1) } - fn is_in_manifold(_point: Tensor) -> bool { - true + fn is_in_manifold( + point: Tensor, + ) -> burn::tensor::Tensor { + point + .clone() + .detach() + .is_nan() + .any_dim(>::RANK_PER_POINT) + .bool_not() + } + + fn proj(point: Tensor) -> Tensor { + point + } + + fn is_tangent_at( + point: Tensor, + vector: Tensor, + ) -> Tensor { + let vector_exists = vector + .clone() + .detach() + .is_nan() + .any_dim(>::RANK_PER_POINT) + .bool_not(); + Self::is_in_manifold(point).bool_and(vector_exists) } } diff --git a/src/manifolds/sphere.rs b/src/manifolds/sphere.rs index 92933b7..dad0aaf 100644 --- a/src/manifolds/sphere.rs +++ b/src/manifolds/sphere.rs @@ -4,9 +4,7 @@ use crate::prelude::*; pub struct Sphere; impl Manifold for Sphere { - type PointOnManifold = Tensor; - - type TangentVectorWithoutPoint = Tensor; + const RANK_PER_POINT: usize = 1; fn new() -> Self { Self @@ -25,7 +23,7 @@ impl Manifold for Sphere { fn retract(point: Tensor, direction: Tensor) -> Tensor { // For sphere: normalize the result let new_point = point + direction; - let norm = new_point.clone().powf_scalar(2.0).sum().sqrt().unsqueeze(); + let norm = new_point.clone().powf_scalar(2.0).sum_dim(D - 1).sqrt(); new_point / norm } @@ -34,19 +32,27 @@ impl Manifold for Sphere { u: Tensor, v: Tensor, ) -> Tensor { - u.transpose().matmul(v) + u * v.sum_dim(D - 1) + } + + fn is_in_manifold(point: Tensor) -> Tensor { + let r_squared = point.powf_scalar(2.0).sum_dim(D - 1); + let one = r_squared.ones_like(); + r_squared.is_close(one, None, None) } - fn is_in_manifold(point: Tensor) -> bool { - let r_squared = point.clone().powf_scalar(2.0).sum_dim(D - 1); - let one = Tensor::::ones(r_squared.shape(), &point.device()); - r_squared.all_close(one, None, None) + fn is_tangent_at( + point: Tensor, + vector: Tensor, + ) -> Tensor { + let dot_product = (point * vector).sum_dim(D - 1); + let zero = dot_product.zeros_like(); + dot_product.is_close(zero, None, Some(1e-6)) } - fn is_tangent_at(point: Tensor, vector: Tensor) -> bool { - let dot_product = (point * vector).sum(); - let zero = Tensor::::from_floats([0.0], &dot_product.device()); - dot_product.all_close(zero, None, Some(1e-6)) + fn proj(point: Tensor) -> Tensor { + let norm = point.clone().powf_scalar(2.0).sum_dim(D - 1).sqrt(); + point / norm } } @@ -86,6 +92,10 @@ mod test { fn test_manifold_creation() { let _manifold = >::new(); assert_eq!(>::name(), "Sphere"); + assert_eq!(>::specific_name(&burn::tensor::Shape{dims: vec![5]}), + "[] Channels worth of points in Sphere with specific n's [5]"); + assert_eq!(>::specific_name(&burn::tensor::Shape{dims: vec![10,30,5]}), + "[10, 30] Channels worth of points in Sphere with specific n's [5]"); } #[test] @@ -116,7 +126,7 @@ mod test { let point = create_test_matrix(6, vec![3.0 / 5.0, 0.0, 0.0, 4.0 / 5.0, 0.0, 0.0]); assert!( - Sphere::is_in_manifold(point.clone()), + Sphere::is_in_manifold(point.clone()).into_scalar(), "This is a point on the sphere by construction" ); @@ -124,7 +134,7 @@ mod test { let direction = create_test_matrix(6, vec![4.0 / 5.0, 0.2, 0.3, -3.0 / 5.0, 0.5, 0.6]); assert!( - Sphere::is_tangent_at(point.clone(), direction.clone()), + Sphere::is_tangent_at(point.clone(), direction.clone()).into_scalar(), "This direction is orthogonal to point by construction" ); @@ -142,7 +152,7 @@ mod test { ); assert!( - Sphere::is_tangent_at(point.clone(), projected.clone()), + Sphere::is_tangent_at(point.clone(), projected.clone()).into_scalar(), "Projecting something already in the tangent space stays in the tangent space" ); assert_tensor_close(&projected, &direction, TOLERANCE); @@ -153,7 +163,7 @@ mod test { let point = create_test_matrix(6, vec![3.0 / 5.0, 0.0, 0.0, 4.0 / 5.0, 0.0, 0.0]); assert!( - Sphere::is_in_manifold(point.clone()), + Sphere::is_in_manifold(point.clone()).into_scalar(), "This is a point on the sphere by construction" ); @@ -161,7 +171,7 @@ mod test { let moved = Sphere::retract(point, direction); - assert!(Sphere::is_in_manifold(moved)); + assert!(Sphere::is_in_manifold(moved).into_scalar()); } #[test] @@ -169,7 +179,7 @@ mod test { let point = create_test_matrix(6, vec![3.0 / 5.0, 0.0, 0.0, 4.0 / 5.0, 0.0, 0.0]); assert!( - Sphere::is_in_manifold(point.clone()), + Sphere::is_in_manifold(point.clone()).into_scalar(), "This is a point on the sphere by construction" ); @@ -178,7 +188,7 @@ mod test { let moved_point = Sphere::retract(point.clone(), direction.clone()); let moved_vector = Sphere::parallel_transport(point, moved_point.clone(), direction); - assert!(Sphere::is_in_manifold(moved_point.clone())); - assert!(Sphere::is_tangent_at(moved_point, moved_vector)); + assert!(Sphere::is_in_manifold(moved_point.clone()).into_scalar()); + assert!(Sphere::is_tangent_at(moved_point, moved_vector).into_scalar()); } } diff --git a/src/manifolds/steifiel.rs b/src/manifolds/steifiel.rs index b56cb27..55d4f83 100644 --- a/src/manifolds/steifiel.rs +++ b/src/manifolds/steifiel.rs @@ -1,5 +1,3 @@ -use burn::tensor::cast::ToElement; - use crate::prelude::*; #[derive(Debug, Clone, Default)] @@ -8,9 +6,7 @@ pub struct SteifielsManifold { } impl Manifold for SteifielsManifold { - type PointOnManifold = Tensor; - - type TangentVectorWithoutPoint = Tensor; + const RANK_PER_POINT: usize = 2; fn new() -> Self { SteifielsManifold { @@ -32,8 +28,23 @@ impl Manifold for SteifielsManifold { } fn retract(point: Tensor, direction: Tensor) -> Tensor { - let s = point + direction; - gram_schmidt(&s) + debug_assert!(point.dims().len() >= Self::RANK_PER_POINT); + debug_assert!(direction.dims().len() >= Self::RANK_PER_POINT); + let mut s = point + direction; + if s.dims().len() > Self::RANK_PER_POINT { + // Gram_schmidt as written does so on the first two coordinates + // unlike Matrix multiplication and the rest of tensor operations + // which is assuming the last two coordinates + // and the first bunch being channels instead of vice versa + s = s.swap_dims(0, D - 2); + s = s.swap_dims(1, D - 1); + s = gram_schmidt(&s); + s = s.swap_dims(1, D - 1); + s = s.swap_dims(0, D - 2); + s + } else { + gram_schmidt(&s) + } } fn inner( @@ -42,15 +53,26 @@ impl Manifold for SteifielsManifold { v: Tensor, ) -> Tensor { // For Stiefel manifold, we use the standard Euclidean inner product - u * v + (u * v).sum_dim(D - 1).sum_dim(D - 2) } - fn is_tangent_at(point: Tensor, vector: Tensor) -> bool { + fn is_tangent_at( + point: Tensor, + vector: Tensor, + ) -> Tensor { let xtv = point.clone().transpose().matmul(vector.clone()); let vtx = vector.clone().transpose().matmul(point.clone()); let skew = xtv + vtx.transpose(); - let max_skew = skew.abs().max().into_scalar(); - max_skew.to_f64() < 1e-6 + let max_skew = skew.clone().abs().max_dim(D - 1).max_dim(D - 2); + max_skew.lower_elem(1e-6) + } + + fn proj(_point: Tensor) -> Tensor { + todo!() + } + + fn is_in_manifold(_point: Tensor) -> Tensor { + todo!() } } @@ -311,7 +333,7 @@ mod test { max_skew ); assert!( - SteifielsManifold::is_tangent_at(point, tangent), + SteifielsManifold::is_tangent_at(point, tangent).into_scalar(), "Tangent space property violated: max skew unknown" ) } diff --git a/src/optimizers.rs b/src/optimizers.rs index 4887b3b..5634904 100644 --- a/src/optimizers.rs +++ b/src/optimizers.rs @@ -288,7 +288,7 @@ where state.exp_avg = state.exp_avg.clone() * self.config.beta1 + rgrad.clone() * (1.0 - self.config.beta1); - let inner_product = M::inner(tensor.clone(), rgrad.clone(), rgrad.clone()); + let inner_product = M::inner::(tensor.clone(), rgrad.clone(), rgrad.clone()); state.exp_avg_sq = state.exp_avg_sq.clone() * self.config.beta2 + inner_product * (1.0 - self.config.beta2); From 23b41bbacb4170ab9680318993f59e3bd1cb7e70 Mon Sep 17 00:00:00 2001 From: Cobord Date: Wed, 10 Dec 2025 15:09:51 -0500 Subject: [PATCH 08/13] clarify how this is a family of manifolds and the last RANK_PER_POINT dimensions might have some constraints --- examples/multi_constraints.rs | 5 ++ src/manifolds.rs | 47 +++++++++++++----- src/manifolds/sphere.rs | 93 +++++++++++++++++++++++++++++++++-- src/manifolds/steifiel.rs | 6 +++ 4 files changed, 133 insertions(+), 18 deletions(-) diff --git a/examples/multi_constraints.rs b/examples/multi_constraints.rs index 075ed79..8c3dab2 100644 --- a/examples/multi_constraints.rs +++ b/examples/multi_constraints.rs @@ -78,6 +78,11 @@ impl Manifold for CustomSphereManifold { let zeros = dot_product.zeros_like(); dot_product.is_close(zeros, None, Some(1e-6)) } + + fn acceptable_dims(a_is: &[usize]) -> bool { + let n = *a_is.first().expect("The ambient R^n does exist"); + n > 0 + } } #[derive(Debug, Clone)] diff --git a/src/manifolds.rs b/src/manifolds.rs index 1f8dd98..638e215 100644 --- a/src/manifolds.rs +++ b/src/manifolds.rs @@ -15,6 +15,8 @@ pub mod sphere; pub use sphere::Sphere; /// A Riemannian manifold defines the geometric structure for optimization. +/// This is actually for a family of manifolds parameterized by some natural numbers. +/// /// /// This trait provides all the necessary operations for Riemannian optimization: /// - Tangent space projections @@ -23,22 +25,12 @@ pub use sphere::Sphere; /// - Parallel transport /// - Riemannian inner products /// -/// # Example Implementation -/// -/// ```r u st -/// use manopt_rs::prelude::*; -/// -/// #[derive(Clone)] -/// struct MyManifold; -/// -/// impl Manifold for MyManifold { -/// } -/// ``` pub trait Manifold: Clone + Send + Sync { const RANK_PER_POINT: usize; fn new() -> Self; fn name() -> &'static str; + #[must_use] fn specific_name(s: &Shape) -> String { let dims = &s.dims; let num_dims = dims.len(); @@ -49,10 +41,35 @@ pub trait Manifold: Clone + Send + Sync { ) } + /// The manifold lives in `R^a_1 \times R^{a_{RANK_PER_POINT}}` + /// so if we have a Tensor of shape `s` + /// then it's last `RANK_PER_POINT` dimensions will be those a's + /// with the previous dimensions being used as channels. + /// Those a's then must be allowed + /// For example in a Matrix Lie group they will be + /// `R^n \times R^n`, giving a constraint that those last + /// two dimensions in the shape be equal to each other. + #[must_use] fn acceptable_shape(s: &Shape) -> bool { - s.num_dims() >= Self::RANK_PER_POINT + let enough_points = s.num_dims() >= Self::RANK_PER_POINT; + if !enough_points { + return false; + } + let (_, a_i) = s.dims.split_at(s.num_dims() - Self::RANK_PER_POINT); + Self::acceptable_dims(a_i) } + /// The manifold lives in `R^a_1 \times R^{a_{RANK_PER_POINT}}` + /// Those a's must be allowed . + /// For example in a Matrix Lie group they will be + /// `R^n \times R^n`, giving a constraint that those last + /// two dimensions in the shape be equal to each other. + /// For the purposes of this, we are allowed to assume the slice + /// is of length `Self::RANK_PER_POINT` + /// because it should only be called through `Self::acceptable_shape`. + /// Putting this in the type would not be allowed without unstable features. + fn acceptable_dims(_a_is: &[usize]) -> bool; + /// Project `vector` to the tangent space at `point` fn project(point: Tensor, vector: Tensor) -> Tensor; @@ -75,7 +92,7 @@ pub trait Manifold: Clone + Send + Sync { tangent: Tensor, ) -> Tensor { // Default implementation: project to tangent space at point2 - Self::project_tangent(point2, tangent.into()) + Self::project_tangent(point2, tangent) } /// Move along the manifold from `point` along the tangent vector `direction` with step size @@ -164,4 +181,8 @@ impl Manifold for Euclidean { .bool_not(); Self::is_in_manifold(point).bool_and(vector_exists) } + + fn acceptable_dims(_a_is: &[usize]) -> bool { + true + } } diff --git a/src/manifolds/sphere.rs b/src/manifolds/sphere.rs index dad0aaf..dacfd76 100644 --- a/src/manifolds/sphere.rs +++ b/src/manifolds/sphere.rs @@ -14,6 +14,27 @@ impl Manifold for Sphere { "Sphere" } + fn specific_name(s: &Shape) -> String { + let num_dims = s.num_dims(); + assert!( + num_dims > 0, + "There is at least one dimension where the manifold actually lives" + ); + let sphere_dim = *s + .dims + .last() + .expect("There is at least one dimension where the manifold actually lives"); + let (channel_dims, _) = s.dims.split_at(num_dims - 1); + if channel_dims.is_empty() { + format!("Sphere S^{} subset R^{sphere_dim}", sphere_dim - 1) + } else { + format!( + "{channel_dims:?} Channels worth of points in Sphere S^{} subset R^{sphere_dim}", + sphere_dim - 1 + ) + } + } + fn project(point: Tensor, vector: Tensor) -> Tensor { // For sphere: project vector orthogonal to point let dot_product = (point.clone() * vector.clone()).sum_dim(D - 1); @@ -54,6 +75,11 @@ impl Manifold for Sphere { let norm = point.clone().powf_scalar(2.0).sum_dim(D - 1).sqrt(); point / norm } + + fn acceptable_dims(a_is: &[usize]) -> bool { + let n = *a_is.first().expect("The ambient R^n does exist"); + n > 0 + } } #[cfg(test)] @@ -63,11 +89,12 @@ mod test { use super::Sphere; use burn::{ backend::{Autodiff, NdArray}, - tensor::Tensor, + tensor::{Shape, Tensor}, }; type TestBackend = Autodiff; type TestTensor = Tensor; + type TestManyTensor = Tensor; const TOLERANCE: f32 = 1e-6; @@ -88,14 +115,29 @@ mod test { Tensor::from_floats(data, &device) } + fn create_test_matrices( + data: [[[f32; ROWS]; CHANNEL1]; CHANNEL0], + ) -> TestManyTensor { + let device = Default::default(); + Tensor::from_floats(data, &device) + } + #[test] fn test_manifold_creation() { let _manifold = >::new(); assert_eq!(>::name(), "Sphere"); - assert_eq!(>::specific_name(&burn::tensor::Shape{dims: vec![5]}), - "[] Channels worth of points in Sphere with specific n's [5]"); - assert_eq!(>::specific_name(&burn::tensor::Shape{dims: vec![10,30,5]}), - "[10, 30] Channels worth of points in Sphere with specific n's [5]"); + assert_eq!( + >::specific_name(&burn::tensor::Shape { + dims: vec![5] + }), + "Sphere S^4 subset R^5" + ); + assert_eq!( + >::specific_name(&burn::tensor::Shape { + dims: vec![10, 30, 5] + }), + "[10, 30] Channels worth of points in Sphere S^4 subset R^5" + ); } #[test] @@ -120,6 +162,47 @@ mod test { ); } + #[test] + fn test_many_projection_tangent_space() { + // Create many points on the Sphere manifold + let point_00 = [3.0 / 5.0, 0.0, 0.0, 4.0 / 5.0, 0.0, 0.0]; + let point_01 = [4.0 / 5.0, 0.0, 3.0 / 5.0, 0.0, 0.0, 0.0]; + let point_02 = [1.0 / 1.0, 0.0, 0.0, 0.0 / 1.0, 0.0, 0.0]; + let point_10 = [0.0 / 1.0, 0.0, 0.0, -1.0 / 1.0, 0.0, 0.0]; + let point_11 = [3.0 / 5.0, 0.0, 4.0 / 5.0, 0.0, 0.0, 0.0]; + let point_12 = [3.0 / 5.0, 0.0, 0.0, -4.0 / 5.0, 0.0, 0.0]; + let points = create_test_matrices::<6, 2, 3>([ + [point_00, point_01, point_02], + [point_10, point_11, point_12], + ]); + assert_eq!( + points.shape(), + Shape { + dims: vec![2, 3, 6] + } + ); + + // Create many direction vectors + let directions = TestManyTensor::random( + points.shape(), + burn::tensor::Distribution::Uniform(-1.0, 1.0), + &points.device(), + ); + + let projecteds = + >::project(points.clone(), directions.clone()); + + // The projection should be orthogonal to the point + // i.e., point^T * projected should be 0 + let product = (points.clone() * projecteds.clone()).sum_dim(2); + let max_entry = product.abs().max().into_scalar(); + assert!( + max_entry < TOLERANCE, + "Projected direction not in tangent space: absoulte value of the dot product = {}", + max_entry + ); + } + #[test] fn test_projection_preserves_tangent_vectors() { // Create a point on the Sphere manifold diff --git a/src/manifolds/steifiel.rs b/src/manifolds/steifiel.rs index 55d4f83..1d2edc7 100644 --- a/src/manifolds/steifiel.rs +++ b/src/manifolds/steifiel.rs @@ -74,6 +74,12 @@ impl Manifold for SteifielsManifold { fn is_in_manifold(_point: Tensor) -> Tensor { todo!() } + + fn acceptable_dims(a_is: &[usize]) -> bool { + let n = a_is[0]; + let k = a_is[1]; + n > 0 && k > 0 && k <= n + } } fn gram_schmidt(v: &Tensor) -> Tensor { From 1693a7ed8aba730ee793563d957980ece4e55077 Mon Sep 17 00:00:00 2001 From: Cobord Date: Thu, 11 Dec 2025 19:55:02 -0500 Subject: [PATCH 09/13] stieffel todos --- src/manifolds/steifiel.rs | 46 +++++++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/src/manifolds/steifiel.rs b/src/manifolds/steifiel.rs index 1d2edc7..ef38fda 100644 --- a/src/manifolds/steifiel.rs +++ b/src/manifolds/steifiel.rs @@ -67,12 +67,48 @@ impl Manifold for SteifielsManifold { max_skew.lower_elem(1e-6) } - fn proj(_point: Tensor) -> Tensor { - todo!() + fn proj(mut point: Tensor) -> Tensor { + debug_assert!(point.dims().len() >= Self::RANK_PER_POINT); + if point.dims().len() > Self::RANK_PER_POINT { + // Gram_schmidt as written does so on the first two coordinates + // unlike Matrix multiplication and the rest of tensor operations + // which is assuming the last two coordinates + // and the first bunch being channels instead of vice versa + point = point.swap_dims(0, D - 2); + point = point.swap_dims(1, D - 1); + point = gram_schmidt(&point); + point = point.swap_dims(1, D - 1); + point = point.swap_dims(0, D - 2); + point + } else { + gram_schmidt(&point) + } } - fn is_in_manifold(_point: Tensor) -> Tensor { - todo!() + fn is_in_manifold(point: Tensor) -> Tensor { + let a_transpose_times_a = point.clone().transpose().matmul(point); + let all_dims = a_transpose_times_a.shape(); + debug_assert!(all_dims.num_dims() >= 2); + let shape : [usize; D] = a_transpose_times_a.shape().dims(); + debug_assert_eq!(shape[D-1], shape[D-2]); + let n = shape[D-1]; + let mut other = a_transpose_times_a.zeros_like(); + let mut ones_shape = [1usize; D]; + for i in 0..(D-2) { + ones_shape[i] = shape[i]; + } + let ones_patch = Tensor::::ones(ones_shape, &a_transpose_times_a.device()); + for diag in 0..n { + let ranges : [_;D] = std::array::from_fn(|dim| + if dim < D-2 { + 0..shape[dim] + } else { + diag..diag+1 + } + ); + other = other.slice_assign(ranges, ones_patch.clone()); + } + a_transpose_times_a.is_close(other, None, None).all_dim(D-1).all_dim(D-2) } fn acceptable_dims(a_is: &[usize]) -> bool { @@ -392,6 +428,8 @@ mod test { "Second column not normalized after retraction: norm = {}", norm2 ); + + assert!(SteifielsManifold::::is_in_manifold(retracted).all().into_scalar()); } #[test] From 58ce0ec045ff35b609d05c697f4c02739eac07e5 Mon Sep 17 00:00:00 2001 From: Cobord Date: Fri, 12 Dec 2025 15:00:35 -0500 Subject: [PATCH 10/13] fmt --- src/manifolds/steifiel.rs | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/manifolds/steifiel.rs b/src/manifolds/steifiel.rs index ef38fda..c0ce571 100644 --- a/src/manifolds/steifiel.rs +++ b/src/manifolds/steifiel.rs @@ -89,26 +89,27 @@ impl Manifold for SteifielsManifold { let a_transpose_times_a = point.clone().transpose().matmul(point); let all_dims = a_transpose_times_a.shape(); debug_assert!(all_dims.num_dims() >= 2); - let shape : [usize; D] = a_transpose_times_a.shape().dims(); - debug_assert_eq!(shape[D-1], shape[D-2]); - let n = shape[D-1]; + let shape: [usize; D] = a_transpose_times_a.shape().dims(); + debug_assert_eq!(shape[D - 1], shape[D - 2]); + let n = shape[D - 1]; let mut other = a_transpose_times_a.zeros_like(); let mut ones_shape = [1usize; D]; - for i in 0..(D-2) { - ones_shape[i] = shape[i]; - } - let ones_patch = Tensor::::ones(ones_shape, &a_transpose_times_a.device()); + ones_shape[..(D - 2)].copy_from_slice(&shape[..(D - 2)]); + let ones_patch = Tensor::::ones(ones_shape, &a_transpose_times_a.device()); for diag in 0..n { - let ranges : [_;D] = std::array::from_fn(|dim| - if dim < D-2 { + let ranges: [_; D] = std::array::from_fn(|dim| { + if dim < D - 2 { 0..shape[dim] } else { - diag..diag+1 + diag..diag + 1 } - ); + }); other = other.slice_assign(ranges, ones_patch.clone()); } - a_transpose_times_a.is_close(other, None, None).all_dim(D-1).all_dim(D-2) + a_transpose_times_a + .is_close(other, None, None) + .all_dim(D - 1) + .all_dim(D - 2) } fn acceptable_dims(a_is: &[usize]) -> bool { @@ -429,7 +430,9 @@ mod test { norm2 ); - assert!(SteifielsManifold::::is_in_manifold(retracted).all().into_scalar()); + assert!(SteifielsManifold::::is_in_manifold(retracted) + .all() + .into_scalar()); } #[test] From cbded21f737d08f90dcbab9432af49cd289c5a0d Mon Sep 17 00:00:00 2001 From: Cobord Date: Mon, 15 Dec 2025 15:01:31 -0500 Subject: [PATCH 11/13] moving identity creation which is used in all matrix manifolds --- src/manifolds.rs | 5 +++ src/manifolds/matrix_groups.rs | 80 ++++++++++++++++++++++++++++++++++ src/manifolds/steifiel.rs | 20 +-------- src/manifolds/utils.rs | 28 ++++++++++++ 4 files changed, 115 insertions(+), 18 deletions(-) create mode 100644 src/manifolds/matrix_groups.rs create mode 100644 src/manifolds/utils.rs diff --git a/src/manifolds.rs b/src/manifolds.rs index 638e215..6216a1e 100644 --- a/src/manifolds.rs +++ b/src/manifolds.rs @@ -14,6 +14,11 @@ pub use steifiel::SteifielsManifold; pub mod sphere; pub use sphere::Sphere; +pub mod matrix_groups; +pub use matrix_groups::OrthogonalGroup; + +pub mod utils; + /// A Riemannian manifold defines the geometric structure for optimization. /// This is actually for a family of manifolds parameterized by some natural numbers. /// diff --git a/src/manifolds/matrix_groups.rs b/src/manifolds/matrix_groups.rs new file mode 100644 index 0000000..4e0dc47 --- /dev/null +++ b/src/manifolds/matrix_groups.rs @@ -0,0 +1,80 @@ +use crate::{manifolds::utils::identity_in_last_two, prelude::*}; + +#[derive(Debug, Clone, Default)] +pub struct OrthogonalGroup { + _backend: std::marker::PhantomData, +} + +impl Manifold for OrthogonalGroup { + const RANK_PER_POINT: usize = 2; + + fn new() -> Self { + OrthogonalGroup { + _backend: std::marker::PhantomData, + } + } + + fn name() -> &'static str { + if IS_SPECIAL { + "Special Orthogonal" + } else { + "Orthogonal" + } + } + + fn acceptable_dims(a_is: &[usize]) -> bool { + debug_assert!(a_is.len() >= Self::RANK_PER_POINT); + let num_dims = a_is.len(); + a_is[num_dims - 1] == a_is[num_dims - 2] + } + + fn project(_point: Tensor, _vector: Tensor) -> Tensor { + todo!() + } + + fn retract(_point: Tensor, _direction: Tensor) -> Tensor { + todo!() + } + + fn inner( + _point: Tensor, + u: Tensor, + v: Tensor, + ) -> Tensor { + // For orthogonal manifolds, we use the standard Euclidean inner product + (u * v).sum_dim(D - 1).sum_dim(D - 2) + } + + fn proj(_point: Tensor) -> Tensor { + todo!() + } + + fn is_in_manifold(point: Tensor) -> Tensor { + if Self::acceptable_shape(&point.shape()) { + return point.zeros_like().any_dim(D - 1).any_dim(D - 2); + } + let a_transpose_times_a = point.clone().transpose().matmul(point); + let all_dims = a_transpose_times_a.shape(); + debug_assert!(all_dims.num_dims() >= 2); + let other = identity_in_last_two(&a_transpose_times_a); + let in_orthogonal = a_transpose_times_a + .is_close(other, None, None) + .all_dim(D - 1) + .all_dim(D - 2); + if IS_SPECIAL { + in_orthogonal + } else { + #[allow(unused_variables)] + let has_det_one = { todo!() }; + #[allow(unreachable_code)] + in_orthogonal.bool_and(has_det_one); + } + } + + fn is_tangent_at( + _point: Tensor, + _vector: Tensor, + ) -> Tensor { + todo!() + } +} diff --git a/src/manifolds/steifiel.rs b/src/manifolds/steifiel.rs index c0ce571..9367525 100644 --- a/src/manifolds/steifiel.rs +++ b/src/manifolds/steifiel.rs @@ -1,4 +1,4 @@ -use crate::prelude::*; +use crate::{manifolds::utils::identity_in_last_two, prelude::*}; #[derive(Debug, Clone, Default)] pub struct SteifielsManifold { @@ -89,23 +89,7 @@ impl Manifold for SteifielsManifold { let a_transpose_times_a = point.clone().transpose().matmul(point); let all_dims = a_transpose_times_a.shape(); debug_assert!(all_dims.num_dims() >= 2); - let shape: [usize; D] = a_transpose_times_a.shape().dims(); - debug_assert_eq!(shape[D - 1], shape[D - 2]); - let n = shape[D - 1]; - let mut other = a_transpose_times_a.zeros_like(); - let mut ones_shape = [1usize; D]; - ones_shape[..(D - 2)].copy_from_slice(&shape[..(D - 2)]); - let ones_patch = Tensor::::ones(ones_shape, &a_transpose_times_a.device()); - for diag in 0..n { - let ranges: [_; D] = std::array::from_fn(|dim| { - if dim < D - 2 { - 0..shape[dim] - } else { - diag..diag + 1 - } - }); - other = other.slice_assign(ranges, ones_patch.clone()); - } + let other = identity_in_last_two(&a_transpose_times_a); a_transpose_times_a .is_close(other, None, None) .all_dim(D - 1) diff --git a/src/manifolds/utils.rs b/src/manifolds/utils.rs new file mode 100644 index 0000000..fe717d0 --- /dev/null +++ b/src/manifolds/utils.rs @@ -0,0 +1,28 @@ +use burn::{prelude::Backend, tensor::Tensor}; + +/// Given a tensor of shape `a_1 ... a_k x N x N` +/// create a tensor of the same shape +/// whose `i_1...i_k,m,n` entry is `1` if `m==n` and `0` otherwise +pub(crate) fn identity_in_last_two( + example: &Tensor, +) -> Tensor { + let shape: [usize; D] = example.shape().dims(); + debug_assert!(D >= 2); + debug_assert_eq!(shape[D - 1], shape[D - 2]); + let n = shape[D - 1]; + let mut other = example.zeros_like(); + let mut ones_shape = [1usize; D]; + ones_shape[..(D - 2)].copy_from_slice(&shape[..(D - 2)]); + let ones_patch = Tensor::::ones(ones_shape, &example.device()); + for diag in 0..n { + let ranges: [_; D] = std::array::from_fn(|dim| { + if dim < D - 2 { + 0..shape[dim] + } else { + diag..diag + 1 + } + }); + other = other.slice_assign(ranges, ones_patch.clone()); + } + other +} From 6fbe9cd8c0497dda9568ba88f6f74e6c444e7edb Mon Sep 17 00:00:00 2001 From: Cobord Date: Mon, 15 Dec 2025 17:03:57 -0500 Subject: [PATCH 12/13] move more generally useful for all matrix manifolds, ein sum --- src/lib.rs | 1 + src/lie_group.rs | 7 ++ src/manifolds/matrix_groups.rs | 8 +- src/manifolds/steifiel.rs | 120 +++--------------- src/manifolds/utils.rs | 220 +++++++++++++++++++++++++++++++++ 5 files changed, 249 insertions(+), 107 deletions(-) create mode 100644 src/lie_group.rs diff --git a/src/lib.rs b/src/lib.rs index 8dd69de..e9f4f78 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod constrained_module; pub mod manifolds; +pub mod lie_group; pub mod optimizers; pub mod prelude; diff --git a/src/lie_group.rs b/src/lie_group.rs new file mode 100644 index 0000000..3cb5163 --- /dev/null +++ b/src/lie_group.rs @@ -0,0 +1,7 @@ +use burn::{prelude::Backend, tensor::Tensor}; + +use crate::prelude::Manifold; + +pub trait MonoidManifold: Clone + Send + Sync + Manifold { + fn lie_mul(points0: Tensor, points1: Tensor) -> Tensor; +} \ No newline at end of file diff --git a/src/manifolds/matrix_groups.rs b/src/manifolds/matrix_groups.rs index 4e0dc47..43ffec0 100644 --- a/src/manifolds/matrix_groups.rs +++ b/src/manifolds/matrix_groups.rs @@ -1,4 +1,4 @@ -use crate::{manifolds::utils::identity_in_last_two, prelude::*}; +use crate::{lie_group::MonoidManifold, manifolds::utils::identity_in_last_two, prelude::*}; #[derive(Debug, Clone, Default)] pub struct OrthogonalGroup { @@ -78,3 +78,9 @@ impl Manifold for OrthogonalGroup MonoidManifold for OrthogonalGroup { + fn lie_mul(points0: Tensor, points1: Tensor) -> Tensor { + points0.matmul(points1) + } +} \ No newline at end of file diff --git a/src/manifolds/steifiel.rs b/src/manifolds/steifiel.rs index 9367525..60c7cb7 100644 --- a/src/manifolds/steifiel.rs +++ b/src/manifolds/steifiel.rs @@ -130,6 +130,7 @@ fn gram_schmidt(v: &Tensor) -> Tensor { #[cfg(test)] mod test { + use crate::manifolds::utils::test::{assert_matrix_close, create_test_matrix}; use crate::optimizers::LessSimpleOptimizer; use super::*; @@ -139,106 +140,9 @@ mod test { }; type TestBackend = Autodiff; - type TestTensor = Tensor; const TOLERANCE: f32 = 1e-6; - fn assert_tensor_close(a: &TestTensor, b: &TestTensor, tol: f32) { - let diff = (a.clone() - b.clone()).abs(); - let max_diff = diff.max().into_scalar(); - assert!( - max_diff < tol, - "Tensors differ by {}, tolerance: {}", - max_diff, - tol - ); - } - - fn create_test_matrix(rows: usize, cols: usize, values: Vec) -> TestTensor { - let device = Default::default(); - // Reshape the flat vector into a 2D array - let mut data = Vec::with_capacity(rows); - for chunk in values.chunks(cols) { - data.push(chunk.to_vec()); - } - - // Create tensor from nested arrays - match (rows, cols) { - (3, 2) => { - if data.len() >= 3 && data[0].len() >= 2 && data[1].len() >= 2 && data[2].len() >= 2 - { - Tensor::from_floats( - [ - [data[0][0], data[0][1]], - [data[1][0], data[1][1]], - [data[2][0], data[2][1]], - ], - &device, - ) - } else { - panic!("Invalid 3x2 matrix data"); - } - } - (3, 1) => { - if data.len() >= 3 - && !data[0].is_empty() - && !data[1].is_empty() - && !data[2].is_empty() - { - Tensor::from_floats([[data[0][0]], [data[1][0]], [data[2][0]]], &device) - } else { - panic!("Invalid 3x1 matrix data"); - } - } - (3, 3) => { - if data.len() >= 3 && data[0].len() >= 3 && data[1].len() >= 3 && data[2].len() >= 3 - { - Tensor::from_floats( - [ - [data[0][0], data[0][1], data[0][2]], - [data[1][0], data[1][1], data[1][2]], - [data[2][0], data[2][1], data[2][2]], - ], - &device, - ) - } else { - panic!("Invalid 3x3 matrix data"); - } - } - (4, 2) => { - if data.len() >= 4 - && data[0].len() >= 2 - && data[1].len() >= 2 - && data[2].len() >= 2 - && data[3].len() >= 2 - { - Tensor::from_floats( - [ - [data[0][0], data[0][1]], - [data[1][0], data[1][1]], - [data[2][0], data[2][1]], - [data[3][0], data[3][1]], - ], - &device, - ) - } else { - panic!("Invalid 4x2 matrix data"); - } - } - (2, 2) => { - if data.len() >= 2 && data[0].len() >= 2 && data[1].len() >= 2 { - Tensor::from_floats( - [[data[0][0], data[0][1]], [data[1][0], data[1][1]]], - &device, - ) - } else { - panic!("Invalid 2x2 matrix data"); - } - } - _ => panic!("Unsupported matrix dimensions: {}x{}", rows, cols), - } - } - #[test] fn test_manifold_creation() { let _manifold = SteifielsManifold::::new(); @@ -248,7 +152,7 @@ mod test { #[test] fn test_gram_schmidt_orthogonalization() { // Test with a simple 3x2 matrix - let input = create_test_matrix(3, 2, vec![1.0, 1.0, 1.0, 0.0, 0.0, 1.0]); + let input = create_test_matrix::(3, 2, vec![1.0, 1.0, 1.0, 0.0, 0.0, 1.0]); let result = gram_schmidt(&input); @@ -294,7 +198,7 @@ mod test { #[test] fn test_gram_schmidt_single_column() { // Test with a single column vector - let input = create_test_matrix(3, 1, vec![3.0, 4.0, 0.0]); + let input = create_test_matrix::(3, 1, vec![3.0, 4.0, 0.0]); let result = gram_schmidt(&input); // Should be normalized to unit length @@ -311,8 +215,8 @@ mod test { ); // Should be proportional to original vector - let expected = create_test_matrix(3, 1, vec![0.6, 0.8, 0.0]); - assert_tensor_close(&result, &expected, TOLERANCE); + let expected = create_test_matrix::(3, 1, vec![0.6, 0.8, 0.0]); + assert_matrix_close(&result, &expected, TOLERANCE); } #[test] @@ -348,7 +252,7 @@ mod test { // Project the tangent vector again let projected = SteifielsManifold::::project(point.clone(), tangent.clone()); // Should be unchanged (idempotent) - assert_tensor_close(&projected, &tangent, 1e-6); + assert_matrix_close(&projected, &tangent, 1e-6); // Check the tangent space property: X^T V + V^T X = 0 let xtv = point.clone().transpose().matmul(tangent.clone()); let vtx = tangent.clone().transpose().matmul(point.clone()); @@ -422,10 +326,14 @@ mod test { #[test] fn test_gram_schmidt_identity_matrix() { // Identity matrix should remain unchanged - let identity = create_test_matrix(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]); + let identity = create_test_matrix::( + 3, + 3, + vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], + ); let result = gram_schmidt(&identity); - assert_tensor_close(&result, &identity, TOLERANCE); + assert_matrix_close(&result, &identity, TOLERANCE); } #[test] @@ -444,7 +352,7 @@ mod test { let gram_matrix = point.clone().transpose().matmul(point.clone()); let identity = create_test_matrix(2, 2, vec![1.0, 0.0, 0.0, 1.0]); - assert_tensor_close(&gram_matrix, &identity, TOLERANCE); + assert_matrix_close(&gram_matrix, &identity, TOLERANCE); // Test projection and retraction preserve this property let direction = create_test_matrix(4, 2, vec![0.1, 0.0, 0.0, 0.1, 0.2, 0.3, -0.1, 0.2]); @@ -453,7 +361,7 @@ mod test { let retracted = SteifielsManifold::::retract(point.clone(), projected * 0.1); let retracted_gram = retracted.clone().transpose().matmul(retracted.clone()); - assert_tensor_close(&retracted_gram, &identity, TOLERANCE); + assert_matrix_close(&retracted_gram, &identity, TOLERANCE); } #[test] diff --git a/src/manifolds/utils.rs b/src/manifolds/utils.rs index fe717d0..47277a5 100644 --- a/src/manifolds/utils.rs +++ b/src/manifolds/utils.rs @@ -26,3 +26,223 @@ pub(crate) fn identity_in_last_two( } other } + +/// Given a tensor of shape `a_1 ... a_k x N x N` +/// create a tensor of the same shape +/// whose `i_1...i_k,m,n` entry is `f(m)` if `m==n` and `0` otherwise +#[allow(dead_code)] +pub(crate) fn diag_i( + example: &Tensor, + diag_fun: impl Fn(usize) -> f32 +) -> Tensor { + let shape: [usize; D] = example.shape().dims(); + debug_assert!(D >= 2); + debug_assert_eq!(shape[D - 1], shape[D - 2]); + let n = shape[D - 1]; + let mut other = example.zeros_like(); + let mut ones_shape = [1usize; D]; + ones_shape[..(D - 2)].copy_from_slice(&shape[..(D - 2)]); + let ones_patch = Tensor::::ones(ones_shape, &example.device()); + for diag in 0..n { + let ranges: [_; D] = std::array::from_fn(|dim| { + if dim < D - 2 { + 0..shape[dim] + } else { + diag..diag + 1 + } + }); + other = other.slice_assign(ranges, ones_patch.clone().mul_scalar(diag_fun(diag))); + } + other +} + +/// Given a tensor of shape `l_1 ... l_k` +/// create a tensor of the same shape +/// where `l_a == l_b` do the summation +/// of `\sum_{j=1}^{l_a}` with those particular +/// entries set. +/// The rank is dropped to `D-2` +/// but because of the lack of const operations +/// we leave those dimensions there but with `1`s +/// in the shape. +pub(crate) fn ein_sum( + mut t: Tensor, + a: usize, + b: usize, +) -> Tensor { + debug_assert!(a < D); + debug_assert!(b < D); + debug_assert_ne!(a, b); + let t_shape: [usize; D] = t.shape().dims(); + debug_assert_eq!(t_shape[a], t_shape[b]); + if a != D-2 || b != D-1 { + t = t.swap_dims(a, D - 2); + t = t.swap_dims(b, D - 1); + } + let identity_last_two = identity_in_last_two(&t); + t = t.mul(identity_last_two); + t = t.sum_dim(D - 1).sum_dim(D - 2); + if a != D-2 || b != D-1 { + t = t.swap_dims(b, D - 1); + t = t.swap_dims(a, D - 2); + } + t +} + +/// Given a tensor of shape `a_1 ... a_k x N x N` +/// create a tensor of shape `a_1 ... a_k x 1 x 1` +/// whose `i_1...i_k,0,0` entry is the trace +/// of the `N x N` matrix with those previous `k` +/// fixed but the last two remaining free. +#[allow(dead_code)] +pub(crate) fn trace( + t: Tensor, +) -> Tensor { + ein_sum(t, D-2, D-1) +} + +#[cfg(test)] +pub(crate) mod test { + use burn::{backend::NdArray, prelude::Backend, tensor::Tensor}; + + use crate::manifolds::utils::ein_sum; + + pub(crate) fn assert_matrix_close( + a: &Tensor, + b: &Tensor, + tol: f32, + ) where + TestBackend: Backend, + ::FloatElem: PartialOrd, + { + let diff = (a.clone() - b.clone()).abs(); + let max_diff = diff.max().into_scalar(); + assert!( + max_diff < tol, + "Tensors differ by {}, tolerance: {}", + max_diff, + tol + ); + } + + pub(crate) fn create_test_matrix( + rows: usize, + cols: usize, + values: Vec, + ) -> Tensor { + debug_assert_ne!(rows, 0); + debug_assert_ne!(cols, 0); + if rows < cols { + return create_test_matrix(cols, rows, values).transpose(); + } + let device = Default::default(); + // Reshape the flat vector into a 2D array + let mut data = Vec::with_capacity(rows); + for chunk in values.chunks(cols) { + data.push(chunk.to_vec()); + } + + // Create tensor from nested arrays + match (rows, cols) { + (3, 2) => { + if data.len() >= 3 && data[0].len() >= 2 && data[1].len() >= 2 && data[2].len() >= 2 + { + Tensor::from_floats( + [ + [data[0][0], data[0][1]], + [data[1][0], data[1][1]], + [data[2][0], data[2][1]], + ], + &device, + ) + } else { + panic!("Invalid 3x2 matrix data"); + } + } + (3, 1) => { + if data.len() >= 3 + && !data[0].is_empty() + && !data[1].is_empty() + && !data[2].is_empty() + { + Tensor::from_floats([[data[0][0]], [data[1][0]], [data[2][0]]], &device) + } else { + panic!("Invalid 3x1 matrix data"); + } + } + (3, 3) => { + if data.len() >= 3 && data[0].len() >= 3 && data[1].len() >= 3 && data[2].len() >= 3 + { + Tensor::from_floats( + [ + [data[0][0], data[0][1], data[0][2]], + [data[1][0], data[1][1], data[1][2]], + [data[2][0], data[2][1], data[2][2]], + ], + &device, + ) + } else { + panic!("Invalid 3x3 matrix data"); + } + } + (4, 2) => { + if data.len() >= 4 + && data[0].len() >= 2 + && data[1].len() >= 2 + && data[2].len() >= 2 + && data[3].len() >= 2 + { + Tensor::from_floats( + [ + [data[0][0], data[0][1]], + [data[1][0], data[1][1]], + [data[2][0], data[2][1]], + [data[3][0], data[3][1]], + ], + &device, + ) + } else { + panic!("Invalid 4x2 matrix data"); + } + } + (2, 2) => { + if data.len() >= 2 && data[0].len() >= 2 && data[1].len() >= 2 { + Tensor::from_floats( + [[data[0][0], data[0][1]], [data[1][0], data[1][1]]], + &device, + ) + } else { + panic!("Invalid 2x2 matrix data"); + } + } + (2, 1) => { + if data.len() >= 2 && !data[0].is_empty() && !data[1].is_empty() { + Tensor::from_floats([[data[0][0]], [data[1][0]]], &device) + } else { + panic!("Invalid 2x1 matrix data"); + } + } + (1, 1) => { + if data.len() >= 1 && !data[0].is_empty() { + Tensor::from_floats([[data[0][0]]], &device) + } else { + panic!("Invalid 1x1 matrix data"); + } + } + _ => panic!("Unsupported matrix dimensions: {}x{}", rows, cols), + } + } + + #[test] + fn small_einsum() { + let mat = create_test_matrix::(3, 3, vec![ + 3.0,4.0,5.0, + 6.0,7.0,3.0, + -10.0,-4.0,-1.0, + ]); + let ein_summed = ein_sum(mat.clone(), 0, 1); + assert_eq!(ein_summed.shape().dims(), [1,1]); + let scalar = ein_summed.into_scalar(); + assert!((scalar - 9.0).abs() <= 1e-6, "{}", scalar); + } +} From cd9df48ae38c1544ac517050cb0df108288af0a1 Mon Sep 17 00:00:00 2001 From: Cobord Date: Sun, 21 Dec 2025 17:03:36 -0500 Subject: [PATCH 13/13] util fixing and tests thereof --- src/lib.rs | 2 +- src/lie_group.rs | 4 +- src/manifolds/matrix_groups.rs | 6 +- src/manifolds/utils.rs | 160 +++++++++++++++++++++++++-------- 4 files changed, 131 insertions(+), 41 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e9f4f78..52d1117 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ pub mod constrained_module; -pub mod manifolds; pub mod lie_group; +pub mod manifolds; pub mod optimizers; pub mod prelude; diff --git a/src/lie_group.rs b/src/lie_group.rs index 3cb5163..71a017a 100644 --- a/src/lie_group.rs +++ b/src/lie_group.rs @@ -3,5 +3,5 @@ use burn::{prelude::Backend, tensor::Tensor}; use crate::prelude::Manifold; pub trait MonoidManifold: Clone + Send + Sync + Manifold { - fn lie_mul(points0: Tensor, points1: Tensor) -> Tensor; -} \ No newline at end of file + fn lie_mul(points0: Tensor, points1: Tensor) -> Tensor; +} diff --git a/src/manifolds/matrix_groups.rs b/src/manifolds/matrix_groups.rs index 43ffec0..2308515 100644 --- a/src/manifolds/matrix_groups.rs +++ b/src/manifolds/matrix_groups.rs @@ -79,8 +79,8 @@ impl Manifold for OrthogonalGroup MonoidManifold for OrthogonalGroup { - fn lie_mul(points0: Tensor, points1: Tensor) -> Tensor { +impl MonoidManifold for OrthogonalGroup { + fn lie_mul(points0: Tensor, points1: Tensor) -> Tensor { points0.matmul(points1) } -} \ No newline at end of file +} diff --git a/src/manifolds/utils.rs b/src/manifolds/utils.rs index 47277a5..2645eee 100644 --- a/src/manifolds/utils.rs +++ b/src/manifolds/utils.rs @@ -10,21 +10,18 @@ pub(crate) fn identity_in_last_two( debug_assert!(D >= 2); debug_assert_eq!(shape[D - 1], shape[D - 2]); let n = shape[D - 1]; - let mut other = example.zeros_like(); - let mut ones_shape = [1usize; D]; - ones_shape[..(D - 2)].copy_from_slice(&shape[..(D - 2)]); - let ones_patch = Tensor::::ones(ones_shape, &example.device()); - for diag in 0..n { - let ranges: [_; D] = std::array::from_fn(|dim| { - if dim < D - 2 { - 0..shape[dim] - } else { - diag..diag + 1 - } - }); - other = other.slice_assign(ranges, ones_patch.clone()); - } - other + let identity = Tensor::eye(n, &example.device()); + // Broadcasting is right aligned + // so this will act like identity in shape (1,...1,N,N) + // and then broadcast those 1's to a_1 ... a_k + // even though that is not stated in the docs of `expand` + // More honestly phrased identity of rank 2 is a function NxN -> R + // and the return value of type a1xa2...xNxN -> R is being done as + // the precomposition with the map of sets a1x...akxNxN -> 1x...1xNxN -> NxN which is the + // terminal map and identity maps in each factor and then monoidal unit laws. + // Relying on this implicit extraneous structure is the reason do not have to do this manually like in + // the below implementation of `diag_i`. + identity.expand(example.shape()) } /// Given a tensor of shape `a_1 ... a_k x N x N` @@ -33,7 +30,7 @@ pub(crate) fn identity_in_last_two( #[allow(dead_code)] pub(crate) fn diag_i( example: &Tensor, - diag_fun: impl Fn(usize) -> f32 + diag_fun: impl Fn(usize) -> f32, ) -> Tensor { let shape: [usize; D] = example.shape().dims(); debug_assert!(D >= 2); @@ -56,15 +53,12 @@ pub(crate) fn diag_i( other } -/// Given a tensor of shape `l_1 ... l_k` -/// create a tensor of the same shape -/// where `l_a == l_b` do the summation -/// of `\sum_{j=1}^{l_a}` with those particular -/// entries set. +/// Given a tensor of shape `l_1 .. l_a .. l_b .. l_D` /// The rank is dropped to `D-2` /// but because of the lack of const operations /// we leave those dimensions there but with `1`s -/// in the shape. +/// in the shape instead. +/// `l_1 .. 1 .. 1 .. l_D` pub(crate) fn ein_sum( mut t: Tensor, a: usize, @@ -75,14 +69,14 @@ pub(crate) fn ein_sum( debug_assert_ne!(a, b); let t_shape: [usize; D] = t.shape().dims(); debug_assert_eq!(t_shape[a], t_shape[b]); - if a != D-2 || b != D-1 { + if a != D - 2 || b != D - 1 { t = t.swap_dims(a, D - 2); t = t.swap_dims(b, D - 1); } let identity_last_two = identity_in_last_two(&t); t = t.mul(identity_last_two); t = t.sum_dim(D - 1).sum_dim(D - 2); - if a != D-2 || b != D-1 { + if a != D - 2 || b != D - 1 { t = t.swap_dims(b, D - 1); t = t.swap_dims(a, D - 2); } @@ -95,17 +89,23 @@ pub(crate) fn ein_sum( /// of the `N x N` matrix with those previous `k` /// fixed but the last two remaining free. #[allow(dead_code)] -pub(crate) fn trace( - t: Tensor, -) -> Tensor { - ein_sum(t, D-2, D-1) +pub(crate) fn trace(t: Tensor) -> Tensor { + // burn::tensor::linalg looks like a module according to the docs, but + // doing `burn::tensor::linalg::trace` does not work + // Also it reduces the rank by 1 and do not have the const generics + // to do {D-1}. The docs there also are contradictory with D0 and D-1 + // because of this inability to do {D-1}. Creates things that compile + // but really should not, but is not visible about it. + // Easier just to use the 1's as here. + ein_sum(t, D - 2, D - 1) } #[cfg(test)] pub(crate) mod test { + use burn::{backend::NdArray, prelude::Backend, tensor::Tensor}; - use crate::manifolds::utils::ein_sum; + use crate::manifolds::utils::{diag_i, ein_sum, identity_in_last_two}; pub(crate) fn assert_matrix_close( a: &Tensor, @@ -235,14 +235,104 @@ pub(crate) mod test { #[test] fn small_einsum() { - let mat = create_test_matrix::(3, 3, vec![ - 3.0,4.0,5.0, - 6.0,7.0,3.0, - -10.0,-4.0,-1.0, - ]); + let mat = create_test_matrix::( + 3, + 3, + vec![3.0, 4.0, 5.0, 6.0, 7.0, 3.0, -10.0, -4.0, -1.0], + ); let ein_summed = ein_sum(mat.clone(), 0, 1); - assert_eq!(ein_summed.shape().dims(), [1,1]); + assert_eq!(ein_summed.shape().dims(), [1, 1]); let scalar = ein_summed.into_scalar(); assert!((scalar - 9.0).abs() <= 1e-6, "{}", scalar); } + + #[test] + fn identity_test() { + { + let mat = create_test_matrix::( + 3, + 3, + vec![3.0, 4.0, 5.0, 6.0, 7.0, 3.0, -10.0, -4.0, -1.0], + ); + let mat = mat.expand([3, 3]); + let identity_mat = identity_in_last_two(&mat); + let expected = Tensor::eye(3, &identity_mat.device()); + assert_matrix_close(&identity_mat, &expected, 1e-6); + } + let mat = create_test_matrix::( + 3, + 3, + vec![3.0, 4.0, 5.0, 6.0, 7.0, 3.0, -10.0, -4.0, -1.0], + ); + let expanded_shape = [3, 3, 3, 3, 3]; + let mat = mat.expand(expanded_shape); + let identity_mat = identity_in_last_two(&mat); + for idx in 0..expanded_shape[0] { + for jdx in 0..expanded_shape[1] { + for kdx in 0..expanded_shape[2] { + let slice = identity_mat + .clone() + .slice([idx..idx + 1, jdx..jdx + 1, kdx..kdx + 1, 0..3, 0..3]) + .reshape([3, 3]); + let expected = Tensor::eye(3, &slice.device()); + assert_matrix_close(&slice, &expected, 1e-6); + } + } + } + let mat = create_test_matrix::( + 3, + 3, + vec![3.0, 4.0, 5.0, 6.0, 7.0, 3.0, -10.0, -4.0, -1.0], + ); + let expanded_shape = [29, 483, 2, 3, 3]; + let mat = mat.expand(expanded_shape); + let identity_mat = identity_in_last_two(&mat); + for idx in 0..expanded_shape[0] { + for jdx in 0..expanded_shape[1] { + for kdx in 0..expanded_shape[2] { + let slice = identity_mat + .clone() + .slice([idx..idx + 1, jdx..jdx + 1, kdx..kdx + 1, 0..3, 0..3]) + .reshape([3, 3]); + let expected = Tensor::eye(3, &slice.device()); + assert_matrix_close(&slice, &expected, 1e-6); + } + } + } + } + + #[test] + fn diag_test() { + let diag_entries = [2.0, 7.0, 9.0]; + let expected = + create_test_matrix::(3, 3, vec![2.0, 0.0, 0.0, 0.0, 7.0, 0.0, 0.0, 0.0, 9.0]); + let expanded_shape = [3, 3, 3, 3, 3]; + let mat = expected.clone().expand(expanded_shape); + let identity_mat = diag_i(&mat, |i| diag_entries[i]); + for idx in 0..expanded_shape[0] { + for jdx in 0..expanded_shape[1] { + for kdx in 0..expanded_shape[2] { + let slice = identity_mat + .clone() + .slice([idx..idx + 1, jdx..jdx + 1, kdx..kdx + 1, 0..3, 0..3]) + .reshape([3, 3]); + assert_matrix_close(&slice, &expected, 1e-6); + } + } + } + let expanded_shape = [10, 9, 20, 3, 3]; + let mat = expected.clone().expand(expanded_shape); + let identity_mat = diag_i(&mat, |i| diag_entries[i]); + for idx in 0..expanded_shape[0] { + for jdx in 0..expanded_shape[1] { + for kdx in 0..expanded_shape[2] { + let slice = identity_mat + .clone() + .slice([idx..idx + 1, jdx..jdx + 1, kdx..kdx + 1, 0..3, 0..3]) + .reshape([3, 3]); + assert_matrix_close(&slice, &expected, 1e-6); + } + } + } + } }