diff --git a/platforms/cuda/src/CudaPyTorchKernelsE2EDiffConf.cpp b/platforms/cuda/src/CudaPyTorchKernelsE2EDiffConf.cpp index 2b93635..3d8b7f8 100644 --- a/platforms/cuda/src/CudaPyTorchKernelsE2EDiffConf.cpp +++ b/platforms/cuda/src/CudaPyTorchKernelsE2EDiffConf.cpp @@ -58,6 +58,87 @@ static void* getTensorPointer(OpenMM::CudaContext& cu, torch::Tensor& tensor) { return data; } +// -------------------- Least-squares rigid-rotation remover (memory-efficient) -------------------- +// Returns corrected_dX (N,3) +torch::Tensor remove_rigid_rotation_lstsq_loop( + const torch::Tensor& coords_in, + const torch::Tensor& dX_in, + bool center = true, + double lambda_reg = 1e-8 +) { + TORCH_CHECK(coords_in.dim() == 2 && coords_in.size(1) == 3, "coords must be (N,3)"); + TORCH_CHECK(dX_in.dim() == 2 && dX_in.size(1) == 3, "dX must be (N,3)"); + TORCH_CHECK(coords_in.size(0) == dX_in.size(0), "coords and dX must have same N"); + + auto device = coords_in.device(); + auto dtype = coords_in.dtype(); + const int64_t N = coords_in.size(0); + + // center if requested + auto coords = coords_in; + auto dX = dX_in; + if (center) { + auto coords_mean = coords_in.mean(0, /*keepdim=*/true); + auto dX_mean = dX_in.mean(0, /*keepdim=*/true); + coords = coords_in - coords_mean; + dX = dX_in - dX_mean; + } + + // Prepare accumulators on device/dtype + auto I3 = torch::eye(3, torch::TensorOptions().device(device).dtype(dtype)); + auto A = torch::zeros({3,3}, torch::TensorOptions().device(device).dtype(dtype)); + auto b = torch::zeros({3}, torch::TensorOptions().device(device).dtype(dtype)); + + // Loop accumulate A and b in a memory-friendly way (vectorized in chunks if desired) + const int64_t chunk = 1 << 16; // process in chunks to reduce kernel launches if N large + for (int64_t start = 0; start < N; start += chunk) { + int64_t end = std::min(N, start + chunk); + auto r_chunk = coords.slice(0, start, end); // (M,3) + auto d_chunk = dX.slice(0, start, end); // (M,3) + + // rsq: (M,) + auto rsq = torch::sum(r_chunk * r_chunk, 1); + + // Compute per-chunk A contribution: sum ( rsq_i * I - r_i r_i^T ) + // Use vectorized outer: (M,3,1) x (M,1,3) => (M,3,3) + auto r_col = r_chunk.unsqueeze(2); + auto r_row = r_chunk.unsqueeze(1); + auto outer = r_col.matmul(r_row); // (M,3,3) + + // rsq * I per sample + auto rsq_exp = rsq.view({-1,1,1}); // (M,1,1) + auto rsqI = rsq_exp * I3.view({1,3,3}); // (M,3,3) + + auto A_per = rsqI - outer; // (M,3,3) + auto A_sum = torch::sum(A_per, 0); // (3,3) + A += A_sum; + + // b contribution: sum r_i x d_i + auto cross_rd = torch::cross(r_chunk, d_chunk, /*dim=*/1); // (M,3) + auto b_sum = torch::sum(cross_rd, 0); // (3,) + b += b_sum; + } + + // Regularize + if (lambda_reg > 0.0) { + A = A + lambda_reg * I3; + } + + // Solve A w = b + auto b_col = b.view({3,1}); + torch::Tensor w_col; + // Try available linalg::solve signatures robustly + w_col = torch::linalg::solve(A, b_col, /*left=*/true); + auto w = w_col.view({3}); + + // Compute corrected displacements: dX_corr = dX - w x r + auto w_expand = w.view({1,3}).expand({coords.size(0),3}); + auto wx = torch::cross(w_expand, coords, /*dim=*/1); + auto dX_corr = dX - wx; + + return dX_corr; +} + CudaCalcPyTorchForceE2EDiffConfKernel::CudaCalcPyTorchForceE2EDiffConfKernel(string name, const Platform& platform, CudaContext& cu): CalcPyTorchForceE2EDiffConfKernel(name, platform), hasInitializedKernel(false), cu(cu) { // Explicitly activate the primary context CHECK_RESULT(cuDevicePrimaryCtxRetain(&primaryContext, cu.getDevice()), "Failed to retain the primary context"); @@ -281,8 +362,8 @@ double CudaCalcPyTorchForceE2EDiffConfKernel::execute(ContextImpl& context,bool CHECK_RESULT(cuCtxSynchronize(), "Error synchronizing CUDA context"); torch::Tensor noise = scale*nnModule.forward(nnInputs).toTensor(); + noise = remove_rigid_rotation_lstsq_loop(positionsTensor.squeeze(1), noise); - //std::cout << "tim_idx, sigfac, mean noise:" << tim_idx << " " << sigfac << " " << torch::pow(noise, 2).mean() << "\n"; // get forces on positions as before if (includeForces) { diff --git a/platforms/reference/src/ReferencePyTorchKernelsE2EDiffConf.cpp b/platforms/reference/src/ReferencePyTorchKernelsE2EDiffConf.cpp index ee1b99e..6339266 100644 --- a/platforms/reference/src/ReferencePyTorchKernelsE2EDiffConf.cpp +++ b/platforms/reference/src/ReferencePyTorchKernelsE2EDiffConf.cpp @@ -72,6 +72,88 @@ static std::vector extractContextVariables(ContextImpl& context, int num return signals; } +// -------------------- Least-squares rigid-rotation remover (memory-efficient) -------------------- +// Returns corrected_dX (N,3) +torch::Tensor remove_rigid_rotation_lstsq_loop( + const torch::Tensor& coords_in, + const torch::Tensor& dX_in, + bool center = true, + double lambda_reg = 1e-8 +) { + TORCH_CHECK(coords_in.dim() == 2 && coords_in.size(1) == 3, "coords must be (N,3)"); + TORCH_CHECK(dX_in.dim() == 2 && dX_in.size(1) == 3, "dX must be (N,3)"); + TORCH_CHECK(coords_in.size(0) == dX_in.size(0), "coords and dX must have same N"); + + auto device = coords_in.device(); + auto dtype = coords_in.dtype(); + const int64_t N = coords_in.size(0); + + // center if requested + auto coords = coords_in; + auto dX = dX_in; + if (center) { + auto coords_mean = coords_in.mean(0, /*keepdim=*/true); + auto dX_mean = dX_in.mean(0, /*keepdim=*/true); + coords = coords_in - coords_mean; + dX = dX_in - dX_mean; + } + + // Prepare accumulators on device/dtype + auto I3 = torch::eye(3, torch::TensorOptions().device(device).dtype(dtype)); + auto A = torch::zeros({3,3}, torch::TensorOptions().device(device).dtype(dtype)); + auto b = torch::zeros({3}, torch::TensorOptions().device(device).dtype(dtype)); + + // Loop accumulate A and b in a memory-friendly way (vectorized in chunks if desired) + const int64_t chunk = 1 << 16; // process in chunks to reduce kernel launches if N large + for (int64_t start = 0; start < N; start += chunk) { + int64_t end = std::min(N, start + chunk); + auto r_chunk = coords.slice(0, start, end); // (M,3) + auto d_chunk = dX.slice(0, start, end); // (M,3) + + // rsq: (M,) + auto rsq = torch::sum(r_chunk * r_chunk, 1); + + // Compute per-chunk A contribution: sum ( rsq_i * I - r_i r_i^T ) + // Use vectorized outer: (M,3,1) x (M,1,3) => (M,3,3) + auto r_col = r_chunk.unsqueeze(2); + auto r_row = r_chunk.unsqueeze(1); + auto outer = r_col.matmul(r_row); // (M,3,3) + + // rsq * I per sample + auto rsq_exp = rsq.view({-1,1,1}); // (M,1,1) + auto rsqI = rsq_exp * I3.view({1,3,3}); // (M,3,3) + + auto A_per = rsqI - outer; // (M,3,3) + auto A_sum = torch::sum(A_per, 0); // (3,3) + A += A_sum; + + // b contribution: sum r_i x d_i + auto cross_rd = torch::cross(r_chunk, d_chunk, /*dim=*/1); // (M,3) + auto b_sum = torch::sum(cross_rd, 0); // (3,) + b += b_sum; + } + + // Regularize + if (lambda_reg > 0.0) { + A = A + lambda_reg * I3; + } + + // Solve A w = b + auto b_col = b.view({3,1}); + torch::Tensor w_col; + // Try available linalg::solve signatures robustly + w_col = torch::linalg::solve(A, b_col, /*left=*/true); + auto w = w_col.view({3}); + + // Compute corrected displacements: dX_corr = dX - w x r + auto w_expand = w.view({1,3}).expand({coords.size(0),3}); + auto wx = torch::cross(w_expand, coords, /*dim=*/1); + auto dX_corr = dX - wx; + + return dX_corr; +} + + ReferenceCalcPyTorchForceE2EDiffConfKernel::~ReferenceCalcPyTorchForceE2EDiffConfKernel() { } @@ -259,7 +341,7 @@ double ReferenceCalcPyTorchForceE2EDiffConfKernel::execute(ContextImpl& context, //std::cout << "get_diffusion_noise device:" << get_diffusion_noise.device(); torch::Tensor noise = scale*nnModule.forward(nnInputs).toTensor(); - + noise = remove_rigid_rotation_lstsq_loop(positionsTensor.squeeze(1), noise); // get forces on positions as before if (includeForces) { diff --git a/python/mlforce.i b/python/mlforce.i index f0b4028..c1039c7 100755 --- a/python/mlforce.i +++ b/python/mlforce.i @@ -13,12 +13,13 @@ #include "openmm/RPMDIntegrator.h" #include "openmm/RPMDMonteCarloBarostat.h" %} -namespace std { - %template(vectori) vector; - %template(vectord) vector; - %template(vectordd) vector< vector >; - %template(vectorii) vector< vector >; - }; +%template(vectori) std::vector; +%template(vectorii) std::vector >; +%template(vectord) std::vector; +%template(vectordd) std::vector >; +%template(vectorf) std::vector ; +%template(vectorff) std::vector >; + namespace PyTorchPlugin {