diff --git a/examples/multi_constraints.rs b/examples/multi_constraints.rs index bad4673..8c3dab2 100644 --- a/examples/multi_constraints.rs +++ b/examples/multi_constraints.rs @@ -1,10 +1,12 @@ +use burn::module::Module; use burn::module::ModuleVisitor; -use burn::nn::LinearConfig; use burn::nn::Linear; -use burn::module::Module; +use burn::nn::LinearConfig; -use manopt_rs::manifolds::Constrained; -use manopt_rs::optimizers::multiple::{MultiManifoldOptimizer, MultiManifoldOptimizerConfig, ManifoldOptimizable}; +use manopt_rs::constrained_module::Constrained; +use manopt_rs::optimizers::multiple::{ + ManifoldOptimizable, MultiManifoldOptimizer, MultiManifoldOptimizerConfig, +}; use manopt_rs::prelude::*; // Example: User-defined custom manifold @@ -12,6 +14,8 @@ use manopt_rs::prelude::*; pub struct CustomSphereManifold; impl Manifold for CustomSphereManifold { + const RANK_PER_POINT: usize = 1; + fn new() -> Self { Self } @@ -22,30 +26,62 @@ impl Manifold for CustomSphereManifold { fn project(point: Tensor, vector: Tensor) -> Tensor { // For sphere: project vector orthogonal to point - let dot_product = (point.clone() * vector.clone()).sum(); - vector - point * dot_product.unsqueeze() + debug_assert!(point.shape() == vector.shape()); + let dot_product = + (point.clone() * vector.clone()).sum_dim(D - >::RANK_PER_POINT); + vector - point * dot_product } fn retract(point: Tensor, direction: Tensor) -> Tensor { // For sphere: normalize the result + debug_assert!(point.shape() == direction.shape()); let new_point = point + direction; - let norm = new_point.clone().powf_scalar(2.0).sum().sqrt().unsqueeze(); + let norm = new_point + .clone() + .powf_scalar(2.0) + .sum_dim(D - >::RANK_PER_POINT) + .sqrt(); new_point / norm } - fn inner(_point: Tensor, u: Tensor, v: Tensor) -> Tensor { - u * v + fn inner( + _point: Tensor, + u: Tensor, + v: Tensor, + ) -> Tensor { + (u * v).sum_dim(D - >::RANK_PER_POINT) } fn proj(point: Tensor) -> Tensor { // Project point onto unit sphere - let norm = point.clone().powf_scalar(2.0).sum().sqrt().unsqueeze(); + let norm = point + .clone() + .powf_scalar(2.0) + .sum_dim(D - >::RANK_PER_POINT) + .sqrt(); point / norm } - fn is_in_manifold(_point: Tensor) -> bool { - // For now, just return true - in a real implementation you'd check the constraint - true + fn is_in_manifold(point: Tensor) -> Tensor { + let r_squared = point + .powf_scalar(2.0) + .sum_dim(D - >::RANK_PER_POINT); + let one = r_squared.ones_like(); + r_squared.is_close(one, None, None) + } + + fn is_tangent_at( + point: Tensor, + vector: Tensor, + ) -> Tensor { + let dot_product = (point * vector).sum_dim(D - >::RANK_PER_POINT); + let zeros = dot_product.zeros_like(); + dot_product.is_close(zeros, None, Some(1e-6)) + } + + fn acceptable_dims(a_is: &[usize]) -> bool { + let n = *a_is.first().expect("The ambient R^n does exist"); + n > 0 } } @@ -53,7 +89,7 @@ impl Manifold for CustomSphereManifold { pub struct TestModel { // Euclidean constrained linear layer linear_euclidean: Constrained, Euclidean>, - // Custom sphere constrained linear layer + // Custom sphere constrained linear layer linear_sphere: Constrained, CustomSphereManifold>, // Regular unconstrained linear layer linear_regular: Linear, @@ -125,7 +161,7 @@ impl TestModel { let linear_euclidean = LinearConfig::new(10, 5).init(device); let linear_sphere = LinearConfig::new(5, 3).init(device); let linear_regular = LinearConfig::new(3, 1).init(device); - + Self { linear_euclidean: Constrained::new(linear_euclidean), linear_sphere: Constrained::new(linear_sphere), @@ -138,12 +174,26 @@ struct ManifoldAwareVisitor; impl ModuleVisitor for ManifoldAwareVisitor { fn visit_float(&mut self, id: burn::module::ParamId, tensor: &Tensor) { - println!("Visiting parameter: {:?} with shape: {:?}", id, tensor.dims()); + println!( + "Visiting parameter: {:?} with shape: {:?}", + id, + tensor.dims() + ); } - fn visit_int(&mut self, _id: burn::module::ParamId, _tensor: &Tensor) {} + fn visit_int( + &mut self, + _id: burn::module::ParamId, + _tensor: &Tensor, + ) { + } - fn visit_bool(&mut self, _id: burn::module::ParamId, _tensor: &Tensor) {} + fn visit_bool( + &mut self, + _id: burn::module::ParamId, + _tensor: &Tensor, + ) { + } } fn main() { @@ -151,53 +201,74 @@ fn main() { type AutoDiffBackend = burn::backend::Autodiff; let device = Default::default(); - + // Create a model with mixed manifold constraints let model = TestModel::::new(&device); - + println!("=== Model Structure ==="); - println!("Euclidean layer manifold: {}", model.linear_euclidean.manifold_name::()); - println!("Sphere layer manifold: {}", model.linear_sphere.manifold_name::()); - + println!( + "Euclidean layer manifold: {}", + model.linear_euclidean.manifold_name::() + ); + println!( + "Sphere layer manifold: {}", + model.linear_sphere.manifold_name::() + ); + // Create multi-manifold optimizer let config = MultiManifoldOptimizerConfig::default(); let mut optimizer = MultiManifoldOptimizer::new(config); - + // Collect manifold information from the model optimizer.collect_manifolds(&model); - + // Register custom manifold for specific parameters (if needed) optimizer.register_manifold::("linear_sphere.weight".to_string()); - + println!("\n=== Manifold Information ==="); - println!("Euclidean info: {:?}", model.linear_euclidean.get_manifold_info()); + println!( + "Euclidean info: {:?}", + model.linear_euclidean.get_manifold_info() + ); println!("Sphere info: {:?}", model.linear_sphere.get_manifold_info()); - + // Example of applying constraints let constrained_model = optimizer.apply_constraints(model); - + // Visit the model to see parameter structure println!("\n=== Parameter Structure ==="); let mut visitor = ManifoldAwareVisitor; constrained_model.visit(&mut visitor); - + println!("\n=== Demonstrating Custom Manifold Operations ==="); - + // Show how the custom sphere manifold works let point = Tensor::::from_floats([3.0, 4.0, 0.0], &device); let vector = Tensor::::from_floats([1.0, 1.0, 1.0], &device); - + println!("Original point: {:?}", point.to_data()); println!("Original vector: {:?}", vector.to_data()); - + // Project point to sphere let projected_point = CustomSphereManifold::proj(point.clone()); println!("Point projected to sphere: {:?}", projected_point.to_data()); - + // Project vector to tangent space let projected_vector = CustomSphereManifold::project(projected_point.clone(), vector); - println!("Vector projected to tangent space: {:?}", projected_vector.to_data()); - + println!( + "Vector projected to tangent space: {:?}", + projected_vector.to_data() + ); + // Check if point is on manifold - println!("Is projected point on sphere? {}", CustomSphereManifold::is_in_manifold(projected_point)); + println!( + "Is projected point on sphere? {}", + CustomSphereManifold::is_in_manifold(projected_point.clone()) + ); + + // Check if vector is tangent at point on manifold + println!( + "Is projected vector tangent to point on sphere? {}", + CustomSphereManifold::is_tangent_at(projected_point, projected_vector) + ); } diff --git a/examples/optimization_demo.rs b/examples/optimization_demo.rs index 9d8c36c..8bc7c27 100644 --- a/examples/optimization_demo.rs +++ b/examples/optimization_demo.rs @@ -1,5 +1,7 @@ +use std::collections::HashMap; + use burn::optim::SimpleOptimizer; -use manopt_rs::prelude::*; +use manopt_rs::{optimizers::LessSimpleOptimizer, prelude::*}; fn main() { // Configure the optimizer @@ -18,6 +20,8 @@ fn main() { Tensor::::from_floats([0.0, 0.0, 0.0], &Default::default()); let mut state = None; + let mut loss_decay: HashMap = HashMap::new(); + println!("Target: {}", target); println!("Initial x: {}", x); println!("\nOptimization steps:"); @@ -35,13 +39,36 @@ fn main() { // Print progress every 10 steps if step % 10 == 0 { let loss = (x.clone() - target.clone()).powf_scalar(2.0).sum(); - println!("Step {}: x = {}, loss = {}", step, x, loss); + let loss_scalar = loss.into_scalar(); + println!("Step {}: x = {}, loss = {:.5}", step, x, loss_scalar); + loss_decay.insert(step, loss_scalar); } } + println!("\nResult after 100:"); + println!("x = {}", x); + println!("Target = {}", target); + let final_loss = (x.clone() - target.clone()) + .powf_scalar(2.0) + .sum() + .into_scalar(); + println!("Loss after 100 = {:.5}", final_loss); + loss_decay.insert(100, final_loss); + + // Perform optimization steps + (x, state) = optimizer.many_steps(|_| 1.0, 400, |x| (x - target.clone()) * 2.0, x, state); + println!("\nFinal result:"); println!("x = {}", x); println!("Target = {}", target); - let final_loss = (x.clone() - target.clone()).powf_scalar(2.0).sum(); - println!("Final loss = {}", final_loss); + let final_loss = (x.clone() - target.clone()) + .powf_scalar(2.0) + .sum() + .into_scalar(); + println!("Final loss = {:.5}", final_loss); + println!("State is set {}", state.is_some()); + loss_decay.insert(500, final_loss); + let mut sorted_losses: Vec<(usize, f32)> = loss_decay.into_iter().collect(); + sorted_losses.sort_by_key(|z| z.0); + println!("The loss decayed as follows: {:?}", sorted_losses); } diff --git a/examples/riemannian_adam_demo.rs b/examples/riemannian_adam_demo.rs index 6609ed7..9238f6d 100644 --- a/examples/riemannian_adam_demo.rs +++ b/examples/riemannian_adam_demo.rs @@ -1,5 +1,5 @@ use burn::optim::SimpleOptimizer; -use manopt_rs::prelude::*; +use manopt_rs::{optimizers::LessSimpleOptimizer, prelude::*}; fn main() { println!("Testing Riemannian Adam optimizer..."); @@ -20,8 +20,20 @@ fn main() { let (new_tensor, state) = optimizer.step(1.0, tensor.clone(), grad, None); println!("Original tensor: {}", tensor); - println!("New tensor: {}", new_tensor); + println!("New tensor: {:.4}", new_tensor); println!("State initialized: {}", state.is_some()); + fn grad_function(_x: Tensor) -> Tensor { + Tensor::::ones([2, 2], &Default::default()) + } + // Perform more optimization steps + let (new_tensor, state) = + optimizer.many_steps(|_| 1.0, 9, grad_function, tensor.clone(), state); + println!("After even more steps Tensor: {:.4}", new_tensor); + println!( + "After even more steps State initialized: {}", + state.is_some() + ); + println!("Riemannian Adam test completed successfully!"); } diff --git a/src/constrained_module.rs b/src/constrained_module.rs new file mode 100644 index 0000000..01ec1f0 --- /dev/null +++ b/src/constrained_module.rs @@ -0,0 +1,217 @@ +//! Riemannian manifolds for constrained optimization. +//! +//! This module defines manifolds and their operations for Riemannian optimization. + +use std::{fmt::Debug, marker::PhantomData}; + +use crate::prelude::*; + +use burn::{ + module::{AutodiffModule, ModuleDisplay}, + tensor::backend::AutodiffBackend, +}; + +#[derive(Clone, Debug)] +pub struct Constrained { + module: M, + _manifold: PhantomData, +} + +impl Module for Constrained +where + M: Module, + B: Backend, + Man: Clone + Debug + Send, +{ + type Record = M::Record; + + fn collect_devices(&self, devices: burn::module::Devices) -> burn::module::Devices { + self.module.collect_devices(devices) + } + + fn fork(self, device: &B::Device) -> Self { + let module = self.module.fork(device); + Self { + module, + _manifold: PhantomData, + } + } + + fn to_device(self, device: &B::Device) -> Self { + let module = self.module.to_device(device); + Self { + module, + _manifold: PhantomData, + } + } + + fn visit>(&self, visitor: &mut Visitor) { + self.module.visit(visitor); + } + + fn map>(self, mapper: &mut Mapper) -> Self { + let module = self.module.map(mapper); + Self { + module, + _manifold: PhantomData, + } + } + + fn load_record(self, record: Self::Record) -> Self { + let module = self.module.load_record(record); + Self { + module, + _manifold: PhantomData, + } + } + + fn into_record(self) -> Self::Record { + self.module.into_record() + } +} + +impl AutodiffModule for Constrained +where + M: AutodiffModule, + B: AutodiffBackend, + Man: Clone + Debug + Send, +{ + type InnerModule = M::InnerModule; + + fn valid(&self) -> Self::InnerModule { + self.module.valid() + } +} + +impl burn::module::ModuleDisplayDefault for Constrained +where + M: burn::module::ModuleDisplayDefault, + Man: Clone + Debug + Send, +{ + fn content(&self, content: burn::module::Content) -> Option { + self.module.content(content) + } +} + +impl ModuleDisplay for Constrained +where + M: ModuleDisplay, + Man: Clone + Debug + Send, +{ + fn format(&self, passed_settings: burn::module::DisplaySettings) -> String { + format!("Constrained<{}>", self.module.format(passed_settings)) + } +} + +impl Constrained { + pub fn new(module: M) -> Self { + Self { + module, + _manifold: PhantomData, + } + } + + /// Get a reference to the inner module + pub fn inner(&self) -> &M { + &self.module + } + + /// Get a mutable reference to the inner module + pub fn inner_mut(&mut self) -> &mut M { + &mut self.module + } + + /// Apply manifold projection to a tensor - requires explicit Backend type + pub fn project_tensor( + &self, + point: Tensor, + vector: Tensor, + ) -> Tensor + where + B: Backend, + M: Module, + Man: Manifold + Clone + Debug + Send, + { + Man::project(point, vector) + } + + /// Apply manifold retraction to a tensor - requires explicit Backend type + pub fn retract_tensor( + &self, + point: Tensor, + direction: Tensor, + ) -> Tensor + where + B: Backend, + M: Module, + Man: Manifold + Clone + Debug + Send, + { + Man::retract(point, direction) + } + + /// Convert Euclidean gradient to Riemannian gradient - requires explicit Backend type + pub fn euclidean_to_riemannian( + &self, + point: Tensor, + grad: Tensor, + ) -> Tensor + where + B: Backend, + M: Module, + Man: Manifold + Clone + Debug + Send, + { + Man::egrad2rgrad(point, grad) + } + + /// Project point onto manifold - requires explicit Backend type + pub fn project_to_manifold(&self, point: Tensor) -> Tensor + where + B: Backend, + M: Module, + Man: Manifold + Clone + Debug + Send, + { + Man::proj(point) + } + + /// Get the manifold name + pub fn manifold_name(&self) -> &'static str + where + B: Backend, + Man: Manifold, + { + Man::name() + } +} + +/// Trait for modules that have manifold constraints +pub trait ConstrainedModule { + /// Apply manifold constraints to all parameters in the module + #[must_use] + fn apply_manifold_constraints(self) -> Self; + + /// Get information about the manifold constraints + fn get_manifold_info(&self) -> std::collections::HashMap; + + /// Check if this module has manifold constraints + fn has_manifold_constraints(&self) -> bool { + true + } +} + +/// Blanket implementation for Constrained wrapper +impl ConstrainedModule for Constrained +where + M: Module, + B: Backend, + Man: Manifold + Clone + Debug + Send, +{ + fn apply_manifold_constraints(self) -> Self { + self + } + + fn get_manifold_info(&self) -> std::collections::HashMap { + let mut info = std::collections::HashMap::new(); + info.insert("manifold_type".to_string(), Man::name().to_string()); + info + } +} diff --git a/src/lib.rs b/src/lib.rs index 7bdbd5c..52d1117 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +pub mod constrained_module; +pub mod lie_group; pub mod manifolds; pub mod optimizers; pub mod prelude; diff --git a/src/lie_group.rs b/src/lie_group.rs new file mode 100644 index 0000000..71a017a --- /dev/null +++ b/src/lie_group.rs @@ -0,0 +1,7 @@ +use burn::{prelude::Backend, tensor::Tensor}; + +use crate::prelude::Manifold; + +pub trait MonoidManifold: Clone + Send + Sync + Manifold { + fn lie_mul(points0: Tensor, points1: Tensor) -> Tensor; +} diff --git a/src/manifolds.rs b/src/manifolds.rs index 1cc2802..6216a1e 100644 --- a/src/manifolds.rs +++ b/src/manifolds.rs @@ -1,20 +1,27 @@ //! Riemannian manifolds for constrained optimization. //! -//! This module defines manifolds and their operations for Riemannian optimization. +//! This module defines manifolds and their operations. //! Each manifold implements geometric operations like projection, retraction, //! exponential maps, and parallel transport. -use std::{fmt::Debug, marker::PhantomData}; +use std::fmt::Debug; use crate::prelude::*; pub mod steifiel; -use burn::{module::{AutodiffModule, ModuleDisplay}, tensor::backend::AutodiffBackend}; pub use steifiel::SteifielsManifold; pub mod sphere; +pub use sphere::Sphere; + +pub mod matrix_groups; +pub use matrix_groups::OrthogonalGroup; + +pub mod utils; /// A Riemannian manifold defines the geometric structure for optimization. +/// This is actually for a family of manifolds parameterized by some natural numbers. +/// /// /// This trait provides all the necessary operations for Riemannian optimization: /// - Tangent space projections @@ -23,62 +30,67 @@ pub mod sphere; /// - Parallel transport /// - Riemannian inner products /// -/// # Example Implementation -/// -/// ```rust -/// use manopt_rs::prelude::*; -/// -/// #[derive(Clone)] -/// struct MyManifold; -/// -/// impl Manifold for MyManifold { -/// fn new() -> Self { MyManifold } -/// fn name() -> &'static str { "MyManifold" } -/// -/// fn project(point: Tensor, vector: Tensor) -> Tensor { -/// // Project vector to tangent space at point -/// vector -/// } -/// -/// fn retract(point: Tensor, direction: Tensor) -> Tensor { -/// // Move along manifold from point in direction with step size -/// point + direction -/// } -/// -/// fn inner(_point: Tensor, u: Tensor, v: Tensor) -> Tensor { -/// // Riemannian inner product at point -/// u * v -/// } -/// } -/// ``` pub trait Manifold: Clone + Send + Sync { + const RANK_PER_POINT: usize; + fn new() -> Self; fn name() -> &'static str; - + #[must_use] + fn specific_name(s: &Shape) -> String { + let dims = &s.dims; + let num_dims = dims.len(); + let (channel_dims, manifold_dims) = dims.split_at(num_dims - Self::RANK_PER_POINT); + format!( + "{channel_dims:?} Channels worth of points in {} with specific n's {manifold_dims:?}", + Self::name() + ) + } + + /// The manifold lives in `R^a_1 \times R^{a_{RANK_PER_POINT}}` + /// so if we have a Tensor of shape `s` + /// then it's last `RANK_PER_POINT` dimensions will be those a's + /// with the previous dimensions being used as channels. + /// Those a's then must be allowed + /// For example in a Matrix Lie group they will be + /// `R^n \times R^n`, giving a constraint that those last + /// two dimensions in the shape be equal to each other. + #[must_use] + fn acceptable_shape(s: &Shape) -> bool { + let enough_points = s.num_dims() >= Self::RANK_PER_POINT; + if !enough_points { + return false; + } + let (_, a_i) = s.dims.split_at(s.num_dims() - Self::RANK_PER_POINT); + Self::acceptable_dims(a_i) + } + + /// The manifold lives in `R^a_1 \times R^{a_{RANK_PER_POINT}}` + /// Those a's must be allowed . + /// For example in a Matrix Lie group they will be + /// `R^n \times R^n`, giving a constraint that those last + /// two dimensions in the shape be equal to each other. + /// For the purposes of this, we are allowed to assume the slice + /// is of length `Self::RANK_PER_POINT` + /// because it should only be called through `Self::acceptable_shape`. + /// Putting this in the type would not be allowed without unstable features. + fn acceptable_dims(_a_is: &[usize]) -> bool; + + /// Project `vector` to the tangent space at `point` fn project(point: Tensor, vector: Tensor) -> Tensor; - fn retract( - point: Tensor, - direction: Tensor, - ) -> Tensor; - /// Convert Euclidean gradient to Riemannian gradient + /// Convert Euclidean gradient `grad` to Riemannian gradient at `point` fn egrad2rgrad(point: Tensor, grad: Tensor) -> Tensor { Self::project(point, grad) } - /// Riemannian inner product at a given point - fn inner(point: Tensor, u: Tensor, v: Tensor) - -> Tensor; - - /// Exponential map: move from point along tangent vector u with step size - fn expmap( - point: Tensor, - direction: Tensor, - ) -> Tensor { - Self::retract(point, direction) + /// Project `vector` to the tangent space at `point` + fn project_tangent(point: Tensor, vector: Tensor) -> Tensor { + Self::project(point, vector) } - /// Parallel transport of tangent vector from point1 to point2 + /// Parallel transport of a tangent vector `tangent` from `point1` to `point2` + /// By default, this is not accurately implemented and ignores the metric/connection + /// just projecting to the tangent space. fn parallel_transport( _point1: Tensor, point2: Tensor, @@ -88,21 +100,32 @@ pub trait Manifold: Clone + Send + Sync { Self::project_tangent(point2, tangent) } - /// Project vector to tangent space at point - fn project_tangent(point: Tensor, vector: Tensor) -> Tensor { - Self::project(point, vector) - } + /// Move along the manifold from `point` along the tangent vector `direction` with step size + fn retract(point: Tensor, direction: Tensor) -> Tensor; - /// Project point onto manifold - fn proj(point: Tensor) -> Tensor { - point + /// Exponential map: move from `point` along tangent vector `direction` with step size + fn expmap(point: Tensor, direction: Tensor) -> Tensor { + Self::retract(point, direction) } - /// Check if a point is in the manifold. - /// By default, this is not implemented and returns `false`. - fn is_in_manifold(_point: Tensor) -> bool { - false - } + /// Riemannian inner product at a given `point` + /// `u` and `v` are in the tangent space at `point` + fn inner(point: Tensor, u: Tensor, v: Tensor) + -> Tensor; + + /// Project `point` onto manifold + fn proj(point: Tensor) -> Tensor; + + /// Check if a `point` is in the manifold. + fn is_in_manifold(point: Tensor) -> Tensor; + + /// Check if a `vector` is in the tangent space at `point` + /// given that `point` is in the manifold. + /// By default, this is not accurately implemented and returns `false`. + fn is_tangent_at( + point: Tensor, + vector: Tensor, + ) -> Tensor; } /// Euclidean manifold - the simplest case where no projection is needed @@ -110,6 +133,8 @@ pub trait Manifold: Clone + Send + Sync { pub struct Euclidean; impl Manifold for Euclidean { + const RANK_PER_POINT: usize = 1; + fn new() -> Self { Self } @@ -122,10 +147,7 @@ impl Manifold for Euclidean { vector } - fn retract( - point: Tensor, - direction: Tensor, - ) -> Tensor { + fn retract(point: Tensor, direction: Tensor) -> Tensor { point + direction } @@ -134,203 +156,38 @@ impl Manifold for Euclidean { u: Tensor, v: Tensor, ) -> Tensor { - u * v - } - - fn is_in_manifold(_point: Tensor) -> bool { - true - } -} - -#[derive(Clone, Debug)] -pub struct Constrained { - module: M, - _manifold: PhantomData, -} - -impl Module for Constrained -where - M: Module, - B: Backend, - Man: Clone + Debug + Send, -{ - type Record = M::Record; - - fn collect_devices(&self, devices: burn::module::Devices) -> burn::module::Devices { - self.module.collect_devices(devices) + (u * v).sum_dim(D - 1) } - fn fork(self, device: &B::Device) -> Self { - let module = self.module.fork(device); - Self { - module, - _manifold: PhantomData, - } - } - - fn to_device(self, device: &B::Device) -> Self { - let module = self.module.to_device(device); - Self { - module, - _manifold: PhantomData, - } - } - - fn visit>(&self, visitor: &mut Visitor) { - self.module.visit(visitor); - } - - fn map>(self, mapper: &mut Mapper) -> Self { - let module = self.module.map(mapper); - Self { - module, - _manifold: PhantomData, - } - } - - fn load_record(self, record: Self::Record) -> Self { - let module = self.module.load_record(record); - Self { - module, - _manifold: PhantomData, - } - } - - fn into_record(self) -> Self::Record { - self.module.into_record() - } -} - - -impl AutodiffModule for Constrained -where - M: AutodiffModule, - B: AutodiffBackend, - Man: Clone + Debug + Send, -{ - type InnerModule = M::InnerModule; - - fn valid(&self) -> Self::InnerModule { - self.module.valid() - } -} - -impl burn::module::ModuleDisplayDefault for Constrained -where - M: burn::module::ModuleDisplayDefault, - Man: Clone + Debug + Send, -{ - fn content(&self, content: burn::module::Content) -> Option { - self.module.content(content) - } -} - -impl ModuleDisplay for Constrained -where - M: ModuleDisplay, - Man: Clone + Debug + Send, -{ - fn format(&self, passed_settings: burn::module::DisplaySettings) -> String { - format!("Constrained<{}>", self.module.format(passed_settings)) + fn is_in_manifold( + point: Tensor, + ) -> burn::tensor::Tensor { + point + .clone() + .detach() + .is_nan() + .any_dim(>::RANK_PER_POINT) + .bool_not() } -} -impl Constrained { - pub fn new(module: M) -> Self { - Self { - module, - _manifold: PhantomData, - } - } - - /// Get a reference to the inner module - pub fn inner(&self) -> &M { - &self.module - } - - /// Get a mutable reference to the inner module - pub fn inner_mut(&mut self) -> &mut M { - &mut self.module - } - - /// Apply manifold projection to a tensor - requires explicit Backend type - pub fn project_tensor(&self, point: Tensor, vector: Tensor) -> Tensor - where - B: Backend, - M: Module, - Man: Manifold + Clone + Debug + Send, - { - Man::project(point, vector) - } - - /// Apply manifold retraction to a tensor - requires explicit Backend type - pub fn retract_tensor(&self, point: Tensor, direction: Tensor) -> Tensor - where - B: Backend, - M: Module, - Man: Manifold + Clone + Debug + Send, - { - Man::retract(point, direction) - } - - /// Convert Euclidean gradient to Riemannian gradient - requires explicit Backend type - pub fn euclidean_to_riemannian(&self, point: Tensor, grad: Tensor) -> Tensor - where - B: Backend, - M: Module, - Man: Manifold + Clone + Debug + Send, - { - Man::egrad2rgrad(point, grad) - } - - /// Project point onto manifold - requires explicit Backend type - pub fn project_to_manifold(&self, point: Tensor) -> Tensor - where - B: Backend, - M: Module, - Man: Manifold + Clone + Debug + Send, - { - Man::proj(point) - } - - /// Get the manifold name - pub fn manifold_name(&self) -> &'static str - where - B: Backend, - Man: Manifold, - { - Man::name() + fn proj(point: Tensor) -> Tensor { + point } -} -/// Trait for modules that have manifold constraints -pub trait ConstrainedModule { - /// Apply manifold constraints to all parameters in the module - fn apply_manifold_constraints(self) -> Self; - - /// Get information about the manifold constraints - fn get_manifold_info(&self) -> std::collections::HashMap; - - /// Check if this module has manifold constraints - fn has_manifold_constraints(&self) -> bool { + fn is_tangent_at( + point: Tensor, + vector: Tensor, + ) -> Tensor { + let vector_exists = vector + .clone() + .detach() + .is_nan() + .any_dim(>::RANK_PER_POINT) + .bool_not(); + Self::is_in_manifold(point).bool_and(vector_exists) + } + + fn acceptable_dims(_a_is: &[usize]) -> bool { true } } - -/// Blanket implementation for Constrained wrapper -impl ConstrainedModule for Constrained -where - M: Module, - B: Backend, - Man: Manifold + Clone + Debug + Send, -{ - fn apply_manifold_constraints(self) -> Self { - self - } - - fn get_manifold_info(&self) -> std::collections::HashMap { - let mut info = std::collections::HashMap::new(); - info.insert("manifold_type".to_string(), Man::name().to_string()); - info - } -} \ No newline at end of file diff --git a/src/manifolds/matrix_groups.rs b/src/manifolds/matrix_groups.rs new file mode 100644 index 0000000..2308515 --- /dev/null +++ b/src/manifolds/matrix_groups.rs @@ -0,0 +1,86 @@ +use crate::{lie_group::MonoidManifold, manifolds::utils::identity_in_last_two, prelude::*}; + +#[derive(Debug, Clone, Default)] +pub struct OrthogonalGroup { + _backend: std::marker::PhantomData, +} + +impl Manifold for OrthogonalGroup { + const RANK_PER_POINT: usize = 2; + + fn new() -> Self { + OrthogonalGroup { + _backend: std::marker::PhantomData, + } + } + + fn name() -> &'static str { + if IS_SPECIAL { + "Special Orthogonal" + } else { + "Orthogonal" + } + } + + fn acceptable_dims(a_is: &[usize]) -> bool { + debug_assert!(a_is.len() >= Self::RANK_PER_POINT); + let num_dims = a_is.len(); + a_is[num_dims - 1] == a_is[num_dims - 2] + } + + fn project(_point: Tensor, _vector: Tensor) -> Tensor { + todo!() + } + + fn retract(_point: Tensor, _direction: Tensor) -> Tensor { + todo!() + } + + fn inner( + _point: Tensor, + u: Tensor, + v: Tensor, + ) -> Tensor { + // For orthogonal manifolds, we use the standard Euclidean inner product + (u * v).sum_dim(D - 1).sum_dim(D - 2) + } + + fn proj(_point: Tensor) -> Tensor { + todo!() + } + + fn is_in_manifold(point: Tensor) -> Tensor { + if Self::acceptable_shape(&point.shape()) { + return point.zeros_like().any_dim(D - 1).any_dim(D - 2); + } + let a_transpose_times_a = point.clone().transpose().matmul(point); + let all_dims = a_transpose_times_a.shape(); + debug_assert!(all_dims.num_dims() >= 2); + let other = identity_in_last_two(&a_transpose_times_a); + let in_orthogonal = a_transpose_times_a + .is_close(other, None, None) + .all_dim(D - 1) + .all_dim(D - 2); + if IS_SPECIAL { + in_orthogonal + } else { + #[allow(unused_variables)] + let has_det_one = { todo!() }; + #[allow(unreachable_code)] + in_orthogonal.bool_and(has_det_one); + } + } + + fn is_tangent_at( + _point: Tensor, + _vector: Tensor, + ) -> Tensor { + todo!() + } +} + +impl MonoidManifold for OrthogonalGroup { + fn lie_mul(points0: Tensor, points1: Tensor) -> Tensor { + points0.matmul(points1) + } +} diff --git a/src/manifolds/sphere.rs b/src/manifolds/sphere.rs index 17bfd28..dacfd76 100644 --- a/src/manifolds/sphere.rs +++ b/src/manifolds/sphere.rs @@ -1,10 +1,11 @@ use crate::prelude::*; -/// Euclidean manifold - the simplest case where no projection is needed #[derive(Clone, Debug)] pub struct Sphere; impl Manifold for Sphere { + const RANK_PER_POINT: usize = 1; + fn new() -> Self { Self } @@ -13,18 +14,38 @@ impl Manifold for Sphere { "Sphere" } - fn project(_point: Tensor, vector: Tensor) -> Tensor - { - // Y/||y| - vector.clone()/(vector.clone().transpose().matmul(vector)).sqrt() + fn specific_name(s: &Shape) -> String { + let num_dims = s.num_dims(); + assert!( + num_dims > 0, + "There is at least one dimension where the manifold actually lives" + ); + let sphere_dim = *s + .dims + .last() + .expect("There is at least one dimension where the manifold actually lives"); + let (channel_dims, _) = s.dims.split_at(num_dims - 1); + if channel_dims.is_empty() { + format!("Sphere S^{} subset R^{sphere_dim}", sphere_dim - 1) + } else { + format!( + "{channel_dims:?} Channels worth of points in Sphere S^{} subset R^{sphere_dim}", + sphere_dim - 1 + ) + } } - fn retract( - _point: Tensor, - _direction: Tensor, - ) -> Tensor { - todo!("Implement retract for Sphere manifold") - + fn project(point: Tensor, vector: Tensor) -> Tensor { + // For sphere: project vector orthogonal to point + let dot_product = (point.clone() * vector.clone()).sum_dim(D - 1); + vector - point * dot_product + } + + fn retract(point: Tensor, direction: Tensor) -> Tensor { + // For sphere: normalize the result + let new_point = point + direction; + let norm = new_point.clone().powf_scalar(2.0).sum_dim(D - 1).sqrt(); + new_point / norm } fn inner( @@ -32,10 +53,225 @@ impl Manifold for Sphere { u: Tensor, v: Tensor, ) -> Tensor { - u.transpose().matmul(v) + u * v.sum_dim(D - 1) + } + + fn is_in_manifold(point: Tensor) -> Tensor { + let r_squared = point.powf_scalar(2.0).sum_dim(D - 1); + let one = r_squared.ones_like(); + r_squared.is_close(one, None, None) + } + + fn is_tangent_at( + point: Tensor, + vector: Tensor, + ) -> Tensor { + let dot_product = (point * vector).sum_dim(D - 1); + let zero = dot_product.zeros_like(); + dot_product.is_close(zero, None, Some(1e-6)) + } + + fn proj(point: Tensor) -> Tensor { + let norm = point.clone().powf_scalar(2.0).sum_dim(D - 1).sqrt(); + point / norm + } + + fn acceptable_dims(a_is: &[usize]) -> bool { + let n = *a_is.first().expect("The ambient R^n does exist"); + n > 0 + } +} + +#[cfg(test)] +mod test { + use crate::prelude::Manifold; + + use super::Sphere; + use burn::{ + backend::{Autodiff, NdArray}, + tensor::{Shape, Tensor}, + }; + + type TestBackend = Autodiff; + type TestTensor = Tensor; + type TestManyTensor = Tensor; + + const TOLERANCE: f32 = 1e-6; + + fn assert_tensor_close(a: &TestTensor, b: &TestTensor, tol: f32) { + let diff = (a.clone() - b.clone()).abs(); + let max_diff = diff.max().into_scalar(); + assert!( + max_diff < tol, + "Tensors differ by {}, tolerance: {}", + max_diff, + tol + ); + } + + fn create_test_matrix(rows: usize, values: Vec) -> TestTensor { + let device = Default::default(); + let data = &values[0..rows]; + Tensor::from_floats(data, &device) + } + + fn create_test_matrices( + data: [[[f32; ROWS]; CHANNEL1]; CHANNEL0], + ) -> TestManyTensor { + let device = Default::default(); + Tensor::from_floats(data, &device) + } + + #[test] + fn test_manifold_creation() { + let _manifold = >::new(); + assert_eq!(>::name(), "Sphere"); + assert_eq!( + >::specific_name(&burn::tensor::Shape { + dims: vec![5] + }), + "Sphere S^4 subset R^5" + ); + assert_eq!( + >::specific_name(&burn::tensor::Shape { + dims: vec![10, 30, 5] + }), + "[10, 30] Channels worth of points in Sphere S^4 subset R^5" + ); + } + + #[test] + fn test_projection_tangent_space() { + // Create a point on the Sphere manifold + let point = create_test_matrix(6, vec![3.0 / 5.0, 0.0, 0.0, 4.0 / 5.0, 0.0, 0.0]); + + // Create a direction vector + let direction = create_test_matrix(6, vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]); + + let projected = + >::project(point.clone(), direction.clone()); + + // The projection should be orthogonal to the point + // i.e., point^T * projected should be 0 + let product = (point.clone() * projected.clone()).sum(); + let max_entry = product.abs().max().into_scalar(); + assert!( + max_entry < TOLERANCE, + "Projected direction not in tangent space: absoulte value of the dot product = {}", + max_entry + ); + } + + #[test] + fn test_many_projection_tangent_space() { + // Create many points on the Sphere manifold + let point_00 = [3.0 / 5.0, 0.0, 0.0, 4.0 / 5.0, 0.0, 0.0]; + let point_01 = [4.0 / 5.0, 0.0, 3.0 / 5.0, 0.0, 0.0, 0.0]; + let point_02 = [1.0 / 1.0, 0.0, 0.0, 0.0 / 1.0, 0.0, 0.0]; + let point_10 = [0.0 / 1.0, 0.0, 0.0, -1.0 / 1.0, 0.0, 0.0]; + let point_11 = [3.0 / 5.0, 0.0, 4.0 / 5.0, 0.0, 0.0, 0.0]; + let point_12 = [3.0 / 5.0, 0.0, 0.0, -4.0 / 5.0, 0.0, 0.0]; + let points = create_test_matrices::<6, 2, 3>([ + [point_00, point_01, point_02], + [point_10, point_11, point_12], + ]); + assert_eq!( + points.shape(), + Shape { + dims: vec![2, 3, 6] + } + ); + + // Create many direction vectors + let directions = TestManyTensor::random( + points.shape(), + burn::tensor::Distribution::Uniform(-1.0, 1.0), + &points.device(), + ); + + let projecteds = + >::project(points.clone(), directions.clone()); + + // The projection should be orthogonal to the point + // i.e., point^T * projected should be 0 + let product = (points.clone() * projecteds.clone()).sum_dim(2); + let max_entry = product.abs().max().into_scalar(); + assert!( + max_entry < TOLERANCE, + "Projected direction not in tangent space: absoulte value of the dot product = {}", + max_entry + ); } - fn is_in_manifold(_point: Tensor) -> bool { - true + #[test] + fn test_projection_preserves_tangent_vectors() { + // Create a point on the Sphere manifold + let point = create_test_matrix(6, vec![3.0 / 5.0, 0.0, 0.0, 4.0 / 5.0, 0.0, 0.0]); + + assert!( + Sphere::is_in_manifold(point.clone()).into_scalar(), + "This is a point on the sphere by construction" + ); + + // Create a direction vector + let direction = create_test_matrix(6, vec![4.0 / 5.0, 0.2, 0.3, -3.0 / 5.0, 0.5, 0.6]); + + assert!( + Sphere::is_tangent_at(point.clone(), direction.clone()).into_scalar(), + "This direction is orthogonal to point by construction" + ); + + let projected = + >::project(point.clone(), direction.clone()); + + // The projection should be orthogonal to the point + // i.e., point^T * projected should be 0 + let product = (point.clone() * projected.clone()).sum(); + let max_entry = product.abs().max().into_scalar(); + assert!( + max_entry < TOLERANCE, + "Projected direction not in tangent space: absoulte value of the dot product = {}", + max_entry + ); + + assert!( + Sphere::is_tangent_at(point.clone(), projected.clone()).into_scalar(), + "Projecting something already in the tangent space stays in the tangent space" + ); + assert_tensor_close(&projected, &direction, TOLERANCE); + } + + #[test] + fn test_retraction_preserves_sphere_property() { + let point = create_test_matrix(6, vec![3.0 / 5.0, 0.0, 0.0, 4.0 / 5.0, 0.0, 0.0]); + + assert!( + Sphere::is_in_manifold(point.clone()).into_scalar(), + "This is a point on the sphere by construction" + ); + + let direction = create_test_matrix(6, vec![4.0 / 5.0, 0.2, 0.3, -3.0 / 5.0, 0.5, 0.6]); + + let moved = Sphere::retract(point, direction); + + assert!(Sphere::is_in_manifold(moved).into_scalar()); + } + + #[test] + fn test_parallel_transport() { + let point = create_test_matrix(6, vec![3.0 / 5.0, 0.0, 0.0, 4.0 / 5.0, 0.0, 0.0]); + + assert!( + Sphere::is_in_manifold(point.clone()).into_scalar(), + "This is a point on the sphere by construction" + ); + + let direction = create_test_matrix(6, vec![4.0 / 5.0, 0.2, 0.3, -3.0 / 5.0, 0.5, 0.6]); + + let moved_point = Sphere::retract(point.clone(), direction.clone()); + let moved_vector = Sphere::parallel_transport(point, moved_point.clone(), direction); + + assert!(Sphere::is_in_manifold(moved_point.clone()).into_scalar()); + assert!(Sphere::is_tangent_at(moved_point, moved_vector).into_scalar()); } } diff --git a/src/manifolds/steifiel.rs b/src/manifolds/steifiel.rs index 3c5be4c..60c7cb7 100644 --- a/src/manifolds/steifiel.rs +++ b/src/manifolds/steifiel.rs @@ -1,4 +1,4 @@ -use crate::prelude::*; +use crate::{manifolds::utils::identity_in_last_two, prelude::*}; #[derive(Debug, Clone, Default)] pub struct SteifielsManifold { @@ -6,6 +6,8 @@ pub struct SteifielsManifold { } impl Manifold for SteifielsManifold { + const RANK_PER_POINT: usize = 2; + fn new() -> Self { SteifielsManifold { _backend: std::marker::PhantomData, @@ -17,7 +19,7 @@ impl Manifold for SteifielsManifold { } /// Project direction onto tangent space at point - /// For Stiefel manifold: P_X(Z) = Z - X(X^T Z + Z^T X)/2 + /// For Stiefel manifold: `P_X(Z) = Z - X(X^T Z + Z^T X)/2` fn project(point: Tensor, direction: Tensor) -> Tensor { let xtd = point.clone().transpose().matmul(direction.clone()); let dtx = direction.clone().transpose().matmul(point.clone()); @@ -25,21 +27,79 @@ impl Manifold for SteifielsManifold { direction - point.matmul(symmetric_part) } - fn retract( - point: Tensor, - direction: Tensor, - ) -> Tensor { - let s = point + direction; - gram_schmidt(&s) + fn retract(point: Tensor, direction: Tensor) -> Tensor { + debug_assert!(point.dims().len() >= Self::RANK_PER_POINT); + debug_assert!(direction.dims().len() >= Self::RANK_PER_POINT); + let mut s = point + direction; + if s.dims().len() > Self::RANK_PER_POINT { + // Gram_schmidt as written does so on the first two coordinates + // unlike Matrix multiplication and the rest of tensor operations + // which is assuming the last two coordinates + // and the first bunch being channels instead of vice versa + s = s.swap_dims(0, D - 2); + s = s.swap_dims(1, D - 1); + s = gram_schmidt(&s); + s = s.swap_dims(1, D - 1); + s = s.swap_dims(0, D - 2); + s + } else { + gram_schmidt(&s) + } } fn inner( - _point: Tensor< B, D>, + _point: Tensor, u: Tensor, v: Tensor, ) -> Tensor { // For Stiefel manifold, we use the standard Euclidean inner product - u * v + (u * v).sum_dim(D - 1).sum_dim(D - 2) + } + + fn is_tangent_at( + point: Tensor, + vector: Tensor, + ) -> Tensor { + let xtv = point.clone().transpose().matmul(vector.clone()); + let vtx = vector.clone().transpose().matmul(point.clone()); + let skew = xtv + vtx.transpose(); + let max_skew = skew.clone().abs().max_dim(D - 1).max_dim(D - 2); + max_skew.lower_elem(1e-6) + } + + fn proj(mut point: Tensor) -> Tensor { + debug_assert!(point.dims().len() >= Self::RANK_PER_POINT); + if point.dims().len() > Self::RANK_PER_POINT { + // Gram_schmidt as written does so on the first two coordinates + // unlike Matrix multiplication and the rest of tensor operations + // which is assuming the last two coordinates + // and the first bunch being channels instead of vice versa + point = point.swap_dims(0, D - 2); + point = point.swap_dims(1, D - 1); + point = gram_schmidt(&point); + point = point.swap_dims(1, D - 1); + point = point.swap_dims(0, D - 2); + point + } else { + gram_schmidt(&point) + } + } + + fn is_in_manifold(point: Tensor) -> Tensor { + let a_transpose_times_a = point.clone().transpose().matmul(point); + let all_dims = a_transpose_times_a.shape(); + debug_assert!(all_dims.num_dims() >= 2); + let other = identity_in_last_two(&a_transpose_times_a); + a_transpose_times_a + .is_close(other, None, None) + .all_dim(D - 1) + .all_dim(D - 2) + } + + fn acceptable_dims(a_is: &[usize]) -> bool { + let n = a_is[0]; + let k = a_is[1]; + n > 0 && k > 0 && k <= n } } @@ -70,6 +130,9 @@ fn gram_schmidt(v: &Tensor) -> Tensor { #[cfg(test)] mod test { + use crate::manifolds::utils::test::{assert_matrix_close, create_test_matrix}; + use crate::optimizers::LessSimpleOptimizer; + use super::*; use burn::{ backend::{Autodiff, NdArray}, @@ -77,106 +140,9 @@ mod test { }; type TestBackend = Autodiff; - type TestTensor = Tensor; const TOLERANCE: f32 = 1e-6; - fn assert_tensor_close(a: &TestTensor, b: &TestTensor, tol: f32) { - let diff = (a.clone() - b.clone()).abs(); - let max_diff = diff.max().into_scalar(); - assert!( - max_diff < tol, - "Tensors differ by {}, tolerance: {}", - max_diff, - tol - ); - } - - fn create_test_matrix(rows: usize, cols: usize, values: Vec) -> TestTensor { - let device = Default::default(); - // Reshape the flat vector into a 2D array - let mut data = Vec::with_capacity(rows); - for chunk in values.chunks(cols) { - data.push(chunk.to_vec()); - } - - // Create tensor from nested arrays - match (rows, cols) { - (3, 2) => { - if data.len() >= 3 && data[0].len() >= 2 && data[1].len() >= 2 && data[2].len() >= 2 - { - Tensor::from_floats( - [ - [data[0][0], data[0][1]], - [data[1][0], data[1][1]], - [data[2][0], data[2][1]], - ], - &device, - ) - } else { - panic!("Invalid 3x2 matrix data"); - } - } - (3, 1) => { - if data.len() >= 3 - && !data[0].is_empty() - && !data[1].is_empty() - && !data[2].is_empty() - { - Tensor::from_floats([[data[0][0]], [data[1][0]], [data[2][0]]], &device) - } else { - panic!("Invalid 3x1 matrix data"); - } - } - (3, 3) => { - if data.len() >= 3 && data[0].len() >= 3 && data[1].len() >= 3 && data[2].len() >= 3 - { - Tensor::from_floats( - [ - [data[0][0], data[0][1], data[0][2]], - [data[1][0], data[1][1], data[1][2]], - [data[2][0], data[2][1], data[2][2]], - ], - &device, - ) - } else { - panic!("Invalid 3x3 matrix data"); - } - } - (4, 2) => { - if data.len() >= 4 - && data[0].len() >= 2 - && data[1].len() >= 2 - && data[2].len() >= 2 - && data[3].len() >= 2 - { - Tensor::from_floats( - [ - [data[0][0], data[0][1]], - [data[1][0], data[1][1]], - [data[2][0], data[2][1]], - [data[3][0], data[3][1]], - ], - &device, - ) - } else { - panic!("Invalid 4x2 matrix data"); - } - } - (2, 2) => { - if data.len() >= 2 && data[0].len() >= 2 && data[1].len() >= 2 { - Tensor::from_floats( - [[data[0][0], data[0][1]], [data[1][0], data[1][1]]], - &device, - ) - } else { - panic!("Invalid 2x2 matrix data"); - } - } - _ => panic!("Unsupported matrix dimensions: {}x{}", rows, cols), - } - } - #[test] fn test_manifold_creation() { let _manifold = SteifielsManifold::::new(); @@ -186,7 +152,7 @@ mod test { #[test] fn test_gram_schmidt_orthogonalization() { // Test with a simple 3x2 matrix - let input = create_test_matrix(3, 2, vec![1.0, 1.0, 1.0, 0.0, 0.0, 1.0]); + let input = create_test_matrix::(3, 2, vec![1.0, 1.0, 1.0, 0.0, 0.0, 1.0]); let result = gram_schmidt(&input); @@ -232,7 +198,7 @@ mod test { #[test] fn test_gram_schmidt_single_column() { // Test with a single column vector - let input = create_test_matrix(3, 1, vec![3.0, 4.0, 0.0]); + let input = create_test_matrix::(3, 1, vec![3.0, 4.0, 0.0]); let result = gram_schmidt(&input); // Should be normalized to unit length @@ -249,8 +215,8 @@ mod test { ); // Should be proportional to original vector - let expected = create_test_matrix(3, 1, vec![0.6, 0.8, 0.0]); - assert_tensor_close(&result, &expected, TOLERANCE); + let expected = create_test_matrix::(3, 1, vec![0.6, 0.8, 0.0]); + assert_matrix_close(&result, &expected, TOLERANCE); } #[test] @@ -286,7 +252,7 @@ mod test { // Project the tangent vector again let projected = SteifielsManifold::::project(point.clone(), tangent.clone()); // Should be unchanged (idempotent) - assert_tensor_close(&projected, &tangent, 1e-6); + assert_matrix_close(&projected, &tangent, 1e-6); // Check the tangent space property: X^T V + V^T X = 0 let xtv = point.clone().transpose().matmul(tangent.clone()); let vtx = tangent.clone().transpose().matmul(point.clone()); @@ -297,6 +263,10 @@ mod test { "Tangent space property violated: max skew = {}", max_skew ); + assert!( + SteifielsManifold::is_tangent_at(point, tangent).into_scalar(), + "Tangent space property violated: max skew unknown" + ) } #[test] @@ -309,7 +279,7 @@ mod test { let step = 0.1; let retracted = - SteifielsManifold::::retract(point.clone(), direction.clone()*step); + SteifielsManifold::::retract(point.clone(), direction.clone() * step); // Check that the result has orthonormal columns let q1 = retracted.clone().slice([0..3, 0..1]); @@ -347,15 +317,23 @@ mod test { "Second column not normalized after retraction: norm = {}", norm2 ); + + assert!(SteifielsManifold::::is_in_manifold(retracted) + .all() + .into_scalar()); } #[test] fn test_gram_schmidt_identity_matrix() { // Identity matrix should remain unchanged - let identity = create_test_matrix(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]); + let identity = create_test_matrix::( + 3, + 3, + vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], + ); let result = gram_schmidt(&identity); - assert_tensor_close(&result, &identity, TOLERANCE); + assert_matrix_close(&result, &identity, TOLERANCE); } #[test] @@ -374,7 +352,7 @@ mod test { let gram_matrix = point.clone().transpose().matmul(point.clone()); let identity = create_test_matrix(2, 2, vec![1.0, 0.0, 0.0, 1.0]); - assert_tensor_close(&gram_matrix, &identity, TOLERANCE); + assert_matrix_close(&gram_matrix, &identity, TOLERANCE); // Test projection and retraction preserve this property let direction = create_test_matrix(4, 2, vec![0.1, 0.0, 0.0, 0.1, 0.2, 0.3, -0.1, 0.2]); @@ -383,7 +361,7 @@ mod test { let retracted = SteifielsManifold::::retract(point.clone(), projected * 0.1); let retracted_gram = retracted.clone().transpose().matmul(retracted.clone()); - assert_tensor_close(&retracted_gram, &identity, TOLERANCE); + assert_matrix_close(&retracted_gram, &identity, TOLERANCE); } #[test] @@ -406,7 +384,9 @@ mod test { .matmul(x.clone()) .sum(); let grads = loss.backward(); - let x_grad = x.grad(&grads).unwrap(); + let x_grad = x + .grad(&grads) + .expect("The gradients do exist we just did loss.backwards()"); // Convert gradient to autodiff backend and ensure independent tensor let x_grad_data = x_grad.to_data(); let x_grad_ad = Tensor::::from_data(x_grad_data, &x.device()); @@ -420,9 +400,105 @@ mod test { } #[test] - fn test_simple_optimizer_step() { + fn test_optimiser_remove() { let optimiser = ManifoldRGD::, TestBackend>::default(); + let a = create_test_matrix(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0]); + + let mut x = Tensor::::random( + [3, 3], + burn::tensor::Distribution::Normal(1., 1.), + &a.device(), + ) + .require_grad(); + for _i in 0..100 { + let loss = x + .clone() + .transpose() + .matmul(a.clone()) + .matmul(x.clone()) + .sum(); + let mut grads = loss.backward(); + let x_grad = x + .grad_remove(&mut grads) + .expect("The gradients do exist we just did loss.backwards()"); + // Convert gradient to autodiff backend and ensure independent tensor + let x_grad_data = x_grad.to_data(); + let x_grad_ad = Tensor::::from_data(x_grad_data, &x.device()); + // Clone x to ensure independent tensor for optimizer + let x_clone = x.clone(); + let (new_x, _) = optimiser.step(0.1, x_clone, x_grad_ad, None); + x = new_x.detach().require_grad(); + println!("Loss: {}", loss); + } + println!("Optimised tensor: {}", x); + } + + #[test] + fn test_optimiser_many() { + let optimiser = ManifoldRGD::, TestBackend>::default(); + + let a = create_test_matrix(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0]); + + let mut x = Tensor::::random( + [3, 3], + burn::tensor::Distribution::Normal(1., 1.), + &a.device(), + ) + .require_grad(); + + fn grad_fn( + x: Tensor, 2>, + a: Tensor, 2>, + ) -> Tensor, 2> { + let loss = x.clone().transpose().matmul(a).matmul(x.clone()).sum(); + let mut grads = loss.backward(); + let x_grad = x + .grad_remove(&mut grads) + .expect("The gradients do exist we just did loss.backwards()"); + // Convert gradient to autodiff backend and ensure independent tensor + let x_grad_ad = Tensor::::from_data(x_grad.to_data(), &x.device()); + x_grad_ad + } + + let mut state = None; + let x_original: Tensor = + Tensor::::from_data(x.to_data(), &Default::default()); + let a_original: Tensor = + Tensor::::from_data(a.to_data(), &Default::default()); + let unoptimised_loss = x_original + .clone() + .transpose() + .matmul(a_original.clone()) + .matmul(x_original.clone()) + .sum() + .into_scalar(); + println!( + "Unoptimised tensor: {} with loss {}", + x_original, unoptimised_loss + ); + (x, state) = optimiser.many_steps(|_| 0.1, 100, |x| grad_fn(x, a.clone()), x, state); + assert!(state.is_none()); + let x_optimised: Tensor = + Tensor::::from_data(x.to_data(), &Default::default()); + let optimised_loss = x_optimised + .clone() + .transpose() + .matmul(a_original) + .matmul(x_optimised.clone()) + .sum() + .into_scalar(); + println!( + "Optimised tensor: {} with loss {}", + x_optimised, optimised_loss + ); + assert!(optimised_loss <= unoptimised_loss, + "The optimimisation should have lowered the loss function. It was {unoptimised_loss} before and {optimised_loss} after"); + } + + #[test] + fn test_simple_optimizer_step() { + let optimiser = ManifoldRGD::, TestBackend>::default(); // Create simple test tensors let point = create_test_matrix(3, 2, vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]); diff --git a/src/manifolds/utils.rs b/src/manifolds/utils.rs new file mode 100644 index 0000000..2645eee --- /dev/null +++ b/src/manifolds/utils.rs @@ -0,0 +1,338 @@ +use burn::{prelude::Backend, tensor::Tensor}; + +/// Given a tensor of shape `a_1 ... a_k x N x N` +/// create a tensor of the same shape +/// whose `i_1...i_k,m,n` entry is `1` if `m==n` and `0` otherwise +pub(crate) fn identity_in_last_two( + example: &Tensor, +) -> Tensor { + let shape: [usize; D] = example.shape().dims(); + debug_assert!(D >= 2); + debug_assert_eq!(shape[D - 1], shape[D - 2]); + let n = shape[D - 1]; + let identity = Tensor::eye(n, &example.device()); + // Broadcasting is right aligned + // so this will act like identity in shape (1,...1,N,N) + // and then broadcast those 1's to a_1 ... a_k + // even though that is not stated in the docs of `expand` + // More honestly phrased identity of rank 2 is a function NxN -> R + // and the return value of type a1xa2...xNxN -> R is being done as + // the precomposition with the map of sets a1x...akxNxN -> 1x...1xNxN -> NxN which is the + // terminal map and identity maps in each factor and then monoidal unit laws. + // Relying on this implicit extraneous structure is the reason do not have to do this manually like in + // the below implementation of `diag_i`. + identity.expand(example.shape()) +} + +/// Given a tensor of shape `a_1 ... a_k x N x N` +/// create a tensor of the same shape +/// whose `i_1...i_k,m,n` entry is `f(m)` if `m==n` and `0` otherwise +#[allow(dead_code)] +pub(crate) fn diag_i( + example: &Tensor, + diag_fun: impl Fn(usize) -> f32, +) -> Tensor { + let shape: [usize; D] = example.shape().dims(); + debug_assert!(D >= 2); + debug_assert_eq!(shape[D - 1], shape[D - 2]); + let n = shape[D - 1]; + let mut other = example.zeros_like(); + let mut ones_shape = [1usize; D]; + ones_shape[..(D - 2)].copy_from_slice(&shape[..(D - 2)]); + let ones_patch = Tensor::::ones(ones_shape, &example.device()); + for diag in 0..n { + let ranges: [_; D] = std::array::from_fn(|dim| { + if dim < D - 2 { + 0..shape[dim] + } else { + diag..diag + 1 + } + }); + other = other.slice_assign(ranges, ones_patch.clone().mul_scalar(diag_fun(diag))); + } + other +} + +/// Given a tensor of shape `l_1 .. l_a .. l_b .. l_D` +/// The rank is dropped to `D-2` +/// but because of the lack of const operations +/// we leave those dimensions there but with `1`s +/// in the shape instead. +/// `l_1 .. 1 .. 1 .. l_D` +pub(crate) fn ein_sum( + mut t: Tensor, + a: usize, + b: usize, +) -> Tensor { + debug_assert!(a < D); + debug_assert!(b < D); + debug_assert_ne!(a, b); + let t_shape: [usize; D] = t.shape().dims(); + debug_assert_eq!(t_shape[a], t_shape[b]); + if a != D - 2 || b != D - 1 { + t = t.swap_dims(a, D - 2); + t = t.swap_dims(b, D - 1); + } + let identity_last_two = identity_in_last_two(&t); + t = t.mul(identity_last_two); + t = t.sum_dim(D - 1).sum_dim(D - 2); + if a != D - 2 || b != D - 1 { + t = t.swap_dims(b, D - 1); + t = t.swap_dims(a, D - 2); + } + t +} + +/// Given a tensor of shape `a_1 ... a_k x N x N` +/// create a tensor of shape `a_1 ... a_k x 1 x 1` +/// whose `i_1...i_k,0,0` entry is the trace +/// of the `N x N` matrix with those previous `k` +/// fixed but the last two remaining free. +#[allow(dead_code)] +pub(crate) fn trace(t: Tensor) -> Tensor { + // burn::tensor::linalg looks like a module according to the docs, but + // doing `burn::tensor::linalg::trace` does not work + // Also it reduces the rank by 1 and do not have the const generics + // to do {D-1}. The docs there also are contradictory with D0 and D-1 + // because of this inability to do {D-1}. Creates things that compile + // but really should not, but is not visible about it. + // Easier just to use the 1's as here. + ein_sum(t, D - 2, D - 1) +} + +#[cfg(test)] +pub(crate) mod test { + + use burn::{backend::NdArray, prelude::Backend, tensor::Tensor}; + + use crate::manifolds::utils::{diag_i, ein_sum, identity_in_last_two}; + + pub(crate) fn assert_matrix_close( + a: &Tensor, + b: &Tensor, + tol: f32, + ) where + TestBackend: Backend, + ::FloatElem: PartialOrd, + { + let diff = (a.clone() - b.clone()).abs(); + let max_diff = diff.max().into_scalar(); + assert!( + max_diff < tol, + "Tensors differ by {}, tolerance: {}", + max_diff, + tol + ); + } + + pub(crate) fn create_test_matrix( + rows: usize, + cols: usize, + values: Vec, + ) -> Tensor { + debug_assert_ne!(rows, 0); + debug_assert_ne!(cols, 0); + if rows < cols { + return create_test_matrix(cols, rows, values).transpose(); + } + let device = Default::default(); + // Reshape the flat vector into a 2D array + let mut data = Vec::with_capacity(rows); + for chunk in values.chunks(cols) { + data.push(chunk.to_vec()); + } + + // Create tensor from nested arrays + match (rows, cols) { + (3, 2) => { + if data.len() >= 3 && data[0].len() >= 2 && data[1].len() >= 2 && data[2].len() >= 2 + { + Tensor::from_floats( + [ + [data[0][0], data[0][1]], + [data[1][0], data[1][1]], + [data[2][0], data[2][1]], + ], + &device, + ) + } else { + panic!("Invalid 3x2 matrix data"); + } + } + (3, 1) => { + if data.len() >= 3 + && !data[0].is_empty() + && !data[1].is_empty() + && !data[2].is_empty() + { + Tensor::from_floats([[data[0][0]], [data[1][0]], [data[2][0]]], &device) + } else { + panic!("Invalid 3x1 matrix data"); + } + } + (3, 3) => { + if data.len() >= 3 && data[0].len() >= 3 && data[1].len() >= 3 && data[2].len() >= 3 + { + Tensor::from_floats( + [ + [data[0][0], data[0][1], data[0][2]], + [data[1][0], data[1][1], data[1][2]], + [data[2][0], data[2][1], data[2][2]], + ], + &device, + ) + } else { + panic!("Invalid 3x3 matrix data"); + } + } + (4, 2) => { + if data.len() >= 4 + && data[0].len() >= 2 + && data[1].len() >= 2 + && data[2].len() >= 2 + && data[3].len() >= 2 + { + Tensor::from_floats( + [ + [data[0][0], data[0][1]], + [data[1][0], data[1][1]], + [data[2][0], data[2][1]], + [data[3][0], data[3][1]], + ], + &device, + ) + } else { + panic!("Invalid 4x2 matrix data"); + } + } + (2, 2) => { + if data.len() >= 2 && data[0].len() >= 2 && data[1].len() >= 2 { + Tensor::from_floats( + [[data[0][0], data[0][1]], [data[1][0], data[1][1]]], + &device, + ) + } else { + panic!("Invalid 2x2 matrix data"); + } + } + (2, 1) => { + if data.len() >= 2 && !data[0].is_empty() && !data[1].is_empty() { + Tensor::from_floats([[data[0][0]], [data[1][0]]], &device) + } else { + panic!("Invalid 2x1 matrix data"); + } + } + (1, 1) => { + if data.len() >= 1 && !data[0].is_empty() { + Tensor::from_floats([[data[0][0]]], &device) + } else { + panic!("Invalid 1x1 matrix data"); + } + } + _ => panic!("Unsupported matrix dimensions: {}x{}", rows, cols), + } + } + + #[test] + fn small_einsum() { + let mat = create_test_matrix::( + 3, + 3, + vec![3.0, 4.0, 5.0, 6.0, 7.0, 3.0, -10.0, -4.0, -1.0], + ); + let ein_summed = ein_sum(mat.clone(), 0, 1); + assert_eq!(ein_summed.shape().dims(), [1, 1]); + let scalar = ein_summed.into_scalar(); + assert!((scalar - 9.0).abs() <= 1e-6, "{}", scalar); + } + + #[test] + fn identity_test() { + { + let mat = create_test_matrix::( + 3, + 3, + vec![3.0, 4.0, 5.0, 6.0, 7.0, 3.0, -10.0, -4.0, -1.0], + ); + let mat = mat.expand([3, 3]); + let identity_mat = identity_in_last_two(&mat); + let expected = Tensor::eye(3, &identity_mat.device()); + assert_matrix_close(&identity_mat, &expected, 1e-6); + } + let mat = create_test_matrix::( + 3, + 3, + vec![3.0, 4.0, 5.0, 6.0, 7.0, 3.0, -10.0, -4.0, -1.0], + ); + let expanded_shape = [3, 3, 3, 3, 3]; + let mat = mat.expand(expanded_shape); + let identity_mat = identity_in_last_two(&mat); + for idx in 0..expanded_shape[0] { + for jdx in 0..expanded_shape[1] { + for kdx in 0..expanded_shape[2] { + let slice = identity_mat + .clone() + .slice([idx..idx + 1, jdx..jdx + 1, kdx..kdx + 1, 0..3, 0..3]) + .reshape([3, 3]); + let expected = Tensor::eye(3, &slice.device()); + assert_matrix_close(&slice, &expected, 1e-6); + } + } + } + let mat = create_test_matrix::( + 3, + 3, + vec![3.0, 4.0, 5.0, 6.0, 7.0, 3.0, -10.0, -4.0, -1.0], + ); + let expanded_shape = [29, 483, 2, 3, 3]; + let mat = mat.expand(expanded_shape); + let identity_mat = identity_in_last_two(&mat); + for idx in 0..expanded_shape[0] { + for jdx in 0..expanded_shape[1] { + for kdx in 0..expanded_shape[2] { + let slice = identity_mat + .clone() + .slice([idx..idx + 1, jdx..jdx + 1, kdx..kdx + 1, 0..3, 0..3]) + .reshape([3, 3]); + let expected = Tensor::eye(3, &slice.device()); + assert_matrix_close(&slice, &expected, 1e-6); + } + } + } + } + + #[test] + fn diag_test() { + let diag_entries = [2.0, 7.0, 9.0]; + let expected = + create_test_matrix::(3, 3, vec![2.0, 0.0, 0.0, 0.0, 7.0, 0.0, 0.0, 0.0, 9.0]); + let expanded_shape = [3, 3, 3, 3, 3]; + let mat = expected.clone().expand(expanded_shape); + let identity_mat = diag_i(&mat, |i| diag_entries[i]); + for idx in 0..expanded_shape[0] { + for jdx in 0..expanded_shape[1] { + for kdx in 0..expanded_shape[2] { + let slice = identity_mat + .clone() + .slice([idx..idx + 1, jdx..jdx + 1, kdx..kdx + 1, 0..3, 0..3]) + .reshape([3, 3]); + assert_matrix_close(&slice, &expected, 1e-6); + } + } + } + let expanded_shape = [10, 9, 20, 3, 3]; + let mat = expected.clone().expand(expanded_shape); + let identity_mat = diag_i(&mat, |i| diag_entries[i]); + for idx in 0..expanded_shape[0] { + for jdx in 0..expanded_shape[1] { + for kdx in 0..expanded_shape[2] { + let slice = identity_mat + .clone() + .slice([idx..idx + 1, jdx..jdx + 1, kdx..kdx + 1, 0..3, 0..3]) + .reshape([3, 3]); + assert_matrix_close(&slice, &expected, 1e-6); + } + } + } + } +} diff --git a/src/optimizers.rs b/src/optimizers.rs index 4bcf6b4..5634904 100644 --- a/src/optimizers.rs +++ b/src/optimizers.rs @@ -2,17 +2,20 @@ //! //! This module provides Riemannian optimization algorithms that work on manifolds, //! extending classical optimization methods to handle geometric constraints. +use crate::prelude::*; use burn::module::AutodiffModule; use burn::optim::{adaptor::OptimizerAdaptor, LrDecayState, SimpleOptimizer}; use burn::record::Record; use burn::tensor::backend::AutodiffBackend; use burn::LearningRate; use std::marker::PhantomData; -use crate::prelude::*; +pub mod hessian_optimizer; +pub mod many_steps; pub mod multiple; +pub use many_steps::LessSimpleOptimizer; #[derive(Debug)] -pub struct ManifoldRGDConfig { +pub struct ManifoldRGDConfig, B: Backend> { _manifold: PhantomData, _backend: PhantomData, } @@ -74,10 +77,17 @@ where } fn to_device( - _state: Self::State, - _device: &::Device, + state: Self::State, + device: &::Device, ) -> Self::State { - _state + const DECAY_STATE_TO_DEVICE: bool = false; + if DECAY_STATE_TO_DEVICE { + ManifoldRGDState { + lr_decay: state.lr_decay.to_device(device), + } + } else { + state + } } } @@ -86,6 +96,7 @@ where M: Manifold, B: Backend, { + #[must_use] pub fn init>( &self, ) -> OptimizerAdaptor, Mod, Back> @@ -117,7 +128,7 @@ where /// .with_amsgrad(true); /// ``` #[derive(Debug, Clone)] -pub struct RiemannianAdamConfig { +pub struct RiemannianAdamConfig, B: Backend> { pub lr: f64, pub beta1: f64, pub beta2: f64, @@ -154,40 +165,48 @@ where M: Manifold, B: Backend, { + #[must_use] pub fn new() -> Self { Self::default() } + #[must_use] pub fn with_lr(mut self, lr: f64) -> Self { self.lr = lr; self } + #[must_use] pub fn with_beta1(mut self, beta1: f64) -> Self { self.beta1 = beta1; self } + #[must_use] pub fn with_beta2(mut self, beta2: f64) -> Self { self.beta2 = beta2; self } + #[must_use] pub fn with_eps(mut self, eps: f64) -> Self { self.eps = eps; self } + #[must_use] pub fn with_weight_decay(mut self, weight_decay: f64) -> Self { self.weight_decay = weight_decay; self } + #[must_use] pub fn with_amsgrad(mut self, amsgrad: bool) -> Self { self.amsgrad = amsgrad; self } + #[must_use] pub fn with_stabilize(mut self, stabilize: Option) -> Self { self.stabilize = stabilize; self @@ -205,6 +224,7 @@ where M: Manifold, B: Backend, { + #[must_use] pub fn new(config: RiemannianAdamConfig) -> Self { Self { config } } @@ -268,12 +288,15 @@ where state.exp_avg = state.exp_avg.clone() * self.config.beta1 + rgrad.clone() * (1.0 - self.config.beta1); - let inner_product = M::inner(tensor.clone(), rgrad.clone(), rgrad.clone()); - state.exp_avg_sq = state.exp_avg_sq.clone() * self.config.beta2 + inner_product * (1.0 - self.config.beta2); + let inner_product = M::inner::(tensor.clone(), rgrad.clone(), rgrad.clone()); + state.exp_avg_sq = state.exp_avg_sq.clone() * self.config.beta2 + + inner_product * (1.0 - self.config.beta2); // Compute denominator let denom = if self.config.amsgrad { - let max_exp_avg_sq = state.max_exp_avg_sq.as_ref().unwrap(); + let max_exp_avg_sq = state.max_exp_avg_sq.as_ref().expect( + "On an initial None state, having config.amsgrad be True makes this maximum field set to 0. \ + If there was an input state then it will be present because of earlier steps"); let new_max = Tensor::max_pair(max_exp_avg_sq.clone(), state.exp_avg_sq.clone()); state.max_exp_avg_sq = Some(new_max.clone()); new_max.sqrt() + self.config.eps @@ -282,7 +305,9 @@ where }; // Bias correction + #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)] let bias_correction1 = 1.0 - self.config.beta1.powi(state.step as i32); + #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)] let bias_correction2 = 1.0 - self.config.beta2.powi(state.step as i32); let step_size = learning_rate * bias_correction2.sqrt() / bias_correction1; @@ -319,6 +344,7 @@ where M: Manifold, B: Backend, { + #[must_use] pub fn init>( &self, ) -> OptimizerAdaptor, Mod, Back> @@ -419,10 +445,11 @@ mod tests { // Check that AMSGrad state is initialized assert!(state.is_some()); - let state = state.unwrap(); + let state = + state.expect("RiemannianAdam optimizer always gives back an initialized state on step"); assert!( state.max_exp_avg_sq.is_some(), - "AMSGrad should initialize max_exp_avg_sq" + "AMSGrad should initialize max_exp_avg_sq. See the explanation around the compute denominator part of step" ); } @@ -461,13 +488,16 @@ mod tests { // First step let (tensor1, state1) = optimizer.step(1.0, tensor, grad.clone(), None); assert!(state1.is_some()); - let state1 = state1.unwrap(); + let state1 = state1.expect( + "RiemannianAdam optimizer always gives back an initialized state on step even with None initial state"); assert_eq!(state1.step, 1); // Second step with state let (_, state2) = optimizer.step(1.0, tensor1, grad, Some(state1)); assert!(state2.is_some()); - let state2 = state2.unwrap(); + let state2 = state2.expect( + "There was an input state so RiemannianAdam optimizer's step modifies that and returns it" + ); assert_eq!(state2.step, 2); } } diff --git a/src/optimizers/hessian_optimizer.rs b/src/optimizers/hessian_optimizer.rs new file mode 100644 index 0000000..6c52ad8 --- /dev/null +++ b/src/optimizers/hessian_optimizer.rs @@ -0,0 +1,137 @@ +use burn::{ + optim::SimpleOptimizer, prelude::Backend, record::Record, tensor::Tensor, LearningRate, +}; + +/// TODO document and construct an implementation +pub trait SimpleHessianOptimizer: Send + Sync + Clone + SimpleOptimizer +where + B: Backend, +{ + /// The state of the optimizer. It also implements [record](Record), so that it can be saved. + type StateWithHessian: Record + + Clone + + From> + + Into> + + 'static; + + /// The optimizer step is performed for one tensor at a time with its gradient, hessian and state. + /// + /// Note that the state is passed as parameter, so implementations don't have to handle + /// the saving and loading of recorded states. + fn step_with_hessian( + &self, + lr: LearningRate, + tensor: Tensor, + grad: Tensor, + hessian: Tensor, + state: Option>, + ) -> (Tensor, Option>); + + /// The optimizer step is performed for one tensor at a time with its gradient, hessian and state. + /// + /// Note that the state is passed as parameter, so implementations don't have to handle + /// the saving and loading of recorded states. + fn step( + &self, + lr: LearningRate, + tensor: Tensor, + grad: Tensor, + hessian: Option>, + state: Option>, + ) -> (Tensor, Option>) { + if let Some(hessian) = hessian { + self.step_with_hessian(lr, tensor, grad, hessian, state) + } else { + let (new_pt, new_state) = + SimpleOptimizer::step(self, lr, tensor, grad, state.map(Into::into)); + (new_pt, new_state.map(Into::into)) + } + } + + /// Change the device of the state. + /// + /// This function will be called accordindly to have the state on the same device as the + /// gradient and the tensor when the [step](SimpleOptimizer::step) function is called. + fn to_device( + state: Self::StateWithHessian, + device: &B::Device, + ) -> Self::StateWithHessian; +} + +/// A `SimpleOptimizer` also works as a `SimpleHessianOptimizer` by ignoring the Hessian information +impl> SimpleHessianOptimizer for T { + type StateWithHessian = >::State; + + fn step_with_hessian( + &self, + lr: LearningRate, + tensor: Tensor, + grad: Tensor, + _hessian: Tensor, + state: Option>, + ) -> (Tensor, Option>) { + self.step(lr, tensor, grad, state) + } + + fn to_device( + state: Self::StateWithHessian, + device: &::Device, + ) -> Self::StateWithHessian { + >::to_device(state, device) + } +} + +/// A optimizer that allows for many steps with a given learning schedule +/// and a way of evaluating the gradient and hessian functions on arbitrary +/// points. This way we can step using `SimpleHessianOptimizer::step` with +/// that gradient and hessian several times. +pub trait LessSimpleHessianOptimizer: SimpleHessianOptimizer { + fn many_steps( + &self, + lr_function: impl FnMut(usize) -> LearningRate, + num_steps: usize, + grad_function: impl FnMut(Tensor) -> Tensor, + hessian_function: impl FnMut(Tensor) -> Option>, + tensor: Tensor, + state: Option>, + ) -> (Tensor, Option>); +} + +/// The implementation of `LessSimpleHessianOptimizer` is completely determined +/// by how `SimpleHessianOptimizer` has been implemented because we are +/// just taking gradients using the input `grad_function` and hessians with `hessian_function` +/// and steping with `SimpleHessianOptimizer::step` +impl> LessSimpleHessianOptimizer for T { + #[inline] + fn many_steps( + &self, + mut lr_function: impl FnMut(usize) -> LearningRate, + num_steps: usize, + mut grad_function: impl FnMut(Tensor) -> Tensor, + mut hessian_function: impl FnMut(Tensor) -> Option>, + mut tensor: Tensor, + mut state: Option>, + ) -> (Tensor, Option>) { + // Perform optimization steps + for step in 0..num_steps { + // Compute gradient at tensor + let cur_grad = grad_function(tensor.clone()); + // Compute hessian at tensor + let cur_hessian = hessian_function(tensor.clone()); + // The current learning rate for this step + let cur_lr = lr_function(step); + // Perform optimizer step + let (new_x, new_state) = SimpleHessianOptimizer::step( + self, + cur_lr, + tensor.clone(), + cur_grad, + cur_hessian, + state, + ); + tensor = new_x.detach().require_grad(); + state = new_state; + } + (tensor, state) + } +} diff --git a/src/optimizers/many_steps.rs b/src/optimizers/many_steps.rs new file mode 100644 index 0000000..2c53c52 --- /dev/null +++ b/src/optimizers/many_steps.rs @@ -0,0 +1,50 @@ +//! A optimizer that allows for many steps with a given learning schedule +//! and a way of evaluating the gradient function on arbitrary +//! points. This way we can step using `SimpleOptimizer::step` with that gradient +//! several times. +use crate::prelude::*; +use burn::{optim::SimpleOptimizer, LearningRate}; + +/// A optimizer that allows for many steps with a given learning schedule +/// and a way of evaluating the gradient function on arbitrary +/// points. This way we can step using `SimpleOptimizer::step` with that gradient +/// several times. +pub trait LessSimpleOptimizer: SimpleOptimizer { + fn many_steps( + &self, + lr_function: impl FnMut(usize) -> LearningRate, + num_steps: usize, + grad_function: impl FnMut(Tensor) -> Tensor, + tensor: Tensor, + state: Option>, + ) -> (Tensor, Option>); +} + +/// The implementation of `LessSimpleOptimizer` is completely determined +/// by how `SimpleOptimizer` has been implemented because we are +/// just taking gradients using the input `grad_function` and steping with +/// `SimpleOptimizer::step` +impl> LessSimpleOptimizer for T { + #[inline] + fn many_steps( + &self, + mut lr_function: impl FnMut(usize) -> LearningRate, + num_steps: usize, + mut grad_function: impl FnMut(Tensor) -> Tensor, + mut tensor: Tensor, + mut state: Option>, + ) -> (Tensor, Option>) { + // Perform optimization steps + for step in 0..num_steps { + // Compute gradient at tensor + let cur_grad = grad_function(tensor.clone()); + // The current learning rate for this step + let cur_lr = lr_function(step); + // Perform optimizer step + let (new_x, new_state) = self.step(cur_lr, tensor.clone(), cur_grad, state); + tensor = new_x.detach().require_grad(); + state = new_state; + } + (tensor, state) + } +} diff --git a/src/optimizers/multiple.rs b/src/optimizers/multiple.rs index 1c6c887..7932d31 100644 --- a/src/optimizers/multiple.rs +++ b/src/optimizers/multiple.rs @@ -2,17 +2,17 @@ //! //! This module provides optimizers that can handle multiple manifold types within //! a single optimization step, allowing complex models with heterogeneous constraints. -//! +//! //! Users can implement their own manifolds and use them with the multi-manifold optimizer. use std::collections::HashMap; -use std::marker::PhantomData; use std::fmt::Debug; +use std::marker::PhantomData; use burn::module::Module; +use crate::constrained_module::Constrained; use crate::prelude::*; -use crate::manifolds::Constrained; /// Multi-manifold optimizer configuration #[derive(Debug, Clone)] @@ -39,11 +39,13 @@ impl Default for MultiManifoldOptimizerConfig { /// Multi-manifold optimizer that can handle different manifold types #[derive(Debug)] pub struct MultiManifoldOptimizer { + #[allow(unused)] config: MultiManifoldOptimizerConfig, _backend: PhantomData, } impl MultiManifoldOptimizer { + #[must_use] pub fn new(config: MultiManifoldOptimizerConfig) -> Self { Self { config, @@ -73,8 +75,9 @@ impl MultiManifoldOptimizer { /// Extension trait for modules with manifold constraints pub trait ManifoldOptimizable: Module { /// Apply manifold constraints to the module + #[must_use] fn apply_manifold_constraints(self) -> Self; - + /// Get information about manifold constraints fn get_manifold_info(&self) -> HashMap; } @@ -90,7 +93,7 @@ where // Apply constraints to the inner module and wrap it back self } - + fn get_manifold_info(&self) -> HashMap { let mut info = HashMap::new(); info.insert("manifold_type".to_string(), Man::name().to_string()); @@ -106,12 +109,11 @@ mod tests { type TestBackend = NdArray; - #[test] fn test_multi_manifold_optimizer() { let config = MultiManifoldOptimizerConfig::default(); let optimizer = MultiManifoldOptimizer::::new(config); - + // Test basic construction assert_eq!(optimizer.config.learning_rate, 1e-3); } @@ -121,7 +123,7 @@ mod tests { let device = Default::default(); let linear = LinearConfig::new(2, 2).init::(&device); let constrained_linear = Constrained::<_, Euclidean>::new(linear); - + let info = constrained_linear.get_manifold_info(); assert_eq!(info.get("manifold_type"), Some(&"Euclidean".to_string())); } @@ -130,15 +132,18 @@ mod tests { fn test_apply_constraints() { let config = MultiManifoldOptimizerConfig::default(); let optimizer = MultiManifoldOptimizer::::new(config); - + let device = Default::default(); let linear = LinearConfig::new(2, 2).init::(&device); let constrained_linear = Constrained::<_, Euclidean>::new(linear); - + // Test applying constraints let result = optimizer.apply_constraints(constrained_linear); - + // Should return the same module since we have a simplified implementation - assert_eq!(result.get_manifold_info().get("manifold_type"), Some(&"Euclidean".to_string())); + assert_eq!( + result.get_manifold_info().get("manifold_type"), + Some(&"Euclidean".to_string()) + ); } -} \ No newline at end of file +}