diff --git a/setup.py b/setup.py index dba4e6a33b..bc4d645baa 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ "jaxlib>=0.1.69", "flax>=0.3.0, <0.4", "orjson~=3.4", - "optax>=0.0.2, <0.0.10", + "optax>=0.0.2, <0.1.2", ] setup(