From af1b112d17c0ea41c9308a270009b9d64b948410 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 17 Nov 2023 05:49:22 -0800 Subject: [PATCH] Increase minimum jaxlib version to 0.4.19. 0.4.19 has xla_extension version 207 and mlir_api_version 54. PiperOrigin-RevId: 583358703 --- .../experimental/call_torch/call_torch_xla.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/jaxonnxruntime/experimental/call_torch/call_torch_xla.py b/jaxonnxruntime/experimental/call_torch/call_torch_xla.py index e6caa85..edf5c2f 100644 --- a/jaxonnxruntime/experimental/call_torch/call_torch_xla.py +++ b/jaxonnxruntime/experimental/call_torch/call_torch_xla.py @@ -201,18 +201,11 @@ def refine_polymorphic_shapes( Returns: The refined module. """ - if xc.mlir_api_version >= 53: - refined_module_str = xc._xla.mlir.refine_polymorphic_shapes( # pylint: disable=protected-access - mlir.module_to_bytecode(module), - enable_shape_assertions=validate_static_shapes, - validate_static_shapes=validate_static_shapes, - ) - elif xc.mlir_api_version >= 50: - refined_module_str = xc._xla.mlir.refine_polymorphic_shapes( # pylint: disable=protected-access - mlir.module_to_bytecode(module) - ) - else: - raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12") + refined_module_str = xc._xla.mlir.refine_polymorphic_shapes( # pylint: disable=protected-access + mlir.module_to_bytecode(module), + enable_shape_assertions=validate_static_shapes, + validate_static_shapes=validate_static_shapes, + ) context = mlir.make_ir_context() with context: