From 4d9970632e8ba3e6afab1df4477c50b96e36bbc6 Mon Sep 17 00:00:00 2001 From: Jason Zhu Date: Fri, 10 Jan 2025 17:12:54 +0000 Subject: [PATCH 1/2] Revert "Merge branch 'mfk' of github.com:machine-discovery/deer into mfk" This reverts commit bdeb567519bcc9c1d23c7298656ddfe1dfbea731, reversing changes made to f339be2c16788a932b24c0b8aed543360fcf2ee8. --- .github/workflows/ci.yml | 2 +- deer/deer_iter.py | 5 +---- deer/froot.py | 3 +-- deer/utils.py | 3 --- setup.py | 4 ++-- 5 files changed, 5 insertions(+), 12 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index abc636e..53d762d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ jobs: strategy: matrix: - python-version: ["3.10"] + python-version: ["3.9", "3.10"] steps: - name: Checkout repository diff --git a/deer/deer_iter.py b/deer/deer_iter.py index fd0048c..7ae2502 100644 --- a/deer/deer_iter.py +++ b/deer/deer_iter.py @@ -340,9 +340,6 @@ def deer_iteration_jvp( inv_lin2 = partial(inv_lin, gts) _, grad_yt = jax.jvp(inv_lin2, (rhs0, inv_lin_params), (grad_func, grad_inv_lin_params)) - # Create the tangent for is_converged - is_converged_tangent = jnp.zeros_like(is_converged, dtype=jax.dtypes.float0) - result = Result(yt, success=is_converged) - grad_result = Result(grad_yt, success=is_converged_tangent) + grad_result = Result(grad_yt) return result, grad_result diff --git a/deer/froot.py b/deer/froot.py index e91f79e..5e41d0d 100644 --- a/deer/froot.py +++ b/deer/froot.py @@ -192,7 +192,6 @@ def newton_iter_jvp( _, grad_func = jax.jvp(func_partial_y, (params,), (grad_params,)) grad_y = jnp.linalg.solve(jac, -grad_func) # (ny,) - is_converged_tangent = jnp.zeros_like(is_converged, dtype=jax.dtypes.float0) result = Result(yt, is_converged) - grad_result = Result(grad_y, success=is_converged_tangent) + grad_result = Result(grad_y) return result, grad_result diff --git a/deer/utils.py b/deer/utils.py index cdd2489..5fb9efc 100644 --- a/deer/utils.py +++ b/deer/utils.py @@ -20,9 +20,6 @@ def __init__(self, value: jnp.ndarray, success: Union[bool, None, jnp.ndarray] = success = jnp.full_like(value, True, dtype=jnp.bool) elif isinstance(success, bool): success = jnp.full_like(value, success, dtype=jnp.bool) - elif hasattr(success, "dtype") and success.dtype == jax.dtypes.float0: - # The Bool outputs of `jax.custom_jvp` requires tangents with `float0` type since jax 0.4.34. - success = jnp.full_like(value, success, dtype=jax.dtypes.float0) elif isinstance(success, jnp.ndarray): assert success.dtype == jnp.bool success = jnp.broadcast_to(success, value.shape) diff --git a/setup.py b/setup.py index 71204a4..bab4693 100644 --- a/setup.py +++ b/setup.py @@ -21,10 +21,10 @@ author_email='muhammad@machine-discovery.com', license='BSD-3', packages=find_packages(), - python_requires=">=3.10", + python_requires=">=3.9", install_requires=[ "jaxlib>=0.4.28", - "jax>=0.4.34", + "jax>=0.4.28", "numpy>=1.24.0", "scipy>=1.10.1", "equinox>=0.10.6", From b33e8e5885d24a79eba4f7b222b8dde18f23b598 Mon Sep 17 00:00:00 2001 From: Jason Zhu Date: Fri, 10 Jan 2025 17:19:56 +0000 Subject: [PATCH 2/2] limit jax version to 0.4.30 --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index bab4693..fa73ecf 100644 --- a/setup.py +++ b/setup.py @@ -23,8 +23,8 @@ packages=find_packages(), python_requires=">=3.9", install_requires=[ - "jaxlib>=0.4.28", - "jax>=0.4.28", + "jaxlib<=0.4.30", + "jax[cuda12]<=0.4.30", "numpy>=1.24.0", "scipy>=1.10.1", "equinox>=0.10.6",