From bd55ed3cfc469d894f3202411661d7dba3233b0c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 9 Apr 2025 13:08:46 -0700 Subject: [PATCH] [JAX] Deprecate jax.lib.xla_client.{heap_profile,mlir_api_version}. PiperOrigin-RevId: 745704799 --- .../experimental/call_torch/call_torch_xla.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/jaxonnxruntime/experimental/call_torch/call_torch_xla.py b/jaxonnxruntime/experimental/call_torch/call_torch_xla.py index 7c321c1..1ac7d8b 100644 --- a/jaxonnxruntime/experimental/call_torch/call_torch_xla.py +++ b/jaxonnxruntime/experimental/call_torch/call_torch_xla.py @@ -204,19 +204,11 @@ def refine_polymorphic_shapes( Returns: The refined module. """ - if xc.mlir_api_version >= 53: - refined_module_str = xla_extension.mlir.refine_polymorphic_shapes( - mlir.module_to_bytecode(module), - enable_shape_assertions=validate_static_shapes, - validate_static_shapes=validate_static_shapes, - ) - elif xc.mlir_api_version >= 50: - refined_module_str = xla_extension.mlir.refine_polymorphic_shapes( - mlir.module_to_bytecode(module) - ) - else: - raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12") - + refined_module_str = xla_extension.mlir.refine_polymorphic_shapes( + mlir.module_to_bytecode(module), + enable_shape_assertions=validate_static_shapes, + validate_static_shapes=validate_static_shapes, + ) context = mlir.make_ir_context() with context: return ir.Module.parse(refined_module_str)