From 64c5b1c76ebf47d638cfcdda1470a2c9bdcbbe3d Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 14 Jul 2025 07:01:02 -0700 Subject: [PATCH] Accelerate deprecation for attributes in jax.lib.xla_client and redirect dependencies to `xla/python/xla_client`. Attributes: - `Client` - `CompileOptions` - `Frame` - `HloSharding` - `OpSharding` - `Traceback` PiperOrigin-RevId: 782888911 --- jaxonnxruntime/experimental/export/exportable_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jaxonnxruntime/experimental/export/exportable_utils.py b/jaxonnxruntime/experimental/export/exportable_utils.py index a5d8a15..77176e6 100644 --- a/jaxonnxruntime/experimental/export/exportable_utils.py +++ b/jaxonnxruntime/experimental/export/exportable_utils.py @@ -14,20 +14,19 @@ """jax.export.Exported utils.""" -import io import os +from typing import Any import jax from jax import export as jax_export from jax import numpy as jnp -from jax.lib import xla_client from jaxlib.mlir import ir from mlir.dialects import stablehlo import tensorflow as tf import torch MLIRModule = ir.Module -HloSharding = xla_client.HloSharding | None +HloSharding = Any | None Sharding = jax.sharding.Sharding | None