diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index abc636e..53d762d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ jobs: strategy: matrix: - python-version: ["3.10"] + python-version: ["3.9", "3.10"] steps: - name: Checkout repository diff --git a/deer/deer_iter.py b/deer/deer_iter.py index fd0048c..7ae2502 100644 --- a/deer/deer_iter.py +++ b/deer/deer_iter.py @@ -340,9 +340,6 @@ def deer_iteration_jvp( inv_lin2 = partial(inv_lin, gts) _, grad_yt = jax.jvp(inv_lin2, (rhs0, inv_lin_params), (grad_func, grad_inv_lin_params)) - # Create the tangent for is_converged - is_converged_tangent = jnp.zeros_like(is_converged, dtype=jax.dtypes.float0) - result = Result(yt, success=is_converged) - grad_result = Result(grad_yt, success=is_converged_tangent) + grad_result = Result(grad_yt) return result, grad_result diff --git a/deer/froot.py b/deer/froot.py index e91f79e..5e41d0d 100644 --- a/deer/froot.py +++ b/deer/froot.py @@ -192,7 +192,6 @@ def newton_iter_jvp( _, grad_func = jax.jvp(func_partial_y, (params,), (grad_params,)) grad_y = jnp.linalg.solve(jac, -grad_func) # (ny,) - is_converged_tangent = jnp.zeros_like(is_converged, dtype=jax.dtypes.float0) result = Result(yt, is_converged) - grad_result = Result(grad_y, success=is_converged_tangent) + grad_result = Result(grad_y) return result, grad_result diff --git a/deer/utils.py b/deer/utils.py index cdd2489..5fb9efc 100644 --- a/deer/utils.py +++ b/deer/utils.py @@ -20,9 +20,6 @@ def __init__(self, value: jnp.ndarray, success: Union[bool, None, jnp.ndarray] = success = jnp.full_like(value, True, dtype=jnp.bool) elif isinstance(success, bool): success = jnp.full_like(value, success, dtype=jnp.bool) - elif hasattr(success, "dtype") and success.dtype == jax.dtypes.float0: - # The Bool outputs of `jax.custom_jvp` requires tangents with `float0` type since jax 0.4.34. - success = jnp.full_like(value, success, dtype=jax.dtypes.float0) elif isinstance(success, jnp.ndarray): assert success.dtype == jnp.bool success = jnp.broadcast_to(success, value.shape) diff --git a/setup.py b/setup.py index 71204a4..fa73ecf 100644 --- a/setup.py +++ b/setup.py @@ -21,10 +21,10 @@ author_email='muhammad@machine-discovery.com', license='BSD-3', packages=find_packages(), - python_requires=">=3.10", + python_requires=">=3.9", install_requires=[ - "jaxlib>=0.4.28", - "jax>=0.4.34", + "jaxlib<=0.4.30", + "jax[cuda12]<=0.4.30", "numpy>=1.24.0", "scipy>=1.10.1", "equinox>=0.10.6",