Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 108 additions & 37 deletions examples/multi_constraints.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use burn::module::Module;
use burn::module::ModuleVisitor;
use burn::nn::LinearConfig;
use burn::nn::Linear;
use burn::module::Module;
use burn::nn::LinearConfig;

use manopt_rs::manifolds::Constrained;
use manopt_rs::optimizers::multiple::{MultiManifoldOptimizer, MultiManifoldOptimizerConfig, ManifoldOptimizable};
use manopt_rs::constrained_module::Constrained;
use manopt_rs::optimizers::multiple::{
ManifoldOptimizable, MultiManifoldOptimizer, MultiManifoldOptimizerConfig,
};
use manopt_rs::prelude::*;

// Example: User-defined custom manifold
#[derive(Clone, Debug)]
pub struct CustomSphereManifold;

impl<B: Backend> Manifold<B> for CustomSphereManifold {
const RANK_PER_POINT: usize = 1;

fn new() -> Self {
Self
}
Expand All @@ -22,38 +26,70 @@ impl<B: Backend> Manifold<B> for CustomSphereManifold {

fn project<const D: usize>(point: Tensor<B, D>, vector: Tensor<B, D>) -> Tensor<B, D> {
// For sphere: project vector orthogonal to point
let dot_product = (point.clone() * vector.clone()).sum();
vector - point * dot_product.unsqueeze()
debug_assert!(point.shape() == vector.shape());
let dot_product =
(point.clone() * vector.clone()).sum_dim(D - <Self as Manifold<B>>::RANK_PER_POINT);
vector - point * dot_product
}

fn retract<const D: usize>(point: Tensor<B, D>, direction: Tensor<B, D>) -> Tensor<B, D> {
// For sphere: normalize the result
debug_assert!(point.shape() == direction.shape());
let new_point = point + direction;
let norm = new_point.clone().powf_scalar(2.0).sum().sqrt().unsqueeze();
let norm = new_point
.clone()
.powf_scalar(2.0)
.sum_dim(D - <Self as Manifold<B>>::RANK_PER_POINT)
.sqrt();
new_point / norm
}

fn inner<const D: usize>(_point: Tensor<B, D>, u: Tensor<B, D>, v: Tensor<B, D>) -> Tensor<B, D> {
u * v
fn inner<const D: usize>(
_point: Tensor<B, D>,
u: Tensor<B, D>,
v: Tensor<B, D>,
) -> Tensor<B, D> {
(u * v).sum_dim(D - <Self as Manifold<B>>::RANK_PER_POINT)
}

fn proj<const D: usize>(point: Tensor<B, D>) -> Tensor<B, D> {
// Project point onto unit sphere
let norm = point.clone().powf_scalar(2.0).sum().sqrt().unsqueeze();
let norm = point
.clone()
.powf_scalar(2.0)
.sum_dim(D - <Self as Manifold<B>>::RANK_PER_POINT)
.sqrt();
point / norm
}

fn is_in_manifold<const D: usize>(_point: Tensor<B, D>) -> bool {
// For now, just return true - in a real implementation you'd check the constraint
true
fn is_in_manifold<const D: usize>(point: Tensor<B, D>) -> Tensor<B, D, Bool> {
let r_squared = point
.powf_scalar(2.0)
.sum_dim(D - <Self as Manifold<B>>::RANK_PER_POINT);
let one = r_squared.ones_like();
r_squared.is_close(one, None, None)
}

fn is_tangent_at<const D: usize>(
point: Tensor<B, D>,
vector: Tensor<B, D>,
) -> Tensor<B, D, Bool> {
let dot_product = (point * vector).sum_dim(D - <Self as Manifold<B>>::RANK_PER_POINT);
let zeros = dot_product.zeros_like();
dot_product.is_close(zeros, None, Some(1e-6))
}

fn acceptable_dims(a_is: &[usize]) -> bool {
let n = *a_is.first().expect("The ambient R^n does exist");
n > 0
}
}

#[derive(Debug, Clone)]
pub struct TestModel<B: Backend> {
// Euclidean constrained linear layer
linear_euclidean: Constrained<Linear<B>, Euclidean>,
// Custom sphere constrained linear layer
// Custom sphere constrained linear layer
linear_sphere: Constrained<Linear<B>, CustomSphereManifold>,
// Regular unconstrained linear layer
linear_regular: Linear<B>,
Expand Down Expand Up @@ -125,7 +161,7 @@ impl<B: Backend> TestModel<B> {
let linear_euclidean = LinearConfig::new(10, 5).init(device);
let linear_sphere = LinearConfig::new(5, 3).init(device);
let linear_regular = LinearConfig::new(3, 1).init(device);

Self {
linear_euclidean: Constrained::new(linear_euclidean),
linear_sphere: Constrained::new(linear_sphere),
Expand All @@ -138,66 +174,101 @@ struct ManifoldAwareVisitor;

impl<B: Backend> ModuleVisitor<B> for ManifoldAwareVisitor {
fn visit_float<const D: usize>(&mut self, id: burn::module::ParamId, tensor: &Tensor<B, D>) {
println!("Visiting parameter: {:?} with shape: {:?}", id, tensor.dims());
println!(
"Visiting parameter: {:?} with shape: {:?}",
id,
tensor.dims()
);
}

fn visit_int<const D: usize>(&mut self, _id: burn::module::ParamId, _tensor: &Tensor<B, D, Int>) {}
fn visit_int<const D: usize>(
&mut self,
_id: burn::module::ParamId,
_tensor: &Tensor<B, D, Int>,
) {
}

fn visit_bool<const D: usize>(&mut self, _id: burn::module::ParamId, _tensor: &Tensor<B, D, Bool>) {}
fn visit_bool<const D: usize>(
&mut self,
_id: burn::module::ParamId,
_tensor: &Tensor<B, D, Bool>,
) {
}
}

fn main() {
type MyBackend = burn::backend::NdArray;
type AutoDiffBackend = burn::backend::Autodiff<MyBackend>;

let device = Default::default();

// Create a model with mixed manifold constraints
let model = TestModel::<AutoDiffBackend>::new(&device);

println!("=== Model Structure ===");
println!("Euclidean layer manifold: {}", model.linear_euclidean.manifold_name::<AutoDiffBackend>());
println!("Sphere layer manifold: {}", model.linear_sphere.manifold_name::<AutoDiffBackend>());

println!(
"Euclidean layer manifold: {}",
model.linear_euclidean.manifold_name::<AutoDiffBackend>()
);
println!(
"Sphere layer manifold: {}",
model.linear_sphere.manifold_name::<AutoDiffBackend>()
);

// Create multi-manifold optimizer
let config = MultiManifoldOptimizerConfig::default();
let mut optimizer = MultiManifoldOptimizer::new(config);

// Collect manifold information from the model
optimizer.collect_manifolds(&model);

// Register custom manifold for specific parameters (if needed)
optimizer.register_manifold::<CustomSphereManifold>("linear_sphere.weight".to_string());

println!("\n=== Manifold Information ===");
println!("Euclidean info: {:?}", model.linear_euclidean.get_manifold_info());
println!(
"Euclidean info: {:?}",
model.linear_euclidean.get_manifold_info()
);
println!("Sphere info: {:?}", model.linear_sphere.get_manifold_info());

// Example of applying constraints
let constrained_model = optimizer.apply_constraints(model);

// Visit the model to see parameter structure
println!("\n=== Parameter Structure ===");
let mut visitor = ManifoldAwareVisitor;
constrained_model.visit(&mut visitor);

println!("\n=== Demonstrating Custom Manifold Operations ===");

// Show how the custom sphere manifold works
let point = Tensor::<MyBackend, 1>::from_floats([3.0, 4.0, 0.0], &device);
let vector = Tensor::<MyBackend, 1>::from_floats([1.0, 1.0, 1.0], &device);

println!("Original point: {:?}", point.to_data());
println!("Original vector: {:?}", vector.to_data());

// Project point to sphere
let projected_point = CustomSphereManifold::proj(point.clone());
println!("Point projected to sphere: {:?}", projected_point.to_data());

// Project vector to tangent space
let projected_vector = CustomSphereManifold::project(projected_point.clone(), vector);
println!("Vector projected to tangent space: {:?}", projected_vector.to_data());

println!(
"Vector projected to tangent space: {:?}",
projected_vector.to_data()
);

// Check if point is on manifold
println!("Is projected point on sphere? {}", CustomSphereManifold::is_in_manifold(projected_point));
println!(
"Is projected point on sphere? {}",
CustomSphereManifold::is_in_manifold(projected_point.clone())
);

// Check if vector is tangent at point on manifold
println!(
"Is projected vector tangent to point on sphere? {}",
CustomSphereManifold::is_tangent_at(projected_point, projected_vector)
);
}
35 changes: 31 additions & 4 deletions examples/optimization_demo.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::collections::HashMap;

use burn::optim::SimpleOptimizer;
use manopt_rs::prelude::*;
use manopt_rs::{optimizers::LessSimpleOptimizer, prelude::*};

fn main() {
// Configure the optimizer
Expand All @@ -18,6 +20,8 @@ fn main() {
Tensor::<burn::backend::NdArray, 1>::from_floats([0.0, 0.0, 0.0], &Default::default());
let mut state = None;

let mut loss_decay: HashMap<usize, f32> = HashMap::new();

println!("Target: {}", target);
println!("Initial x: {}", x);
println!("\nOptimization steps:");
Expand All @@ -35,13 +39,36 @@ fn main() {
// Print progress every 10 steps
if step % 10 == 0 {
let loss = (x.clone() - target.clone()).powf_scalar(2.0).sum();
println!("Step {}: x = {}, loss = {}", step, x, loss);
let loss_scalar = loss.into_scalar();
println!("Step {}: x = {}, loss = {:.5}", step, x, loss_scalar);
loss_decay.insert(step, loss_scalar);
}
}

println!("\nResult after 100:");
println!("x = {}", x);
println!("Target = {}", target);
let final_loss = (x.clone() - target.clone())
.powf_scalar(2.0)
.sum()
.into_scalar();
println!("Loss after 100 = {:.5}", final_loss);
loss_decay.insert(100, final_loss);

// Perform optimization steps
(x, state) = optimizer.many_steps(|_| 1.0, 400, |x| (x - target.clone()) * 2.0, x, state);

println!("\nFinal result:");
println!("x = {}", x);
println!("Target = {}", target);
let final_loss = (x.clone() - target.clone()).powf_scalar(2.0).sum();
println!("Final loss = {}", final_loss);
let final_loss = (x.clone() - target.clone())
.powf_scalar(2.0)
.sum()
.into_scalar();
println!("Final loss = {:.5}", final_loss);
println!("State is set {}", state.is_some());
loss_decay.insert(500, final_loss);
let mut sorted_losses: Vec<(usize, f32)> = loss_decay.into_iter().collect();
sorted_losses.sort_by_key(|z| z.0);
println!("The loss decayed as follows: {:?}", sorted_losses);
}
16 changes: 14 additions & 2 deletions examples/riemannian_adam_demo.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use burn::optim::SimpleOptimizer;
use manopt_rs::prelude::*;
use manopt_rs::{optimizers::LessSimpleOptimizer, prelude::*};

fn main() {
println!("Testing Riemannian Adam optimizer...");
Expand All @@ -20,8 +20,20 @@ fn main() {
let (new_tensor, state) = optimizer.step(1.0, tensor.clone(), grad, None);

println!("Original tensor: {}", tensor);
println!("New tensor: {}", new_tensor);
println!("New tensor: {:.4}", new_tensor);
println!("State initialized: {}", state.is_some());

fn grad_function<B: Backend>(_x: Tensor<B, 2>) -> Tensor<B, 2> {
Tensor::<B, 2>::ones([2, 2], &Default::default())
}
// Perform more optimization steps
let (new_tensor, state) =
optimizer.many_steps(|_| 1.0, 9, grad_function, tensor.clone(), state);
println!("After even more steps Tensor: {:.4}", new_tensor);
println!(
"After even more steps State initialized: {}",
state.is_some()
);

println!("Riemannian Adam test completed successfully!");
}
Loading