diff --git a/vmoe/partitioning.py b/vmoe/partitioning.py index 9506b32..1b9d5ba 100644 --- a/vmoe/partitioning.py +++ b/vmoe/partitioning.py @@ -84,6 +84,15 @@ UnparsedPartitionSpec = Union[str, Tuple[Union[str, Tuple[str, ...]], ...]] +def get_array_sharding_or_default(arr: jax.Array) -> sharding.Sharding: + if hasattr(arr, 'sharding'): + return arr.sharding + else: + op_sharding = jax.xla.xc.OpSharding() + op_sharding.type = jax.xla.xc.OpSharding.Type.REPLICATED + return sharding.OpShardingSharding(jax.devices(), op_sharding) + + def process_has_contiguous_device_slice(devices: np.ndarray, process_index: int) -> bool: """Checks if the devices of a process form a contiguous slice in the mesh."""