Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion platforms/cuda/src/CudaPyTorchKernelsE2EDiffConf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,87 @@ static void* getTensorPointer(OpenMM::CudaContext& cu, torch::Tensor& tensor) {
return data;
}

// -------------------- Least-squares rigid-rotation remover (memory-efficient) --------------------
// Returns corrected_dX (N,3)
torch::Tensor remove_rigid_rotation_lstsq_loop(
const torch::Tensor& coords_in,
const torch::Tensor& dX_in,
bool center = true,
double lambda_reg = 1e-8
) {
TORCH_CHECK(coords_in.dim() == 2 && coords_in.size(1) == 3, "coords must be (N,3)");
TORCH_CHECK(dX_in.dim() == 2 && dX_in.size(1) == 3, "dX must be (N,3)");
TORCH_CHECK(coords_in.size(0) == dX_in.size(0), "coords and dX must have same N");

auto device = coords_in.device();
auto dtype = coords_in.dtype();
const int64_t N = coords_in.size(0);

// center if requested
auto coords = coords_in;
auto dX = dX_in;
if (center) {
auto coords_mean = coords_in.mean(0, /*keepdim=*/true);
auto dX_mean = dX_in.mean(0, /*keepdim=*/true);
coords = coords_in - coords_mean;
dX = dX_in - dX_mean;
}

// Prepare accumulators on device/dtype
auto I3 = torch::eye(3, torch::TensorOptions().device(device).dtype(dtype));
auto A = torch::zeros({3,3}, torch::TensorOptions().device(device).dtype(dtype));
auto b = torch::zeros({3}, torch::TensorOptions().device(device).dtype(dtype));

// Loop accumulate A and b in a memory-friendly way (vectorized in chunks if desired)
const int64_t chunk = 1 << 16; // process in chunks to reduce kernel launches if N large
for (int64_t start = 0; start < N; start += chunk) {
int64_t end = std::min<int64_t>(N, start + chunk);
auto r_chunk = coords.slice(0, start, end); // (M,3)
auto d_chunk = dX.slice(0, start, end); // (M,3)

// rsq: (M,)
auto rsq = torch::sum(r_chunk * r_chunk, 1);

// Compute per-chunk A contribution: sum ( rsq_i * I - r_i r_i^T )
// Use vectorized outer: (M,3,1) x (M,1,3) => (M,3,3)
auto r_col = r_chunk.unsqueeze(2);
auto r_row = r_chunk.unsqueeze(1);
auto outer = r_col.matmul(r_row); // (M,3,3)

// rsq * I per sample
auto rsq_exp = rsq.view({-1,1,1}); // (M,1,1)
auto rsqI = rsq_exp * I3.view({1,3,3}); // (M,3,3)

auto A_per = rsqI - outer; // (M,3,3)
auto A_sum = torch::sum(A_per, 0); // (3,3)
A += A_sum;

// b contribution: sum r_i x d_i
auto cross_rd = torch::cross(r_chunk, d_chunk, /*dim=*/1); // (M,3)
auto b_sum = torch::sum(cross_rd, 0); // (3,)
b += b_sum;
}

// Regularize
if (lambda_reg > 0.0) {
A = A + lambda_reg * I3;
}

// Solve A w = b
auto b_col = b.view({3,1});
torch::Tensor w_col;
// Try available linalg::solve signatures robustly
w_col = torch::linalg::solve(A, b_col, /*left=*/true);
auto w = w_col.view({3});

// Compute corrected displacements: dX_corr = dX - w x r
auto w_expand = w.view({1,3}).expand({coords.size(0),3});
auto wx = torch::cross(w_expand, coords, /*dim=*/1);
auto dX_corr = dX - wx;

return dX_corr;
}

