Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
## Latest Changes

### v0.5.4 (2025-02-01)
Improvements to JAX frontend.

**Added**:
- Jacobian Vector Products (JVP)
for both `TensorProduct` and `TensorProductConv` via custom primitives, in addition to VJP.
- Arbitrary higher-order derivatives in JAX.
- JAX JIT support; in particular, support for
Phonon Fine Tuning in [Nequix](https://github.com/atomicarchitects/nequix).

**Fixed**:
- Zero'd all output buffers in the backwards and double-backwards implementations of convolution
before calling kernels.

### v0.5.1-0.5.3 (2025-02-01)
Minor bugfixes related to packaging and JAX.

### v0.5.0 (2025-12-25)
JAX support is now available in
OpenEquivariance for BOTH NVIDIA and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def __init__(self, num_blocks, num_threads, warp_size, smem):
self.num_blocks = num_blocks
self.num_threads = num_threads
self.warp_size = warp_size
self.smem = smem
self.smem = int(smem)


class ComputationSchedule:
Expand Down
11 changes: 0 additions & 11 deletions openequivariance/openequivariance/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import json
import tempfile
import hashlib

from enum import IntEnum

Expand Down Expand Up @@ -200,13 +199,3 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]):
time_millis[i] = kernel_time

return time_millis


def hash_attributes(attrs):
m = hashlib.sha256()

for key in sorted(attrs.keys()):
m.update(attrs[key].__repr__().encode("utf-8"))

hash = int(m.hexdigest()[:16], 16) >> 1
attrs["hash"] = hash
Loading