From 6aac91a430ed653adcbd996ba20a88d262d89706 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 7 May 2024 14:48:10 -0700 Subject: [PATCH] Bump minimum jaxlib vesrion to 0.4.27 xla_extension_version is 261 and mlir_api_version is 56 PiperOrigin-RevId: 631556323 --- .../experimental/call_torch/call_torch_xla.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/jaxonnxruntime/experimental/call_torch/call_torch_xla.py b/jaxonnxruntime/experimental/call_torch/call_torch_xla.py index 97f3b9f..43c4626 100644 --- a/jaxonnxruntime/experimental/call_torch/call_torch_xla.py +++ b/jaxonnxruntime/experimental/call_torch/call_torch_xla.py @@ -201,19 +201,11 @@ def refine_polymorphic_shapes( Returns: The refined module. """ - if xc.mlir_api_version >= 53: - refined_module_str = xc._xla.mlir.refine_polymorphic_shapes( # pylint: disable=protected-access - mlir.module_to_bytecode(module), - enable_shape_assertions=validate_static_shapes, - validate_static_shapes=validate_static_shapes, - ) - elif xc.mlir_api_version >= 50: - refined_module_str = xc._xla.mlir.refine_polymorphic_shapes( # pylint: disable=protected-access - mlir.module_to_bytecode(module) - ) - else: - raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12") - + refined_module_str = xc._xla.mlir.refine_polymorphic_shapes( # pylint: disable=protected-access + mlir.module_to_bytecode(module), + enable_shape_assertions=validate_static_shapes, + validate_static_shapes=validate_static_shapes, + ) context = mlir.make_ir_context() with context: return ir.Module.parse(refined_module_str)