CudaCalcPyTorchForceE2EDiffConfKernel::CudaCalcPyTorchForceE2EDiffConfKernel(string name, const Platform& platform, CudaContext& cu): CalcPyTorchForceE2EDiffConfKernel(name, platform), hasInitializedKernel(false), cu(cu) {
// Explicitly activate the primary context
CHECK_RESULT(cuDevicePrimaryCtxRetain(&primaryContext, cu.getDevice()), "Failed to retain the primary context");
Expand Down Expand Up @@ -281,8 +362,8 @@ double CudaCalcPyTorchForceE2EDiffConfKernel::execute(ContextImpl& context,bool
CHECK_RESULT(cuCtxSynchronize(), "Error synchronizing CUDA context");

torch::Tensor noise = scale*nnModule.forward(nnInputs).toTensor();
noise = remove_rigid_rotation_lstsq_loop(positionsTensor.squeeze(1), noise);

//std::cout << "tim_idx, sigfac, mean noise:" << tim_idx << " " << sigfac << " " << torch::pow(noise, 2).mean() << "\n";
// get forces on positions as before
if (includeForces) {

Expand Down
84 changes: 83 additions & 1 deletion platforms/reference/src/ReferencePyTorchKernelsE2EDiffConf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,88 @@ static std::vector<double> extractContextVariables(ContextImpl& context, int num
return signals;
}

// -------------------- Least-squares rigid-rotation remover (memory-efficient) --------------------
// Returns corrected_dX (N,3)
torch::Tensor remove_rigid_rotation_lstsq_loop(
const torch::Tensor& coords_in,
const torch::Tensor& dX_in,
bool center = true,
double lambda_reg = 1e-8
) {
TORCH_CHECK(coords_in.dim() == 2 && coords_in.size(1) == 3, "coords must be (N,3)");
TORCH_CHECK(dX_in.dim() == 2 && dX_in.size(1) == 3, "dX must be (N,3)");
TORCH_CHECK(coords_in.size(0) == dX_in.size(0), "coords and dX must have same N");

auto device = coords_in.device();
auto dtype = coords_in.dtype();
const int64_t N = coords_in.size(0);

// center if requested
auto coords = coords_in;
auto dX = dX_in;
if (center) {
auto coords_mean = coords_in.mean(0, /*keepdim=*/true);
auto dX_mean = dX_in.mean(0, /*keepdim=*/true);
coords = coords_in - coords_mean;
dX = dX_in - dX_mean;
}

// Prepare accumulators on device/dtype
auto I3 = torch::eye(3, torch::TensorOptions().device(device).dtype(dtype));
auto A = torch::zeros({3,3}, torch::TensorOptions().device(device).dtype(dtype));
auto b = torch::zeros({3}, torch::TensorOptions().device(device).dtype(dtype));

// Loop accumulate A and b in a memory-friendly way (vectorized in chunks if desired)
const int64_t chunk = 1 << 16; // process in chunks to reduce kernel launches if N large
for (int64_t start = 0; start < N; start += chunk) {
int64_t end = std::min<int64_t>(N, start + chunk);
auto r_chunk = coords.slice(0, start, end); // (M,3)
auto d_chunk = dX.slice(0, start, end); // (M,3)

// rsq: (M,)
auto rsq = torch::sum(r_chunk * r_chunk, 1);

// Compute per-chunk A contribution: sum ( rsq_i * I - r_i r_i^T )
// Use vectorized outer: (M,3,1) x (M,1,3) => (M,3,3)
auto r_col = r_chunk.unsqueeze(2);
auto r_row = r_chunk.unsqueeze(1);
auto outer = r_col.matmul(r_row); // (M,3,3)

// rsq * I per sample
auto rsq_exp = rsq.view({-1,1,1}); // (M,1,1)
auto rsqI = rsq_exp * I3.view({1,3,3}); // (M,3,3)

auto A_per = rsqI - outer; // (M,3,3)
auto A_sum = torch::sum(A_per, 0); // (3,3)
A += A_sum;

// b contribution: sum r_i x d_i
auto cross_rd = torch::cross(r_chunk, d_chunk, /*dim=*/1); // (M,3)
auto b_sum = torch::sum(cross_rd, 0); // (3,)
b += b_sum;
}

// Regularize
if (lambda_reg > 0.0) {
A = A + lambda_reg * I3;
}

// Solve A w = b
auto b_col = b.view({3,1});
torch::Tensor w_col;
// Try available linalg::solve signatures robustly
w_col = torch::linalg::solve(A, b_col, /*left=*/true);
auto w = w_col.view({3});

// Compute corrected displacements: dX_corr = dX - w x r
auto w_expand = w.view({1,3}).expand({coords.size(0),3});
auto wx = torch::cross(w_expand, coords, /*dim=*/1);
auto dX_corr = dX - wx;

return dX_corr;
}


ReferenceCalcPyTorchForceE2EDiffConfKernel::~ReferenceCalcPyTorchForceE2EDiffConfKernel() {
}

Expand Down Expand Up @@ -259,7 +341,7 @@ double ReferenceCalcPyTorchForceE2EDiffConfKernel::execute(ContextImpl& context,

//std::cout << "get_diffusion_noise device:" << get_diffusion_noise.device();
torch::Tensor noise = scale*nnModule.forward(nnInputs).toTensor();

noise = remove_rigid_rotation_lstsq_loop(positionsTensor.squeeze(1), noise);

// get forces on positions as before
if (includeForces) {
Expand Down
13 changes: 7 additions & 6 deletions python/mlforce.i
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
#include "openmm/RPMDIntegrator.h"
#include "openmm/RPMDMonteCarloBarostat.h"
%}
namespace std {
%template(vectori) vector<int>;
%template(vectord) vector<double>;
%template(vectordd) vector< vector<double> >;
%template(vectorii) vector< vector<int> >;
};
%template(vectori) std::vector<int>;
%template(vectorii) std::vector<std::vector<int> >;
%template(vectord) std::vector<double>;
%template(vectordd) std::vector<std::vector<double> >;
%template(vectorf) std::vector<float> ;
%template(vectorff) std::vector<std::vector<float> >;


namespace PyTorchPlugin {

Expand Down