JAX support #157
Replies: 11 comments 44 replies
-
|
Hello Teddy, It's great to hear from you! That's the idea! I have not started working towards it yet, as the work is not yet finished on the torch integration. If / when we start work on JAX bindings, is there a project you have in mind, where you would be interested in integrating it? |
Beta Was this translation helpful? Give feedback.
-
|
Is there any update on the JAX integration? Would be nice to also use it with the mace-jax. I also plan to integrate these models for simulation in JAX-MD and having these kernels would be great! |
Beta Was this translation helpful? Give feedback.
-
|
Hi @abhijeetgangan and @teddykoker - a note that we'll have JAX support merged shortly and would love some testing (@asglover is doing some on his end, but we want to stress test). As the packages aren't published yet to pypi and the diff is not on main, the steps are below. The README on the jax_support branch gives a usage example, but the API is the same as OEQ PyTorch so it shouldn't be too much of a problem. Let us know if there are any issues (and include the error and a JAX version if you can). |
Beta Was this translation helpful? Give feedback.
-
|
Update: merged JAX support, installation slightly easier with pip install openequivariance[jax]
pip install openequivariance_extjax --no-build-isolation |
Beta Was this translation helpful? Give feedback.
-
|
Beta Was this translation helpful? Give feedback.
-
|
I was able to integrate the kernels with the nequix jax models and also with the simulation engines (ASE and jax-md). Tests to compare the integration are also present. It would be nice to have a new release. The speed seams good and Below is a result of simulation with and without kernels with the two simulators (script is available on the above fork):
|
Beta Was this translation helpful? Give feedback.
-
|
Here is a demo for PFT (using @abhijeetgangan's Right now it gets an error: But once higher-order derivatives are supported this should be a good demonstration of speed/memory improvements. from pathlib import Path
import time
import urllib.request
import ase
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import phonopy
from nequix.data import atomic_numbers_to_indices, preprocess_graph
from nequix.model import Nequix
from tqdm import tqdm
def train(model, graph, ref_hessian, n_epochs=50, lr=0.003, label="HVP"):
x = graph["positions"]
n_atoms, n_dof = x.shape[0], x.size
def energy_fn(model, pos_flat):
pos = pos_flat.reshape(n_atoms, 3)
disp = pos[graph["senders"]] - pos[graph["receivers"]] + graph["shifts"] @ graph["cell"]
return model.node_energies(
disp, graph["species"], graph["senders"], graph["receivers"]
).sum()
grad_fn = jax.grad(energy_fn, argnums=1)
def hvp_fn(model, x, v):
return jax.jvp(lambda pos: grad_fn(model, pos), (x,), (v,))[1]
optimizer = optax.adam(lr)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
@eqx.filter_jit
def train_step_hvp(model, opt_state, x_flat, idx):
def loss_fn(model):
v = jnp.zeros(n_dof, dtype=x.dtype).at[idx].set(1.0)
hvp = hvp_fn(model, x_flat, v)
return jnp.abs(hvp - ref_hessian[:, idx]).mean()
loss, grads = eqx.filter_value_and_grad(loss_fn)(model)
updates, opt_state_new = optimizer.update(grads, opt_state, eqx.filter(model, eqx.is_array))
model = eqx.apply_updates(model, updates)
return model, opt_state_new, loss
x_flat = x.flatten()
device = jax.devices()[0]
loss_history = []
step_times = []
warmup_steps = 5
rng_key = jax.random.key(42)
for epoch in tqdm(range(n_epochs), desc=label):
jax.block_until_ready(model)
step_start = time.perf_counter()
rng_key, subkey = jax.random.split(rng_key)
idx = jax.random.randint(subkey, (), 0, n_dof)
model, opt_state, loss = train_step_hvp(model, opt_state, x_flat, idx)
jax.block_until_ready(model)
step_time = time.perf_counter() - step_start
if epoch >= warmup_steps:
step_times.append(step_time)
loss_history.append(float(loss))
avg_step_time = sum(step_times) / len(step_times) if step_times else 0
mem_stats = device.memory_stats()
peak_mem = mem_stats.get("peak_bytes_in_use", 0) / 1024**3 if mem_stats else 0
return loss_history, avg_step_time, peak_mem
def pft():
n_epochs = 1000
cutoff = 5.0
model_kwargs = dict(
n_species=1,
cutoff=cutoff,
hidden_irreps="32x0e + 32x1o + 32x2e",
n_layers=3,
radial_basis_size=8,
radial_mlp_size=64,
radial_mlp_layers=2,
)
data_path = Path("mp-149.yaml")
if not data_path.exists():
url = "https://github.com/teddykoker/nequix-examples/raw/refs/heads/main/phonon/mp-149.yaml"
print(f"downloading {url}...")
urllib.request.urlretrieve(url, data_path)
ph_ref = phonopy.load(data_path)
ph_ref.produce_force_constants()
atoms = ase.Atoms(
symbols=ph_ref.supercell.symbols,
positions=ph_ref.supercell.positions,
cell=ph_ref.supercell.cell,
pbc=True,
)
atom_indices = atomic_numbers_to_indices(set(atoms.get_atomic_numbers()))
g = preprocess_graph(atoms, atom_indices, cutoff, targets=False)
graph = {
k: jnp.array(v) for k, v in g.items() if v is not None and k not in ("n_node", "n_edge")
}
graph["species"] = graph["species"].astype(jnp.int32)
graph["senders"] = graph["senders"].astype(jnp.int32)
graph["receivers"] = graph["receivers"].astype(jnp.int32)
n_atoms = g["n_node"][0]
# (n, n, 3, 3) -> (3n, 3n)
ref_hessian = (
jnp.array(ph_ref.force_constants, dtype=jnp.float32)
.swapaxes(1, 2)
.reshape(n_atoms * 3, n_atoms * 3)
)
key = jax.random.key(0)
model_kernel = Nequix(key=key, kernel=True, **model_kwargs)
loss_kernel, avg_step_kernel, mem_kernel = train(
model_kernel, graph, ref_hessian, n_epochs=n_epochs, label="HVP (kernel)"
)
key = jax.random.key(0)
model_no_kernel = Nequix(key=key, kernel=False, **model_kwargs)
n_params = sum(x.size for x in jax.tree_util.tree_leaves(eqx.filter(model_no_kernel, eqx.is_array)))
loss_no_kernel, avg_step_no_kernel, mem_no_kernel = train(
model_no_kernel, graph, ref_hessian, n_epochs=n_epochs, label="HVP (no kernel)"
)
print(f"With Kernel: {avg_step_kernel * 1000:.1f}ms/step, {mem_kernel:.2f}GB, final loss={loss_kernel[-1]:.2e}")
print(f"No Kernel: {avg_step_no_kernel * 1000:.1f}ms/step, {mem_no_kernel:.2f}GB, final loss={loss_no_kernel[-1]:.2e}")
speedup = avg_step_no_kernel / avg_step_kernel if avg_step_kernel > 0 else 0
mem_ratio = mem_no_kernel / mem_kernel if mem_kernel > 0 else 0
print(f"Speedup: {speedup:.1f}x, Memory: {mem_ratio:.1f}x")
steps = np.arange(len(loss_no_kernel))
fig, ax = plt.subplots(figsize=(5, 4))
ax.plot(
steps,
loss_no_kernel,
"b-",
lw=2,
label=f"No Kernel ({avg_step_no_kernel * 1000:.1f}ms/step, {mem_no_kernel:.2f}GB)",
)
ax.plot(
steps,
loss_kernel,
"r--",
lw=2,
label=f"Kernel ({avg_step_kernel * 1000:.1f}ms/step, {mem_kernel:.2f}GB)",
)
ax.set(
xlabel="Step",
ylabel=r"Hessian MAE [meV/Å$^2$/atom]",
title=f"Nequix {n_params // 1000}K Hessian Training (mp-149)",
)
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("loss_comparison.png", dpi=300, bbox_inches="tight")
plt.close()
if __name__ == "__main__":
pft() |
Beta Was this translation helpful? Give feedback.
-
|
#181 adds JVP support; verified that the original JVP script works. This is still not yet released to pypi in case further changes are required, but it is merged to main. You'll need to reinstall both oeq and openequivariance_extjax. Happy hunting and let us know what the final speedup is (JVP loss curve below):
|
Beta Was this translation helpful? Give feedback.
-
|
Working on the PR for Nequix. Here are the preliminary speed ups we are seeing.
These are on a single A100 GPU. MPtrj uses a batch size of 64 with energy/force/stress loss. PFT uses a batch size of 16 (although has much bigger supercell inputs) with a energy/force/stress/hvp loss. For multi gpu training, we use Would it be feasible to implement this? Would be happy to provide a reproducer if helpful. |
Beta Was this translation helpful? Give feedback.
-
|
Inference speedups are even greater, especially combined with jax-md (the md overhead is more apparent once the kernels are in). For this ~2000 atom system of water molecules + NaCl I am getting almost a 18x speedup from 0.73 steps/second to 13.0 steps/second! Can run with nvt.py. Using the ASE calculator, we get a similar speedup to the torch version, about 8-10x. |
Beta Was this translation helpful? Give feedback.
-
|
Released! |
Beta Was this translation helpful? Give feedback.




Uh oh!
There was an error while loading. Please reload this page.
-
The README mentions "future frontend support outside of torch"; are there any plans for integration with JAX, e.g. using foreign function interface?
Beta Was this translation helpful? Give feedback.
All reactions