From 0511e5dbf30d4ca53e69b3df3c3641344c5cb5bb Mon Sep 17 00:00:00 2001 From: liyingli Date: Wed, 4 Feb 2026 07:23:27 +0000 Subject: [PATCH 01/24] update maxtext to version 022dc02 --- third_party/maxtext | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/maxtext b/third_party/maxtext index 8def32a8a..022dc02eb 160000 --- a/third_party/maxtext +++ b/third_party/maxtext @@ -1 +1 @@ -Subproject commit 8def32a8a5b96fc6267636a8e58abfe4c178e161 +Subproject commit 022dc02eb89057350d2e365f23c8f1f0edb4732d From 2e58d6781015e8451e67c9be7bf998d9fed07464 Mon Sep 17 00:00:00 2001 From: liyingli Date: Wed, 4 Feb 2026 13:23:56 +0000 Subject: [PATCH 02/24] update maxtext part 1 --- primus/backends/maxtext/checkpointing.py | 186 +++++++++++- primus/backends/maxtext/configs/__init__.py | 0 primus/backends/maxtext/configs/types.py | 12 + .../input_pipeline/_hf_data_processing.py | 62 ++-- .../input_pipeline/custom_packed_batch.py | 215 -------------- .../backends/maxtext/layers/attention_op.py | 60 ++-- primus/backends/maxtext/layers/attentions.py | 49 ---- primus/backends/maxtext/layers/gemma.py | 119 ++++++++ primus/backends/maxtext/layers/gemma2.py | 195 +++++++++++++ primus/backends/maxtext/layers/llama2.py | 115 ++++++++ primus/backends/maxtext/layers/mistral.py | 103 +++++++ primus/backends/maxtext/layers/mixtral.py | 108 +++++++ primus/backends/maxtext/layers/moe.py | 12 +- primus/backends/maxtext/max_utils.py | 127 ++++++++ primus/backends/maxtext/train.py | 272 +++++++++--------- primus/backends/maxtext/train_utils.py | 41 ++- .../configs/modules/maxtext/trainer_base.yaml | 106 ++++++- 17 files changed, 1281 insertions(+), 501 deletions(-) create mode 100644 primus/backends/maxtext/configs/__init__.py create mode 100644 primus/backends/maxtext/configs/types.py delete mode 100644 primus/backends/maxtext/input_pipeline/custom_packed_batch.py delete mode 100644 primus/backends/maxtext/layers/attentions.py create mode 100644 primus/backends/maxtext/layers/gemma.py create mode 100644 primus/backends/maxtext/layers/gemma2.py create mode 100644 primus/backends/maxtext/layers/llama2.py create mode 100644 primus/backends/maxtext/layers/mistral.py create mode 100644 primus/backends/maxtext/layers/mixtral.py diff --git a/primus/backends/maxtext/checkpointing.py b/primus/backends/maxtext/checkpointing.py index a04dfc4b1..1034f32e5 100644 --- a/primus/backends/maxtext/checkpointing.py +++ b/primus/backends/maxtext/checkpointing.py @@ -7,9 +7,178 @@ from typing import Any -import orbax.checkpoint as ocp +import jax + from etils import epath +from flax.training import train_state +import orbax.checkpoint as ocp +import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager +import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager + from MaxText import max_logging +from MaxText.checkpointing import _replica_devices, _restore_grain_iterator, load_params_from_path, _load_full_state_from_path +from MaxText.multihost_dataloading import MultiHostDataLoadIterator +from MaxText.input_pipeline.input_pipeline_interface import PlaceHolderDataIterator + +Composite = ocp.args.Composite +EmergencyCheckpointManager = emergency_checkpoint_manager.CheckpointManager +EmergencyReplicatorCheckpointManager = emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager + + +def load_state_if_possible( + checkpoint_manager: ocp.CheckpointManager | None, + data_iterator: MultiHostDataLoadIterator | list[MultiHostDataLoadIterator] | None, + load_parameters_from_path: str, + load_full_state_from_path: str, + checkpoint_storage_concurrent_gb: int, + abstract_unboxed_pre_state: train_state.TrainState, + enable_single_replica_ckpt_restoring: bool | None = False, + dataset_type: str | None = "tfds", + step: int = -1, # -1 means latest + use_ocdbt=True, + use_zarr3=True, + enable_orbax_v1=False, + checkpoint_conversion_fn=None, + source_checkpoint_layout="orbax", + expansion_factor_real_data: int = -1, + ): + """Loads TrainState as possible from the inputs. + + Args: + checkpoint_manager: if the checkpoint_manager has a valid checkpoint, return + that TrainState. This enables a full reload of a run in progress. + load_parameters_from_path: if there is no checkpoint in the checkpoint + manager, load parameters from a parameter only checkpoint at this path. + load_full_state_from_path: if there is no checkpoint in the checkpoint + manager, load full state from a full state checkpoint at this path. + abstract_unboxed_pre_state: an unboxed, abstract TrainState that Orbax + matches type against. + enable_single_replica_ckpt_restoring: bool flag for restoring checkpoitn + with SingleReplicaArrayHandler + checkpoint_storage_concurrent_gb: concurrent GB for checkpoint byte I/O. + enable_orbax_v1: bool flag for enabling Orbax v1. + checkpoint_conversion_fn: function for converting checkpoint to Orbax v1. + source_checkpoint_layout: Optional checkpoint context to use for loading, + provided in string format with the default being "orbax". + + Returns: + A tuple of (train_state, train_state_params) where full_train_state captures + a full reload and train_state_params just the params for a partial reload. + At most one will be non-None. Both can be None if neither checkpoint is + set. + """ + + if checkpoint_manager is not None: + max_logging.log("checkpoint manager exists so trying to load this run's existing checkpoint") + + step = checkpoint_manager.latest_step() if step < 0 else step + if step is not None: + max_logging.log(f"restoring from this run's directory step {step}") + + def map_to_pspec(data): + if not enable_single_replica_ckpt_restoring: + return ocp.type_handlers.ArrayRestoreArgs(sharding=data.sharding) + pspec = data.sharding.spec + mesh = data.sharding.mesh + replica_axis_index = 0 + replica_devices = _replica_devices(mesh.devices, replica_axis_index) + replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names) + single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec) + + return ocp.type_handlers.SingleReplicaArrayRestoreArgs( + sharding=jax.sharding.NamedSharding(mesh, pspec), + single_replica_sharding=single_replica_sharding, + global_shape=data.shape, + dtype=data.dtype, + ) + + # Cache the original ArrayHandler before potentially overriding it. + # This is the same handler used when enable_single_replica_ckpt_restoring=False. + original_array_handler = ocp.type_handlers.get_type_handler(jax.Array) + + # Register SingleReplicaArrayHandler globally for restore (if enabled) + if enable_single_replica_ckpt_restoring: + single_replica_handler = ocp.type_handlers.SingleReplicaArrayHandler( + replica_axis_index=0, + broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit + ) + ocp.type_handlers.register_type_handler(jax.Array, single_replica_handler, override=True) + + restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) + checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args) + + def _restore_original_array_handler(): + """Restore the original ArrayHandler after SingleReplicaArrayHandler restore. + + This is critical because SingleReplicaArrayHandler is designed for restore only. + Using it for saves will cause missing array_metadatas files and checkpoint failures. + We restore the EXACT handler that was in place before, not a new instance. + """ + if enable_single_replica_ckpt_restoring: + max_logging.log("Restoring original ArrayHandler after SingleReplicaArrayHandler restore...") + # Re-register the original handler that was cached before the override + ocp.type_handlers.register_type_handler(jax.Array, original_array_handler, override=True) + max_logging.log("Original ArrayHandler restored successfully.") + + match (checkpoint_manager, dataset_type, data_iterator): + # Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager + # or EmergencyReplicatorCheckpointManager. The '_' indicates that 'dataset_type' and + # 'data_iterator' can be any value and aren't used in this pattern. + case (checkpoint_manager, _, _) if isinstance( + checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager) + ): + result = ( + checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state, + None, + ) + _restore_original_array_handler() + return result + # Case 2: Matches if dataset type is "grain" and the data iterator is not a + # PlaceHolderDataIterator and a specific checkpoint file exists for the iterator + case ( + checkpoint_manager, + dataset_type, + data_iterator, + ) if ( + dataset_type == "grain" + and data_iterator + and not isinstance(data_iterator, PlaceHolderDataIterator) + and (checkpoint_manager.directory / str(step) / "iter").exists() + ): + result = _restore_grain_iterator( + checkpoint_manager, step, data_iterator, checkpoint_args, expansion_factor_real_data + ) + _restore_original_array_handler() + return result + # Case 3: Default/Fallback case. + # This case acts as a wildcard ('_') and matches if none of the preceding cases were met. + case _: + result = (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) + _restore_original_array_handler() + return result + + if load_parameters_from_path != "": + restored_params = load_params_from_path( + load_parameters_from_path, + abstract_unboxed_pre_state.params, + checkpoint_storage_concurrent_gb, + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + ) + return None, restored_params + elif load_full_state_from_path != "": + max_logging.log(f"Loading full state from path: {load_full_state_from_path}") + restored_state = _load_full_state_from_path( + path=load_full_state_from_path, + abstract_unboxed_pre_state=abstract_unboxed_pre_state, + enable_orbax_v1=enable_orbax_v1, + checkpoint_conversion_fn=checkpoint_conversion_fn, + source_checkpoint_layout=source_checkpoint_layout, + ) + return {"items": restored_state}, None + else: + max_logging.log("No existing checkpoints found, not restoring checkpoint.") + return None, None def create_orbax_checkpoint_manager( @@ -30,17 +199,18 @@ def create_orbax_checkpoint_manager( max_logging.log(f"Creating checkpoint manager with ocdbt={use_ocdbt} and zarr3={use_zarr3}") + # Base configuration for all dataset types + item_names = ("items",) + # we need to use ocdbt and zarr3 to control max file size in the checkpoint + item_handlers = {"items": ocp.PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)} + if dataset_type == "grain": - item_names = ("items", "iter") - else: - item_names = ("items",) + item_names += ("iter",) + item_handlers["iter"] = ocp.GrainCheckpointHandler() # local storage checkpoint needs parent directory created p = epath.Path(checkpoint_dir) p.mkdir(exist_ok=True, parents=True) - # we need to use ocdbt and zarr3 to control max file size in the checkpoint - # omitting `iter` uses default handler for `iter` - item_handlers = {"items": ocp.PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)} manager = ocp.CheckpointManager( p, item_names=item_names, @@ -49,7 +219,7 @@ def create_orbax_checkpoint_manager( create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async, - max_to_keep=max_to_keep, + max_to_keep = max_to_keep, ), logger=orbax_logger, ) diff --git a/primus/backends/maxtext/configs/__init__.py b/primus/backends/maxtext/configs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/primus/backends/maxtext/configs/types.py b/primus/backends/maxtext/configs/types.py new file mode 100644 index 000000000..c79555ed7 --- /dev/null +++ b/primus/backends/maxtext/configs/types.py @@ -0,0 +1,12 @@ +from pydantic.fields import Field + +from MaxText.configs.types import MoEGeneral, DevelopmentAndDebugging + +class PrimusMoEGeneral(MoEGeneral): + expert_balance: bool = Field(False, description="Whether to use expert balancing.") + + +class PrimusDevelopmentAndDebugging(DevelopmentAndDebugging): + jax_distributed_heartbeat_timeout_seconds: int = Field( + 100, description="How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores." + ) \ No newline at end of file diff --git a/primus/backends/maxtext/input_pipeline/_hf_data_processing.py b/primus/backends/maxtext/input_pipeline/_hf_data_processing.py index 641ac3845..d776d635e 100644 --- a/primus/backends/maxtext/input_pipeline/_hf_data_processing.py +++ b/primus/backends/maxtext/input_pipeline/_hf_data_processing.py @@ -12,11 +12,9 @@ import numpy as np import transformers from MaxText import multihost_dataloading -from MaxText.input_pipeline import _input_pipeline_utils +from MaxText.input_pipeline import _input_pipeline_utils, instruction_data_processing from MaxText.input_pipeline._hf_data_processing import vision_sft_preprocessing_pipeline -from .custom_packed_batch import CustomPackAndBatchOperation - def preprocessing_pipeline( dataloading_host_index, @@ -31,18 +29,19 @@ def preprocessing_pipeline( max_target_length, shuffle, data_shuffle_seed, + chat_template_path="", add_bos=True, add_eos=True, packing=True, shift=True, num_threads=1, - drop_remainder=False, - generate_padding_example=False, + drop_remainder=True, + generate_padding_batch=False, use_dpo=None, use_sft=None, sft_train_on_completion_only=True, grain_worker_count=1, # only support 0 or 1 - max_segments=1, # max segments per sequence + max_segments_per_seq=1, # max segments per sequence ): """pipeline for preprocessing HF dataset""" assert ( @@ -63,10 +62,16 @@ def preprocessing_pipeline( if use_sft: dataset = dataset.select_columns(data_column_names) - supported_columns = [["prompt", "completion"], ["messages"]] + supported_columns = [["prompt", "completion"], ["messages"], ["question", "answer"]] assert any( set(data_column_names) == set(supported) for supported in supported_columns ), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_column_names}" + + # convert instruction dataset to conversational format + dataset, data_column_names = instruction_data_processing.convert_to_conversational_format( + dataset=dataset, data_columns=data_column_names, chat_template_path=chat_template_path + ) + assert _input_pipeline_utils.is_conversational( dataset.features, data_column_names ), "Dataset is not in conversational format." @@ -74,11 +79,7 @@ def preprocessing_pipeline( if len(data_column_names) > 1: combined_column_name = "messages" dataset_features = datasets.Features( - { - combined_column_name: [ - {"content": datasets.Value(dtype="string"), "role": datasets.Value(dtype="string")} - ] - } + {combined_column_name: [{"content": datasets.Value(dtype="string"), "role": datasets.Value(dtype="string")}]} ) dataset = dataset.map( _input_pipeline_utils.combine_columns, @@ -119,7 +120,6 @@ def preprocessing_pipeline( dataloading_host_index, dataloading_host_count, num_threads, - generate_padding_example, max_target_length, data_column_names, ) @@ -147,25 +147,21 @@ def lists2array(x): data_column_names = ("inputs", "targets") if packing and not use_dpo: - # monkey patch the splitter to handle TE's maximum segment limitation length_struct = {col: max_target_length for col in data_column_names} - pack_and_batch = CustomPackAndBatchOperation( - batch_size=global_batch_size // jax.process_count(), - length_struct=length_struct, - max_segments=max_segments, + operations.append( + grain.experimental.PackAndBatchOperation( + batch_size=global_batch_size // jax.process_count(), + length_struct=length_struct, + max_sequences_per_bin=max_segments_per_seq, + ) ) - operations.append(pack_and_batch) operations.append(_input_pipeline_utils.ReformatPacking(data_column_names)) else: - operations.append(_input_pipeline_utils.PadToMaxLength(max_target_length, pad_id)) - operations.append( - grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder) - ) + operations.append(_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id)) + operations.append(grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder)) if shift and not use_dpo: - operations.append( - _input_pipeline_utils.ShiftData(ignored_ids=[pad_id, tokenizer.bos_token_id], axis=1) - ) + operations.append(_input_pipeline_utils.ShiftData(ignored_ids=[pad_id, tokenizer.bos_token_id], axis=1)) # Since HuggingFace IterableDataset does not support access through index # Indexes generated by dummy_index_sampler is not used. @@ -189,12 +185,11 @@ def lists2array(x): read_options=grain.ReadOptions(num_threads=num_threads, prefetch_buffer_size=128), ) - multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh) + multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh, generate_padding_batch) # Return multi-host jax.Array prep iterator return multihost_gen - def make_hf_train_iterator( config: ml_collections.ConfigDict, global_mesh, @@ -237,11 +232,12 @@ def make_hf_train_iterator( add_bos=config.add_bos, add_eos=config.add_eos, packing=config.packing, - generate_padding_example=False, + generate_padding_batch=config.generate_padding_batch_train, use_dpo=config.use_dpo, use_sft=config.use_sft, sft_train_on_completion_only=config.sft_train_on_completion_only, - max_segments=config.max_segments, + chat_template_path=config.chat_template_path, + max_segments_per_seq=config.max_segments_per_seq, ) return train_iter @@ -261,7 +257,6 @@ def make_hf_eval_iterator( token=config.hf_access_token, ) - eval_generate_padding_example = config.eval_steps > 0 if config.use_sft and config.use_multimodal: eval_iter = vision_sft_preprocessing_pipeline( dataset=eval_ds, @@ -290,10 +285,11 @@ def make_hf_eval_iterator( add_bos=config.add_bos, add_eos=config.add_eos, packing=config.packing, - generate_padding_example=eval_generate_padding_example, + generate_padding_batch=config.generate_padding_batch_eval, use_dpo=config.use_dpo, use_sft=config.use_sft, sft_train_on_completion_only=config.sft_train_on_completion_only, - max_segments=config.max_segments, + chat_template_path=config.chat_template_path, + max_segments_per_seq=config.max_segments_per_seq, ) return eval_iter diff --git a/primus/backends/maxtext/input_pipeline/custom_packed_batch.py b/primus/backends/maxtext/input_pipeline/custom_packed_batch.py deleted file mode 100644 index 3f6b2e21e..000000000 --- a/primus/backends/maxtext/input_pipeline/custom_packed_batch.py +++ /dev/null @@ -1,215 +0,0 @@ -############################################################################### -# Copyright 2023–2025 Google LLC. All rights reserved. -# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. -# -# See LICENSE for license information. -############################################################################### -""" -Forked from https://github.com/google/grain/blob/7841100258c90c77fcebdd668232aea9c0314fc2/grain/_src/python/experimental/example_packing/packing.py - -Customized packing based on MaxText's default -Modified to support max segments per sequence - -""" - -import dataclasses -from typing import Any, Generic, Iterator, TypeVar, Union, cast - -import numpy as np -from absl import logging -from grain._src.core import tree_lib -from grain._src.python import record - -_T = TypeVar("_T") - - -class _PackedBatch: - """Class to represent a batch of packed examples.""" - - def __init__( - self, - element_for_shapes: Any, # PyTree[np.ndarray] - batch_size: int, - length_struct: Any, # PyTree[int] - max_segments: int, - ): - self._batch_size = batch_size - self._length_struct = length_struct - self._max_segments = max_segments - - # Define the main buffers we will pack the data into. - def make_packed_buffer(length: int, input_arr: np.ndarray): - return np.zeros( - shape=(batch_size, length, *input_arr.shape[1:]), # (B, T, ...) - dtype=input_arr.dtype, - ) - - self._batch = tree_lib.map_structure(make_packed_buffer, length_struct, element_for_shapes) - - def make_packed_aux_info(length: int): - return np.zeros(shape=(batch_size, length), dtype=np.int32) - - self._segmentations = tree_lib.map_structure(make_packed_aux_info, length_struct) - self._positions = tree_lib.map_structure(make_packed_aux_info, length_struct) - - # Tracks the next empty position to insert an example for each row - # in the batch, for each feature in features_to_pack. - self._first_free_cell_per_row = tree_lib.map_structure( - lambda _: np.zeros(batch_size, dtype=np.int32), length_struct - ) - - # Tracks the number of examples already packed into row of the batch. Used - # to fill the segmentation values for each feature. - self._num_examples_per_row = [0 for _ in range(batch_size)] - - # For determinism, the metadata.index for the packed batch must match - # metadata.index of the _last_ included input example. - self._last_record_metadata = None - - def get_packed_batch(self) -> record.Record[tuple[_T, _T, _T]]: - assert self._last_record_metadata is not None - return record.Record( - metadata=cast(record.RecordMetadata, self._last_record_metadata), - data=(self._batch, self._segmentations, self._positions), - ) - - def _can_add_at_row( - self, - element: Any, # PyTree[np.ndarray] - ) -> int: - """Returns the index of the first row which fits element, or -1 if none.""" - element_feature_lengths = tree_lib.map_structure(len, element) - - # Check no feature exceeds max length - length_exceeded = tree_lib.map_structure( - lambda feature_length, max_length: feature_length > max_length, - element_feature_lengths, - self._length_struct, - ) - if any(tree_lib.flatten(length_exceeded)): - raise ValueError("Inputs to PackAndBatchOperation must be truncated to max length.") - - # For each row, check whether the total length after adding the current - # element would exceed max feature lengths. - def _feature_will_fit(feature_length, first_free_cell, max_length): - return feature_length + first_free_cell <= max_length - - is_row_free_struct = tree_lib.map_structure( - _feature_will_fit, element_feature_lengths, self._first_free_cell_per_row, self._length_struct - ) - - ## Pick first row (if exists) where element can be added. - for i in range(self._batch_size): - if self._num_examples_per_row[i] < self._max_segments: - row_is_free_per_feature = [free[i] for free in tree_lib.flatten(is_row_free_struct)] - if all(row_is_free_per_feature): - return i - return -1 - - def add_element_to_batch( - self, - element: Any, # PyTree[np.ndarray] - row: int, - ) -> None: - """Adds element to current batch at the specified row.""" - # Apply updates to each feature. - for per_feature_data in zip( - tree_lib.flatten(element), - tree_lib.flatten(self._batch), - tree_lib.flatten(self._segmentations), - tree_lib.flatten(self._positions), - tree_lib.flatten(self._first_free_cell_per_row), - ): - value, batch_value, segmentations, positions, first_free_cell_per_row = per_feature_data - # Update batch value, segmentations, and positions. - start = first_free_cell_per_row[row] - end = first_free_cell_per_row[row] + len(value) - batch_value[row][start:end] = value - segmentations[row][start:end] = self._num_examples_per_row[row] + 1 - positions[row][start:end] = np.arange(end - start) - # Update first_free_cell_per_row. - first_free_cell_per_row[row] += len(value) - - self._num_examples_per_row[row] += 1 - - def try_add_to_batch(self, element: record.Record) -> bool: - """Finds a row in the batch at which element can be added.""" - if (row_idx := self._can_add_at_row(element.data)) == -1: - return False - self.add_element_to_batch(element.data, row_idx) - self._last_record_metadata = element.metadata.remove_record_key() - return True - - -@dataclasses.dataclass -class CustomPackAndBatchOperation(Generic[_T]): - """PyGrain pack-and-batch operation - see module docstring. - - WARNING: This class is deprecated. Please use - lazy_dataset.FirstFitPackIterDataset instead. - - Attributes: - batch_size: int, the batch size. - length_struct: A pytree, with the same structure as `input_iterator` - elements, but where leaves are ints, representing the packed length of the - corresponding feature. - max_segments: int, max segments per sequence - - __call__() takes an input iterator, where elements are `Record`s containing: - - input_data: Pytrees of arrays. For more info about PyTrees, please refer to: - https://jax.readthedocs.io/en/latest/pytrees.html. Packed leaves should be - n-dimensional arrays, with sequence length as the leading dimension, i.e. - shape (T_in, ...), where T_in < T_packed. Note that leaves can and will - often have ragged length dimensions across different elements of the input - iterator. - - The output of __call__() will be an iterator over `Record`s containing a - 3-tuple of Pytrees. These are: - - data: The batched and packed data. This is a Pytree with parallel structure - to elements of `input_iterator`. Leaves have shape (B, T_packed, ...). - segmentations: Pytree with the same structure as `data`, and leaves of shape - (B, T). Represents which example each entry comes from. This may be used - for Transformer attention masks, for example. - positions: Pytree with the same structure as `data`, and leaves of shape - (B, T). Represents the position of each entry within their original - example. This may be used e.g. in Transformer absolute position - embeddings. - """ - - length_struct: Any # PyTree[int] - batch_size: int - max_segments: int - # We don't know input shapes and corresponding buffer shapes until __call__. - _cur_batch: Union[_PackedBatch, None] = None - - def __post_init__(self): - logging.error( - "PackAndBatchOperation is deprecated. Please use" " lazy_dataset.FirstFitPackIterDataset instead." - ) - - def __call__( - self, input_iterator: Iterator[record.Record[_T]] - ) -> Iterator[record.Record[tuple[_T, _T, _T]]]: - for element in input_iterator: - # Use `element` to set dtypes + trailing dimensions. - if self._cur_batch is None: # pytype: disable=attribute-error - self._cur_batch = _PackedBatch( - element.data, self.batch_size, self.length_struct, self.max_segments - ) - - # Try adding element to the current packed batch. - element_added_to_batch = self._cur_batch.try_add_to_batch(element) - - # When we have a full batch, yield the current packed data, - # and then start a new batch with this element. - if not element_added_to_batch: - yield self._cur_batch.get_packed_batch() # Main yield - self._cur_batch = _PackedBatch( - element.data, self.batch_size, self.length_struct, self.max_segments - ) - self._cur_batch.try_add_to_batch(element) - - # Final batch - yield self._cur_batch.get_packed_batch() diff --git a/primus/backends/maxtext/layers/attention_op.py b/primus/backends/maxtext/layers/attention_op.py index 5cdd9aee6..7e85c089c 100644 --- a/primus/backends/maxtext/layers/attention_op.py +++ b/primus/backends/maxtext/layers/attention_op.py @@ -6,7 +6,8 @@ ############################################################################### import jax.numpy as jnp -from MaxText.common_types import MODEL_MODE_TRAIN, Array, AttentionType +from MaxText.common_types import DEFAULT_MASK_VALUE, MODEL_MODE_TRAIN, Array, AttentionType +from MaxText.layers import nnx_wrappers from MaxText.layers.attention_op import AttentionOp @@ -23,47 +24,50 @@ def cudnn_flash_attention( model_mode: str = MODEL_MODE_TRAIN, ) -> Array: """CUDNN Flash Attention with Transformer Engine. - 1. Stable API, supports GQA, SWA (only with causal masking) - 2. Head_dim = 256 is also supported from TE-1.12 stable release with CUDNN 12.6 + 1. Stable API, supports MHA, GQA, SWA, Packing and Context Parallelism + 2. Context Parallelism currently only supports causal masking and no packing """ # These imports are only meant to work in a GPU build. # pylint: disable=import-outside-toplevel - from transformer_engine.jax.flax.transformer import ( - DotProductAttention, # pytype: disable=import-error - ) + from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error + from transformer_engine.jax.attention import SequenceDescriptor # pytype: disable=import-error _, _, _, head_dim = query.shape # pylint: disable=unused-variable using_context_parallelism = self.mesh.shape["context"] > 1 - if self.attention_type == AttentionType.LOCAL_SLIDING and using_context_parallelism: - raise AssertionError( - "Sliding window attention is not supported when context parallelism is enabled" - ) - + # Initialize default attention configuration sliding_window_size = None mask_type = "padding_causal" - qkv_layout = "BSHD_BSHD_BSHD" # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + qkv_layout = "BSHD_BSHD_BSHD" # Non-packed format: 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' max_segments_per_seq = 1 # max number of segments per sequence; for non-packed its 1 + # Handle local sliding window attention if configured if self.attention_type == AttentionType.LOCAL_SLIDING: sliding_window_size = [self.sliding_window_size, 0] + # Handle packing configurations if self.config.packing and self.config.dataset_type != "synthetic": + qkv_layout = "THD_THD_THD" # Packed format: 'T3HD', 'THD_T2HD' or 'THD_THD_THD' if decoder_segment_ids is None: decoder_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) - attn_mask = SequenceDescriptor.from_segment_ids_and_pos( - segment_ids=decoder_segment_ids, segment_pos=None - ) - qkv_layout = "THD_THD_THD" # 'T3HD', 'THD_T2HD' or 'THD_THD_THD' - max_segments_per_seq = 32 - elif ( - using_context_parallelism or self.config.dataset_type == "synthetic" - ): # context parallelism currently only supports causal masking and no packing + attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=decoder_segment_ids, segment_pos=None) + # Create dummy SequenceDescriptor for lazy_init + dummy_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) + dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=dummy_segment_ids, segment_pos=None) + max_segments_per_seq = self.config.max_segments_per_seq + elif using_context_parallelism and self.config.dataset_type == "synthetic": + if self.attention_type == AttentionType.LOCAL_SLIDING: + raise AssertionError("Sliding window attention is not supported for context parallelism") + # Context parallelism without packing: only supports causal masking attn_mask = None + dummy_attn_mask = None mask_type = "causal" else: + # Default case: no packing, no context parallelism + dummy_attn_mask = jnp.zeros((1, 1, 1, self.max_target_length, self.max_target_length), dtype=jnp.uint8) attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) + attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8) dpa_layer = DotProductAttention( head_dim=head_dim, @@ -76,11 +80,23 @@ def cudnn_flash_attention( dtype=self.dtype, float32_logits=self.float32_logits, qkv_layout=qkv_layout, - # scale_factor=1.0, + scale_factor=1.0, transpose_batch_sequence=False, window_size=sliding_window_size, context_parallel_causal_load_balanced=self.config.context_parallel_load_balance, context_parallel_axis="context", + # context_parallel_strategy=self.config.context_parallel_strategy, max_segments_per_seq=max_segments_per_seq, ) - return dpa_layer(query, key, value, mask=attn_mask) + + dpa_layer = nnx_wrappers.ToNNX(dpa_layer, rngs=self.rngs) + dummy_query_prefill = jnp.zeros( + (1, self.max_target_length, self.num_query_heads, self.config.head_dim), dtype=self.dtype + ) + dummy_key_prefill = jnp.zeros((1, self.max_target_length, self.num_kv_heads, self.config.head_dim), dtype=self.dtype) + dummy_value_prefill = jnp.zeros( + (1, self.max_target_length, self.num_kv_heads, self.config.head_dim), dtype=self.dtype + ) + + dpa_layer.lazy_init(dummy_query_prefill, dummy_key_prefill, dummy_value_prefill, sequence_descriptor=dummy_attn_mask) + return dpa_layer(query, key, value, sequence_descriptor=attn_mask) diff --git a/primus/backends/maxtext/layers/attentions.py b/primus/backends/maxtext/layers/attentions.py deleted file mode 100644 index ceb880d84..000000000 --- a/primus/backends/maxtext/layers/attentions.py +++ /dev/null @@ -1,49 +0,0 @@ -############################################################################### -# Copyright 2023–2025 Google LLC. All rights reserved. -# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. -# -# See LICENSE for license information. -############################################################################### - -from typing import Tuple - -from flax import nnx -from MaxText.layers.attentions import Attention -from MaxText.layers.linears import DenseGeneral - - -class PrimusAttention(Attention): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def init_query_w(self, inputs_q_shape: Tuple) -> nnx.Module: - """Query projection initialization.""" - - # NOTE: T5 does not explicitly rescale the attention logits by - # 1/sqrt(depth_kq)! This is folded into the initializers of the - # linear transformations, which is equivalent under Adafactor. - # depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) - depth_scaling = 1.0 - - def query_init(*args): - # pylint: disable=no-value-for-parameter - return self.kernel_init(*args) / depth_scaling - - kernel_axes = ( - (None, None, None) - if self.config.ici_context_autoregressive_parallelism > 1 - else ("embed", "q_heads", "kv") - ) - return DenseGeneral( - in_features_shape=self.convert_dense_general_inputs_shape(inputs_q_shape), - out_features_shape=(self.num_query_heads, self.head_dim), - axis=-1, - kernel_init=query_init, - kernel_axes=kernel_axes, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - quant=self.quant, - matmul_precision=self.config.matmul_precision, - use_bias=self.use_bias_in_projections, - rngs=self.rngs, - ) diff --git a/primus/backends/maxtext/layers/gemma.py b/primus/backends/maxtext/layers/gemma.py new file mode 100644 index 000000000..7170cdf13 --- /dev/null +++ b/primus/backends/maxtext/layers/gemma.py @@ -0,0 +1,119 @@ +############################################################################### +# Copyright 2023–2025 Google LLC. All rights reserved. +# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +from typing import Optional + +from flax import nnx +from jax.sharding import Mesh + +from MaxText import max_utils +from MaxText.common_types import MODEL_MODE_PREFILL, Config +from MaxText.layers import quantizations +from MaxText.layers.attentions import Attention +from MaxText.layers.linears import MlpBlock +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.quantizations import AqtQuantization as Quant + +from MaxText.layers.gemma import GemmaDecoderLayer + + +class PrimusGemmaDecoderLayer(GemmaDecoderLayer): + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: Optional[Quant] = None, + *, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + + self.pre_self_attention_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.self_attention = Attention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=self.mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=config.float32_qk_product, + float32_logits=config.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(config), + use_ragged_attention=config.use_ragged_attention, + ragged_block_size=config.ragged_block_size, + query_pre_attn_scalar=(config.head_dim**-0.5), + model_mode=self.model_mode, + rngs=self.rngs, + ) + + if config.use_post_attn_norm: + self.post_self_attention_norm_global = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.pre_ffw_norm_global = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.mlp_global = MlpBlock( + config=config, + mesh=self.mesh, + in_features=config.emb_dim, + intermediate_dim=config.mlp_dim, + activations=config.mlp_activations, + intermediate_dropout_rate=config.dropout_rate, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + quant=self.quant, + model_mode=self.model_mode, + rngs=self.rngs, + ) + + if config.use_post_ffw_norm: + self.post_ffw_norm_global = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + if model_mode == MODEL_MODE_PREFILL: + self.activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") + else: + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") \ No newline at end of file diff --git a/primus/backends/maxtext/layers/gemma2.py b/primus/backends/maxtext/layers/gemma2.py new file mode 100644 index 000000000..d4707059d --- /dev/null +++ b/primus/backends/maxtext/layers/gemma2.py @@ -0,0 +1,195 @@ +############################################################################### +# Copyright 2023–2025 Google LLC. All rights reserved. +# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +from typing import Optional + +from flax import nnx +from jax.sharding import Mesh + +from MaxText import max_utils +from MaxText.common_types import Config, MODEL_MODE_PREFILL +from MaxText.layers import quantizations +from MaxText.layers.attentions import Attention, AttentionType +from MaxText.layers.linears import MlpBlock, Dropout +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.quantizations import AqtQuantization as Quant + +from MaxText.layers.gemma2 import Gemma2DecoderLayer + + +class PrimusGemma2DecoderLayer(Gemma2DecoderLayer): + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: Optional[Quant] = None, + *, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + + self.pre_self_attention_norm_local = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.self_attention_local = Attention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=self.mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=config.float32_qk_product, + float32_logits=config.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(config), + attention_type=AttentionType.LOCAL_SLIDING, + sliding_window_size=config.sliding_window_size, + attn_logits_soft_cap=config.attn_logits_soft_cap, + query_pre_attn_scalar=(config.head_dim**-0.5), + model_mode=self.model_mode, + rngs=self.rngs, + ) + + if config.use_post_attn_norm: + self.post_self_attention_norm_local = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.pre_ffw_norm_local = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.mlp_local = MlpBlock( + config=config, + mesh=self.mesh, + in_features=config.emb_dim, + intermediate_dim=config.mlp_dim, + activations=config.mlp_activations, + intermediate_dropout_rate=config.dropout_rate, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + quant=self.quant, + model_mode=self.model_mode, + rngs=self.rngs, + ) + + if config.use_post_ffw_norm: + self.post_ffw_norm_local = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) + + self.pre_self_attention_norm_global = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.self_attention_global = Attention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=self.mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=True, + float32_logits=True, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(config), + attention_type=AttentionType.GLOBAL, + attn_logits_soft_cap=config.attn_logits_soft_cap, + query_pre_attn_scalar=(config.head_dim**-0.5), + model_mode=model_mode, + rngs=self.rngs, + ) + + if config.use_post_attn_norm: + self.post_self_attention_norm_global = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.pre_ffw_norm_global = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.mlp_global = MlpBlock( + config=config, + mesh=self.mesh, + in_features=config.emb_dim, + intermediate_dim=config.mlp_dim, + activations=config.mlp_activations, + intermediate_dropout_rate=config.dropout_rate, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + quant=self.quant, + model_mode=self.model_mode, + rngs=self.rngs, + ) + + if config.use_post_ffw_norm: + self.post_ffw_norm_global = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + if model_mode == MODEL_MODE_PREFILL: + self.activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") + else: + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") diff --git a/primus/backends/maxtext/layers/llama2.py b/primus/backends/maxtext/layers/llama2.py new file mode 100644 index 000000000..7a3b8e6a5 --- /dev/null +++ b/primus/backends/maxtext/layers/llama2.py @@ -0,0 +1,115 @@ +############################################################################### +# Copyright 2023–2025 Google LLC. All rights reserved. +# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +import functools + +from flax import nnx +from jax.sharding import Mesh + +from MaxText import max_utils +from MaxText.sharding import maybe_shard_with_logical +from MaxText.common_types import MODEL_MODE_PREFILL, Config +from MaxText.layers import quantizations +from MaxText.layers.attentions import Attention +from MaxText.layers.linears import MlpBlock, Dropout +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.quantizations import AqtQuantization as Quant +from MaxText.layers.llama2 import LlamaDecoderLayer + + +class PrimusLlamaDecoderLayer(LlamaDecoderLayer): + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: None | Quant = None, + ): + + self.config = config + self.mesh = mesh + self.quant = quant + + if model_mode == MODEL_MODE_PREFILL: + self.activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") + else: + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + + self.pre_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + shard_mode=config.shard_mode, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + + self.self_attention = Attention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=config.float32_qk_product, + float32_logits=config.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(config), + prefill_cache_axis_order=tuple(map(int, config.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, config.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, config.compute_axis_order.split(","))), + reshape_q=config.reshape_q, + use_ragged_attention=config.use_ragged_attention, + ragged_block_size=config.ragged_block_size, + query_pre_attn_scalar=(config.head_dim**-0.5), + model_mode=model_mode, + rngs=rngs, + ) + + self.post_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + shard_mode=config.shard_mode, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + + self.mlp = MlpBlock( + in_features=config.emb_dim, + intermediate_dim=config.mlp_dim, + activations=config.mlp_activations, + intermediate_dropout_rate=config.dropout_rate, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + config=config, + mesh=mesh, + quant=self.quant, + model_mode=model_mode, + rngs=rngs, + ) + + self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) + + self._maybe_shard_with_logical = functools.partial( + maybe_shard_with_logical, + mesh=self.mesh, + shard_mode=config.shard_mode, + ) \ No newline at end of file diff --git a/primus/backends/maxtext/layers/mistral.py b/primus/backends/maxtext/layers/mistral.py new file mode 100644 index 000000000..37ae9e68c --- /dev/null +++ b/primus/backends/maxtext/layers/mistral.py @@ -0,0 +1,103 @@ +############################################################################### +# Copyright 2023–2025 Google LLC. All rights reserved. +# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +from flax import nnx +from jax.sharding import Mesh + +from MaxText import max_utils +from MaxText.common_types import Config + +from MaxText.layers import quantizations +from MaxText.layers.attentions import Attention +from MaxText.layers.linears import Dropout, MlpBlock +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.quantizations import AqtQuantization as Quant +from MaxText.layers.mistral import MistralDecoderLayer + + +class PrimusMistralDecoderLayer(MistralDecoderLayer): + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + *, + rngs: nnx.Rngs, + quant: None | Quant = None, + ): + self.config = config + self.mesh = mesh + self.quant = quant + self.rngs = rngs + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + + self.pre_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=self.rngs, + ) + + self.self_attention = Attention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=config.float32_qk_product, + float32_logits=config.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(config), + prefill_cache_axis_order=tuple(map(int, config.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, config.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, config.compute_axis_order.split(","))), + reshape_q=config.reshape_q, + use_ragged_attention=config.use_ragged_attention, + ragged_block_size=config.ragged_block_size, + query_pre_attn_scalar=(config.head_dim**-0.5), + model_mode=model_mode, + rngs=self.rngs, + ) + + self.post_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=self.rngs, + ) + + self.mlp = MlpBlock( + mesh=self.mesh, + in_features=config.emb_dim, + intermediate_dim=config.mlp_dim, + activations=config.mlp_activations, + intermediate_dropout_rate=config.dropout_rate, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + config=config, + quant=self.quant, + model_mode=model_mode, + rngs=self.rngs, + ) + + self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) + + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") diff --git a/primus/backends/maxtext/layers/mixtral.py b/primus/backends/maxtext/layers/mixtral.py new file mode 100644 index 000000000..e19fce01f --- /dev/null +++ b/primus/backends/maxtext/layers/mixtral.py @@ -0,0 +1,108 @@ +############################################################################### +# Copyright 2023–2025 Google LLC. All rights reserved. +# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +from flax import linen as nn +from flax import nnx +from jax.sharding import Mesh + +from MaxText import max_utils +from MaxText.common_types import Config +from MaxText.layers import initializers +from MaxText.layers import moe +from MaxText.layers import quantizations +from MaxText.layers.attentions import Attention +from MaxText.layers.linears import Dropout +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.quantizations import AqtQuantization as Quant + +from MaxText.layers.mistral import MixtralDecoderLayer + + +class PrimusMixtralDecoderLayer(MixtralDecoderLayer): + @nn.compact + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: None | Quant = None, + *, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + + self.pre_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=self.rngs, + ) + + self.self_attention = Attention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=config.float32_qk_product, + float32_logits=config.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(config), + prefill_cache_axis_order=tuple(map(int, config.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, config.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, config.compute_axis_order.split(","))), + reshape_q=config.reshape_q, + use_ragged_attention=config.use_ragged_attention, + ragged_block_size=config.ragged_block_size, + query_pre_attn_scalar=(config.head_dim**-0.5), + model_mode=model_mode, + rngs=self.rngs, + ) + + self.post_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=self.rngs, + ) + + self.MoeBlock_0 = moe.RoutedMoE( + config=config, + num_experts=config.num_experts, + num_experts_per_tok=config.num_experts_per_tok, + mesh=mesh, + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + intermediate_dim=config.mlp_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + quant=self.quant, + rngs=self.rngs, + ) + + self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) + + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") \ No newline at end of file diff --git a/primus/backends/maxtext/layers/moe.py b/primus/backends/maxtext/layers/moe.py index 10b0ea36c..b36396725 100644 --- a/primus/backends/maxtext/layers/moe.py +++ b/primus/backends/maxtext/layers/moe.py @@ -59,6 +59,9 @@ def dense_matmul( w0_kernel, w1_kernel, wo_kernel, + w0_bias, + w1_bias, + wo_bias, ) -> tuple[jax.Array, Optional[jax.Array]]: """Dense matrix multiplication.""" if self.config.expert_balance: @@ -82,7 +85,7 @@ def dense_matmul( gate_logits = jnp.broadcast_to(rotated_weights[None, :, :], (batch_size, seq_len, num_experts)) ############################################# end #################################################### ########################################## - return super().dense_matmul(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel) + return super().dense_matmul(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias) def sparse_matmul( self, @@ -92,6 +95,9 @@ def sparse_matmul( w0_kernel, w1_kernel, wo_kernel, + w0_bias, + w1_bias, + wo_bias, ): """Perform sparse matrix multiplication with optional Primus Turbo backend.""" if not self.config.use_turbo_grouped_gemm: @@ -109,7 +115,7 @@ def sparse_matmul( # Fallback to original implementation if primus_turbo is not available max_logging.log("WARNING: primus_turbo not available, using default ragged_dot in MoE") return super().sparse_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel + inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias ) max_logging.log("Using primus_turbo grouped_gemm in MoE") @@ -129,7 +135,7 @@ def _turbo_ragged_dot(*, lhs, rhs, group_sizes, preferred_element_type=None, **k jax.lax.ragged_dot = _turbo_ragged_dot try: return super().sparse_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel + inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias ) finally: jax.lax.ragged_dot = _orig_ragged_dot diff --git a/primus/backends/maxtext/max_utils.py b/primus/backends/maxtext/max_utils.py index 84ed82253..ac1802261 100644 --- a/primus/backends/maxtext/max_utils.py +++ b/primus/backends/maxtext/max_utils.py @@ -10,7 +10,134 @@ import socket import jax +import orbax.checkpoint as ocp +from orbax.checkpoint.experimental.emergency.multi_tier_checkpointing import initialization + from MaxText import max_logging +from MaxText.max_utils import _retrieve_jax_init_info, is_gpu_backend, is_cpu_backend, get_coordinator_ip_address + + +def maybe_initialize_jax_distributed_system(raw_keys): + """The best recipe to initialize the Jax Distributed System has varied over time. We keep a layer of + indirection in MaxText to avoid breaking the call sites unnecessarily. + + Currently jax.distributed.initialize() fully works as expected! + + For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments. + """ + if raw_keys["skip_jax_distributed_system"]: + max_logging.log("Skipping jax distributed system due to skip_jax_distributed_system=True flag.") + return + if raw_keys["enable_single_controller"]: + max_logging.log("Skipping jax distributed system since its not needed for single controller.") + return + if jax.distributed.is_initialized(): + max_logging.log("Jax distributed system is already initialized.") + return + if raw_keys["inference_benchmark_test"]: + # Disable initialization for inference benmark test. + return + if raw_keys["compile_topology"]: + # Don't initialize jax distributed with AOT compilation + return + if is_gpu_backend(raw_keys): + max_logging.log("Attempting to initialize the jax distributed system for GPU backend...") + initialize_jax_for_gpu(raw_keys) + max_logging.log("Jax distributed system initialized on GPU!") + elif is_cpu_backend(raw_keys): + max_logging.log("Attempting to initialize the jax distributed system for CPU backend...") + initialize_jax_for_cpu(raw_keys) + max_logging.log("Jax distributed system initialized on CPUs!") + elif raw_keys["enable_multi_tier_checkpointing"]: + max_logging.log("Attempting to initialize the jax distributed system for multi-tier " "checkpointing...") + initialization.initialize_multi_tier_checkpointing( + local_checkpoint_directory=raw_keys["local_checkpoint_directory"], + backup_interval_minutes=raw_keys["multi_tier_checkpointing_backup_interval_minutes"], + run_name=raw_keys["run_name"], + jax_initialization_timeout_seconds=raw_keys["jax_distributed_initialization_timeout"], + data_parallelism=raw_keys["mtc_data_parallelism"], + ) + max_logging.log("Jax distributed system initialized for multi-tier checkpointing!") + elif (raw_keys["enable_checkpointing"] and raw_keys["compile_topology_num_slices"] == -1) or raw_keys[ + "hardware" + ] == "gpu_multiprocess": + max_logging.log("Attempting to initialize the jax distributed system...") + if not raw_keys["enable_emergency_checkpoint"]: + jax.distributed.initialize( + initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + heartbeat_timeout_seconds=raw_keys["jax_distributed_heartbeat_timeout_seconds"], + ) + else: + if raw_keys["hardware"] == "gpu_multiprocess": + max_logging.log("Initializing jax distribtued to support local checkpointing with" " GPUs...") + jax.distributed.initialize( + initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + heartbeat_timeout_seconds=raw_keys["jax_distributed_heartbeat_timeout_seconds"], + ) + ocp.multihost.initialize_runtime_to_distributed_ids() + ocp.multihost.initialize_distributed_to_device_ids() + else: + initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys) + max_logging.log("Jax distributed system initialized!") + + +def initialize_jax_for_gpu(raw_keys): + """Jax distributed initialize for GPUs.""" + if os.environ.get("JAX_COORDINATOR_IP") is not None: + coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP")) + coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT")) + jax.distributed.initialize( + coordinator_address=f"{coordinator_ip}:{coordinator_port}", + num_processes=int(os.getenv("NNODES")), + process_id=int(os.getenv("NODE_RANK")), + initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + heartbeat_timeout_seconds=raw_keys["jax_distributed_heartbeat_timeout_seconds"], + ) + max_logging.log(f"JAX global devices: {jax.devices()}") + + +def initialize_jax_for_cpu(raw_keys): + """Jax distributed initialize for CPUs. Includes retries until the coordinator is ready.""" + coordinator_ip_address = get_coordinator_ip_address() + coordinator_address = coordinator_ip_address + ":1234" # JAX coordinator port used in XPK + # Env variables to be set in XPK or otherwise + job_index = int(os.environ.get("JOB_INDEX")) + job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX")) + processes_in_job = int(os.environ.get("PROCESSES_IN_JOB")) + pid = job_index * processes_in_job + job_completion_index + max_logging.log(f" Jax process id is {pid} ") + # Explicit initialize is needed only for CPUs + jax.distributed.initialize( + coordinator_address=coordinator_address, + process_id=pid, + num_processes=int(os.environ.get("JAX_PROCESS_COUNT")), + initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + heartbeat_timeout_seconds=raw_keys["jax_distributed_heartbeat_timeout_seconds"], + ) + + +def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys): + """Initialize JAX distributed runtime for TPUs when emergency checkpointing is used. + The information required to initialize JAX distributed runtime will be written by GKE to + the local checkpoint directory. This function retrieves that information and initializes + JAX distributed runtime. + """ + process_id, coordinator_address = _retrieve_jax_init_info(raw_keys) + + if process_id != "" and coordinator_address != "": + max_logging.log( + f"Using {process_id} as the process_id and {coordinator_address} as the" + " coordinator_address to initialize JAX distributed runtime..." + ) + jax.distributed.initialize( + coordinator_address=coordinator_address, + process_id=int(process_id), + initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + heartbeat_timeout_seconds=raw_keys["jax_distributed_heartbeat_timeout_seconds"], + ) + + ocp.multihost.initialize_runtime_to_distributed_ids() + ocp.multihost.initialize_distributed_to_device_ids() def print_system_information(): diff --git a/primus/backends/maxtext/train.py b/primus/backends/maxtext/train.py index 7aee03126..d5d582a36 100644 --- a/primus/backends/maxtext/train.py +++ b/primus/backends/maxtext/train.py @@ -51,142 +51,136 @@ from MaxText.vertex_tensorboard import VertexTensorboardManager -def validate_train_config(config): - """Validates the configuration is set correctly for 'train.py'.""" - - assert config.run_name, "Erroring out, need a real run_name" - if config.dataset_path and not config.dataset_path.startswith("gs://"): - max_logging.log("WARNING: 'dataset_path' might be pointing your local file system") - if not config.base_output_directory.startswith("gs://"): - max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system") - assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive integer." - - if config.quantization in ("fp8", "nanoo_fp8"): - # pylint: disable=line-too-long - assert config.gradient_accumulation_steps == 1, ( - "fp8 can't be used with gradient_accumulation_steps right now. Please use other quantization or set " - "gradient_accumulation_steps to 1" - ) - - # Check if GPU Flash Attention is being used with sequence packing - # if config.attention == "cudnn_flash_te" and config.packing and config.dataset_type != "synthetic": - # raise ValueError( - # "cudnn_flash_te only supports BSHD format. The THD (seq packing) support is going to be available in " - # "Transformer Engine 2.0 release. " - # "Please disable sequence packing (set packing=False) or use a different attention mechanism. " - # "With synthetic data, the format is not important as packing is not applied." - # ) - - def train_loop(config, recorder, state=None): - """Main Training loop.""" - ( - init_rng, - checkpoint_manager, - state_mesh_shardings, - model, - mesh, - learning_rate_schedule, - data_iterator, - eval_data_iterator, - state, - ) = setup_train_loop(config, recorder) - - if config.use_dpo: - if "reference_params" not in state.params: - reference_params = jax.tree.map(jnp.copy, state.params["params"]) - state = _merge_dpo_state(state, reference_params) - state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) - - p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( - config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator - ) + """Main Training loop.""" + ( + init_rng, + checkpoint_manager, + state_mesh_shardings, + model, + mesh, + learning_rate_schedule, + data_iterator, + data_loader, + rampup_manager, + eval_data_iterator, + state, + ) = train_utils.setup_train_loop(config, recorder) + + if config.use_dpo: + if "reference_params" not in state.params: + reference_params = jax.tree.map(jnp.copy, state.params["params"]) + state = _merge_dpo_state(state, reference_params) + state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + + params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) + + p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( + config, + model, + mesh, + state, + state_mesh_shardings, + train_step, + eval_step, + eval_data_iterator, + params_shardings, + ) + + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + shaped_batch = maxtext_utils.get_shaped_batch(config) + if config.shard_optimizer_over_data: + state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) + if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded + compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() + compiled_stats = compiled.memory_analysis() + max_utils.print_compiled_memory_stats(compiled_stats) + + start_step = get_first_step(state) # this is the start_step for training + prof = profiler.Profiler(config, offset_step=start_step) + metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) + + # Write train config params, num model params, and XLA flags to tensorboard + metric_logger.write_setup_info_to_tensorboard(state.params) + + try: + last_step_completion = datetime.datetime.now() + for step in np.arange(start_step, config.steps): + prof.maybe_activate_profiler(step, state) + + with jax.profiler.StepTraceAnnotation("train", step_num=step): + example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager) + # Reshard data from loaded sharding to performant activation sharding + example_batch = sharding.maybe_shard_with_name( + example_batch, + sharding.get_input_data_sharding(config, mesh), + shard_mode=config.shard_mode, + ) + # pylint: disable=not-callable + nextrng = jax.jit(jax.random.fold_in)(init_rng, step) + with maybe_record_goodput(recorder, GoodputEvent.STEP, step): + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + if config.shard_optimizer_over_data: + state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) + state, metrics = p_train_step(state, example_batch, nextrng) + jax.block_until_ready(state) + + step_time_delta = datetime.datetime.now() - last_step_completion + last_step_completion = datetime.datetime.now() + + state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) + + if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): + jax.block_until_ready(state) # Ensure compilation has finished. + gcs_utils.upload_dump( + config.dump_hlo_local_dir, + config.dump_hlo_gcs_dir, + module_name=config.dump_hlo_module_name, + delete_local_after=config.dump_hlo_delete_local_after, + all_host_upload=config.dump_hlo_upload_all, + ) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - shaped_batch = maxtext_utils.get_shaped_batch(config) - compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() - compiled_stats = compiled.memory_analysis() - max_utils.print_compiled_memory_stats(compiled_stats) - - start_step = get_first_step(state) # this is the start_step for training - prof = profiler.Profiler(config, offset_step=start_step) - data_loader = DataLoader(config, mesh, data_iterator, recorder) - metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) - - # Write train config params, num model params, and XLA flags to tensorboard - metric_logger.write_setup_info_to_tensorboard(state.params) - - try: - last_step_completion = datetime.datetime.now() - for step in np.arange(start_step, config.steps): - prof.maybe_activate_profiler(step, state) - - with jax.profiler.StepTraceAnnotation("train", step_num=step): - example_batch = data_loader.load_next_batch() - # pylint: disable=not-callable - nextrng = jax.jit(jax.random.fold_in)(init_rng, step) - with maybe_record_goodput(recorder, GoodputEvent.STEP, step): - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - state, metrics = p_train_step(state, example_batch, nextrng) - jax.block_until_ready(state) - - step_time_delta = datetime.datetime.now() - last_step_completion - last_step_completion = datetime.datetime.now() - - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - checkpointing.maybe_save_checkpoint( - checkpoint_manager, state_to_save, config, data_iterator, step - ) - - if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): - jax.block_until_ready(state) # Ensure compilation has finished. - gcs_utils.upload_dump( - config.dump_hlo_local_dir, - config.dump_hlo_gcs_dir, - module_name=config.dump_hlo_module_name, - delete_local_after=config.dump_hlo_delete_local_after, - all_host_upload=config.dump_hlo_upload_all, - ) - - if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: - assert eval_data_iterator - - # Explicitly reset the eval counters before starting the eval loop - metric_logger.reset_eval_metrics() - - eval_step_count = 0 - # pylint: disable=not-callable - for eval_batch in eval_data_iterator: - if config.eval_steps > 0 and eval_step_count >= config.eval_steps: - break - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - eval_metrics = p_eval_step(state, eval_batch, nextrng) - metric_logger.record_eval_metrics(step, metrics=eval_metrics) - max_logging.log(f"Completed eval step {eval_step_count}") - eval_step_count += 1 - metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) - if ( - metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] - <= config.target_eval_loss - ): - prof.deactivate() - raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") - - prof.maybe_deactivate_profiler(step, state) - - if step == start_step: - max_utils.print_mem_stats("After params initialized") - - metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) - - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) - except exceptions.StopTraining as e: - max_logging.log(f"Training stopped: {str(e)}") - finally: - metric_logger.flush_metrics_and_cleanup() - - return state + if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: + assert eval_data_iterator + # Explicitly reset the eval iterator and counters before starting the eval loop + eval_data_iterator.reset() + metric_logger.reset_eval_metrics() + + eval_step_count = 0 + # pylint: disable=not-callable + for eval_batch in eval_data_iterator: + if config.eval_steps > 0 and eval_step_count >= config.eval_steps: + break + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + eval_metrics = p_eval_step(state, eval_batch, nextrng) + metric_logger.record_eval_metrics(step, metrics=eval_metrics) + max_logging.log(f"Completed eval step {eval_step_count}") + eval_step_count += 1 + metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) + if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: + prof.deactivate() + raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") + + prof.maybe_deactivate_profiler(step, state) + + if step == start_step: + max_utils.print_mem_stats("After params initialized") + + metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) + + if config.save_checkpoint_on_completion: + state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) + if checkpoint_manager is not None: + # in case the last checkpoint_period checkpoint is still in progress + checkpoint_manager.wait_until_finished() + except exceptions.StopTraining as e: + max_logging.log(f"Training stopped: {str(e)}") + finally: + metric_logger.flush_metrics_and_cleanup() + + return state def initialize(argv: Sequence[str], **kwargs) -> tuple[pyconfig.HyperParameters, Any, Any]: @@ -231,6 +225,10 @@ def initialize(argv: Sequence[str], **kwargs) -> tuple[pyconfig.HyperParameters, def run(config, recorder, diagnostic_config): """Run the job given hyperparameters and utilities""" - with diagnostic.diagnose(diagnostic_config): - with maybe_record_goodput(recorder, GoodputEvent.JOB): - train_loop(config, recorder) + with ( + diagnostic.diagnose(diagnostic_config), + maybe_record_goodput(recorder, GoodputEvent.JOB), + max_utils.maybe_get_transformer_engine_context(config), + maybe_monitor_goodput(config), + ): + train_loop(config, recorder) \ No newline at end of file diff --git a/primus/backends/maxtext/train_utils.py b/primus/backends/maxtext/train_utils.py index 42641caaa..bfd70b1f2 100644 --- a/primus/backends/maxtext/train_utils.py +++ b/primus/backends/maxtext/train_utils.py @@ -15,26 +15,25 @@ def create_training_tools(config, model, mesh): learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) logger = checkpointing.setup_checkpoint_logger(config) - if config.enable_emergency_checkpoint: - if config.use_replicator_service: - checkpoint_manager = checkpointing.create_orbax_emergency_replicator_checkpoint_manager( - config.local_checkpoint_directory, - config.local_checkpoint_period, - mesh, - ) - else: - abstract_state, _, _ = maxtext_utils.get_abstract_state( - model, tx, config, init_rng, mesh, is_training=True - ) - checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager( - config.local_checkpoint_directory, - config.checkpoint_dir, - mesh, - abstract_state, - config.local_checkpoint_period, - config.checkpoint_period, - logger, - ) + if config.enable_multi_tier_checkpointing: + checkpoint_manager = checkpointing.create_orbax_emergency_replicator_checkpoint_manager( + config.local_checkpoint_directory, + config.local_checkpoint_period, + mesh, + ) + elif config.enable_emergency_checkpoint: + abstract_state, _, _ = maxtext_utils.get_abstract_state( + model, tx, config, init_rng, mesh, is_training=True + ) + checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager( + config.local_checkpoint_directory, + config.checkpoint_dir, + mesh, + abstract_state, + config.local_checkpoint_period, + config.checkpoint_period, + logger, + ) else: # TODO(b/368121306): Remove this once zarr3 support is plumbed on the backend use_ocdbt = config.checkpoint_storage_use_ocdbt @@ -54,7 +53,7 @@ def create_training_tools(config, model, mesh): logger, use_ocdbt, use_zarr3, - config.max_to_keep, + config.max_num_checkpoints_to_keep, ) return init_rng, checkpoint_manager, learning_rate_schedule, tx diff --git a/primus/configs/modules/maxtext/trainer_base.yaml b/primus/configs/modules/maxtext/trainer_base.yaml index f7700b1cb..fae170c65 100644 --- a/primus/configs/modules/maxtext/trainer_base.yaml +++ b/primus/configs/modules/maxtext/trainer_base.yaml @@ -48,7 +48,6 @@ async_checkpointing: true checkpoint_period: 10_000 # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: false -max_to_keep: 5 force_unroll: false # during generate_param_only_checkpoint should we unroll the loop? @@ -121,9 +120,11 @@ save_quantized_params_path: "" model_call_mode: "" use_qwix_quantization: false # Whether to use qwix for quantization. If set to True, the model will be quantized using qwix. # Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 -quantization_calibration_method: "absmax" # Shard the range finding operation for quantization. By default this is set to number of slices. quantization_local_shard_count: -1 +act_quantization_calibration_method: absmax +bwd_quantization_calibration_method: absmax +weight_quantization_calibration_method: absmax # Global parameter scale needs to be a power of 2. If you want finer grained control of the model sizes # then you should explicitly set base_embed_dim, base_num_query_heads, base_num_kv_heads, @@ -154,10 +155,6 @@ megablox: true sparse_matmul: true capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default use_turbo_grouped_gemm: false # requires sparse_matmul=true and JAX_ENABLE_X64=1 -# Tunable tiling dimensions used for Megablox -tile_batch_seq: 512 -tile_activation_dim: 1024 -tile_weight_dim: 1024 # How the expert axis is used to shard attention weights and activations # "fsdp" (ep acts as fsdp parallelism) @@ -246,6 +243,8 @@ param_scan_axis: 1 # The attention parameter dictates the specific algorithm/methodology used to compute the attention scores # The attention_type parameter determines the variants of attention, e.g. global or local_sliding # moved to model_base.yaml +attention_bias: False +attention_sink: False # MLA parameters # moved to model_base.yaml @@ -271,12 +270,6 @@ local_checkpoint_directory: "" # It should be a positive number when and only when `enable_emergency_checkpoint` is True. local_checkpoint_period: 0 -# Whether to use emergency checkpoint with the replicator service. -use_replicator_service: false - -# The interval to backup local checkpoints to the persistent storage. -replicator_backup_interval_minutes: 0 - # Jax cache directory jax_cache_dir: "~/jax_cache" @@ -412,7 +405,6 @@ train_image_column: 'image' eval_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected" eval_image_column: 'image' packing: true -max_segments: 32 num_epoch: 1 # only grain and tfds pipeline supports num_epoch > 1 # direct preference optimization (DPO) @@ -449,6 +441,7 @@ hf_access_token: '' # For more details, see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-input-pipeline grain_train_files: '' grain_eval_files: '' +grain_train_mixture_config_path: '' # Path to a JSON file specifying the mixture weights for Grain training data. grain_file_type: 'arrayrecord' # arrayrecord or parquet grain_worker_count: 1 grain_worker_count_eval: 1 @@ -733,3 +726,90 @@ projector_dropout_for_vit: 0.0 # Subslice shape in the form of "x,y,z" when using pathways (single controller). # Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium. subslice_shape: "" + +# tile +wi_tile_dlhs_batch_seq: 512 +wi_tile_dlhs_embed_dim: 1024 +wi_tile_dlhs_mlp_dim: 1024 +wi_tile_drhs_batch_seq: 512 +wi_tile_drhs_embed_dim: 1024 +wi_tile_drhs_mlp_dim: 1024 +wi_tile_fwd_batch_seq: 512 +wi_tile_fwd_embed_dim: 1024 +wi_tile_fwd_mlp_dim: 1024 +wo_tile_dlhs_batch_seq: 512 +wo_tile_dlhs_embed_dim: 1024 +wo_tile_dlhs_mlp_dim: 1024 +wo_tile_drhs_batch_seq: 512 +wo_tile_drhs_embed_dim: 1024 +wo_tile_drhs_mlp_dim: 1024 +wo_tile_fwd_batch_seq: 512 +wo_tile_fwd_embed_dim: 1024 +wo_tile_fwd_mlp_dim: 1024 + +chat_template_path: "" +context_parallel_strategy: all_gather +conv_stride_for_vit: 14 +cost_estimate_flops_bwd: -1 +cost_estimate_flops_fwd: -1 +debug.rl: False +deepstack_visual_indexes_for_vit: [] +dq_reduction_steps: 0 +enable_multi_tier_checkpointing: False +enable_nnx: False +enable_rampup_batch_size: False +float32_weight_sum: True +fsdp_shard_on_exp: False +gdn_chunk_size: 64 +gdn_conv_kernel_dim: 4 +gdn_key_head_dim: 128 +gdn_num_key_heads: 16 +gdn_num_value_heads: 32 +gdn_value_head_dim: 128 +generate_padding_batch_eval: False +generate_padding_batch_train: False +global_rampup_samples: 500 +grad_dtype: float32 +grain_data_source_max_workers: 16 +grain_num_threads: 16 +grain_num_threads_eval: 16 +grain_per_worker_buffer_size: 1 +grain_per_worker_buffer_size_eval: 1 +grain_prefetch_buffer_size: 500 +grain_prefetch_buffer_size_eval: 500 +hide_profiler_step_metric: False +max_num_checkpoints_to_keep: None +max_num_images_per_example: -1 +max_segments_per_seq: 32 +mlp_activations_limit: -1.0 +mlp_bias: False +moba: False +moba_chunk_size: 1024 +moba_topk: 8 +moe_fsdp_use_two_stage_all_gather: False +mtc_data_parallelism: 0 +multi_tier_checkpointing_backup_interval_minutes: 0 +num_position_embeddings_for_vit: 1024 +num_vocab_tiling: 1 +out_hidden_size_for_vit: 512 +partial_rotary_factor: 1.0 +per_device_batch_size_increment: 2.0 +per_device_batch_size_start: 4.0 +posemb_type_for_vit: learn +rope_attention_scaling: False +rope_interleave: True +rope_linear_scaling_factor: 1.0 +rope_truncate: True +save_checkpoint_on_completion: True +shard_mode: auto +shard_optimizer_over_data: False +spatial_merge_size_for_vit: 2 +temporal_patch_size_for_vit: 2 +use_batch_split_schedule: False +use_custom_sort_vjp: True +use_max_logit_estimate: -1 +use_qk_norm_in_gdn: True +use_ring_of_experts: False +use_tokamax_gmm: False +use_tokamax_splash: False +use_truncation: True From a506657990f171122a2fea26d8342242d30b30ef Mon Sep 17 00:00:00 2001 From: liyingli Date: Mon, 9 Feb 2026 07:09:35 +0000 Subject: [PATCH 03/24] update maxtext: part II --- primus/backends/maxtext/configs/types.py | 246 +++++++++++++++++- .../backends/maxtext/layers/attention_op.py | 2 +- primus/backends/maxtext/layers/gemma.py | 31 +-- primus/backends/maxtext/max_utils.py | 33 +-- primus/backends/maxtext/train.py | 107 ++++---- .../configs/modules/maxtext/trainer_base.yaml | 13 +- primus/modules/trainer/maxtext/pre_trainer.py | 47 +++- primus/pretrain.py | 2 + 8 files changed, 367 insertions(+), 114 deletions(-) diff --git a/primus/backends/maxtext/configs/types.py b/primus/backends/maxtext/configs/types.py index c79555ed7..99f905912 100644 --- a/primus/backends/maxtext/configs/types.py +++ b/primus/backends/maxtext/configs/types.py @@ -1,6 +1,100 @@ +############################################################################### +# Copyright 2023–2025 Google LLC. All rights reserved. +# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +import os + +from pydantic import BaseModel from pydantic.fields import Field +from pydantic import model_validator +from pydantic import ConfigDict +from typing import Any + +from MaxText.configs.types import ( + MoEGeneral, + DevelopmentAndDebugging, + MaxTextConfig, + # Run and Checkpointing + RunInfo, + Checkpointing, + OrbaxStorage, + EmergencyCheckpointing, + # Data Types and Quantization + DataTypes, + Quantization, + # Core Model Architecture + ModelArchitecture, + MTP, + Logits, + # Attention Mechanisms + Attention, + MlaAttention, + MoBa, + Llama4Attention, + SplashAttention, + PagedAttention, + # Mixture of Experts + MoEKernels, + DeepSeekMoE, + Qwen3Next, + # Parallelism and Layout + HardwareAndMesh, + LayoutAndSharding, + DcnParallelism, + IciParallelism, + PipelineParallelism, + # Training, Optimization, and Fine-Tuning + RematAndOffload, + TrainingLoop, + Optimizer, + AdamW, + FineTuning, + # Reinforcement Learning + RLHardware, + VLLM, + GRPO, + RLDataset, + RLEvaluation, + Reward, + SpecialTokens, + # Positional Embeddings + PositionalEmbedding, + Rope, + YarnRope, + # Dataset Loading and Tokenization + DatasetGeneral, + TfdsDataset, + HfDataset, + GrainDataset, + Tokenizer, + # Inference + InferenceGeneral, + Decoding, + InferenceLayout, + InferenceServer, + InferenceBenchmark, + PrefixCaching, + # Development and Debugging + AOT, + Profiling, + HloDump, + StackTrace, + # Metrics and Monitoring + Metrics, + Goodput, + GcpMonitoring, + Tensorboard, + # Multimodal + MultimodalGeneral, + VisionTower, + VisionProjector, + # Derived + DerivedValues, + Debug, +) -from MaxText.configs.types import MoEGeneral, DevelopmentAndDebugging class PrimusMoEGeneral(MoEGeneral): expert_balance: bool = Field(False, description="Whether to use expert balancing.") @@ -9,4 +103,152 @@ class PrimusMoEGeneral(MoEGeneral): class PrimusDevelopmentAndDebugging(DevelopmentAndDebugging): jax_distributed_heartbeat_timeout_seconds: int = Field( 100, description="How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores." - ) \ No newline at end of file + ) + + +class PrimusTurboConfig(BaseModel): + enable_primus_turbo: bool = Field(False, description="Whether to enable Primus Turbo.") + use_turbo_grouped_gemm: bool = Field(False, description="Whether to use turbo grouped gemm.") + + +class PrimusWandbConfig(BaseModel): + enable_wandb: bool = Field(False, description="Whether to enable WandB.") + wandb_project: None | str = Field(None, description="The name of the WandB project.") + wandb_exp_name: None | str = Field(None, description="The name of the WandB experiment, derived from the run_name if not set.") + wandb_save_dir: None | str = Field(None, description="The directory to save the WandB logs.") + + +class PrimusMaxTextConfig( + # Run and Checkpointing + RunInfo, + Checkpointing, + OrbaxStorage, + EmergencyCheckpointing, + # Data Types and Quantization + DataTypes, + Quantization, + # Core Model Architecture + ModelArchitecture, + MTP, + Logits, + # Attention Mechanisms + Attention, + MlaAttention, + MoBa, + Llama4Attention, + SplashAttention, + PagedAttention, + # Mixture of Experts - REPLACED with PrimusMoEGeneral + PrimusMoEGeneral, # Replaces MoEGeneral + MoEKernels, + DeepSeekMoE, + Qwen3Next, + # Parallelism and Layout + HardwareAndMesh, + LayoutAndSharding, + DcnParallelism, + IciParallelism, + PipelineParallelism, + # Training, Optimization, and Fine-Tuning + RematAndOffload, + TrainingLoop, + Optimizer, + AdamW, + FineTuning, + # Reinforcement Learning + RLHardware, + VLLM, + GRPO, + RLDataset, + RLEvaluation, + Reward, + SpecialTokens, + # Positional Embeddings + PositionalEmbedding, + Rope, + YarnRope, + # Dataset Loading and Tokenization + DatasetGeneral, + TfdsDataset, + HfDataset, + GrainDataset, + Tokenizer, + # Inference + InferenceGeneral, + Decoding, + InferenceLayout, + InferenceServer, + InferenceBenchmark, + PrefixCaching, + # Development and Debugging - REPLACED with PrimusDevelopmentAndDebugging + AOT, + PrimusDevelopmentAndDebugging, # Replaces DevelopmentAndDebugging + Profiling, + HloDump, + StackTrace, + # Metrics and Monitoring + Metrics, + Goodput, + GcpMonitoring, + Tensorboard, + # Multimodal + MultimodalGeneral, + VisionTower, + VisionProjector, + # Primus-specific configs - ADDED + PrimusTurboConfig, + PrimusWandbConfig, + # Derived + DerivedValues, +): + """ + The main configuration object for Primus MaxText. + + This class extends MaxTextConfig with Primus-specific configurations: + - Replaces MoEGeneral with PrimusMoEGeneral (adds expert_balance) + - Replaces DevelopmentAndDebugging with PrimusDevelopmentAndDebugging (adds jax_distributed_heartbeat_timeout_seconds) + - Adds PrimusTurboConfig (Primus Turbo optimizations) + - Adds PrimusWandbConfig (WandB integration) + + All other functionality from MaxTextConfig is preserved. + """ + + debug: Debug = Field(default_factory=Debug) + model_config = ConfigDict(extra="forbid", protected_namespaces=()) + + @model_validator(mode="before") + @classmethod + def load_model_specific_defaults(cls, values: dict[str, Any]) -> dict[str, Any]: + """This method is a no-op because `pyconfig` handles model-specific config loading.""" + return values + + @model_validator(mode="after") + def set_derived_and_validate_values(self) -> "PrimusMaxTextConfig": + """ + Computes all derived values and runs all cross-field validations after initial parsing. + This calls the MaxTextConfig's validation logic and then adds any Primus-specific validations. + """ + # Call MaxTextConfig's validation logic directly since we're using composition via multiple inheritance + # rather than direct inheritance. MaxTextConfig.set_derived_and_validate_values expects a MaxTextConfig + # instance, but since we have all the same base classes, we can call it on self. + # We need to temporarily cast self to MaxTextConfig for the method call, or call the method directly. + # Actually, since MaxTextConfig's method works on the same fields we have, we can call it directly. + MaxTextConfig.set_derived_and_validate_values(self) + + # Add any Primus-specific validations here if needed + if self.wandb_save_dir is None or self.wandb_save_dir == "" and self.base_output_directory: + self.wandb_save_dir = os.path.join(self.base_output_directory, "wandb") + + if self.wandb_project is None or self.wandb_project == "": + self.wandb_project = os.getenv("WANDB_PROJECT", "Primus-MaxText-Pretrain") + + if self.wandb_exp_name is None or self.wandb_exp_name == "" and self.run_name: + self.wandb_exp_name = self.run_name + + if self.enable_wandb and "WANDB_API_KEY" not in os.environ: + raise ValueError("WANDB_API_KEY is not set. Please set it or login wandb before proceeding") + + if not self.enable_primus_turbo: + self.use_turbo_grouped_gemm = False + + return self diff --git a/primus/backends/maxtext/layers/attention_op.py b/primus/backends/maxtext/layers/attention_op.py index 7e85c089c..d104caf5c 100644 --- a/primus/backends/maxtext/layers/attention_op.py +++ b/primus/backends/maxtext/layers/attention_op.py @@ -56,7 +56,7 @@ def cudnn_flash_attention( dummy_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=dummy_segment_ids, segment_pos=None) max_segments_per_seq = self.config.max_segments_per_seq - elif using_context_parallelism and self.config.dataset_type == "synthetic": + elif using_context_parallelism or self.config.dataset_type == "synthetic": if self.attention_type == AttentionType.LOCAL_SLIDING: raise AssertionError("Sliding window attention is not supported for context parallelism") # Context parallelism without packing: only supports causal masking diff --git a/primus/backends/maxtext/layers/gemma.py b/primus/backends/maxtext/layers/gemma.py index 7170cdf13..f7a1085cc 100644 --- a/primus/backends/maxtext/layers/gemma.py +++ b/primus/backends/maxtext/layers/gemma.py @@ -11,10 +11,10 @@ from jax.sharding import Mesh from MaxText import max_utils -from MaxText.common_types import MODEL_MODE_PREFILL, Config +from MaxText.common_types import Config from MaxText.layers import quantizations from MaxText.layers.attentions import Attention -from MaxText.layers.linears import MlpBlock +from MaxText.layers.linears import MlpBlock, Dropout from MaxText.layers.normalizations import RMSNorm from MaxText.layers.quantizations import AqtQuantization as Quant @@ -73,16 +73,7 @@ def __init__( rngs=self.rngs, ) - if config.use_post_attn_norm: - self.post_self_attention_norm_global = RMSNorm( - num_features=config.emb_dim, - dtype=config.dtype, - weight_dtype=config.weight_dtype, - kernel_axes=("norm",), - rngs=self.rngs, - ) - - self.pre_ffw_norm_global = RMSNorm( + self.pre_ffw_norm = RMSNorm( num_features=config.emb_dim, dtype=config.dtype, weight_dtype=config.weight_dtype, @@ -90,7 +81,7 @@ def __init__( rngs=self.rngs, ) - self.mlp_global = MlpBlock( + self.mlp = MlpBlock( config=config, mesh=self.mesh, in_features=config.emb_dim, @@ -104,16 +95,6 @@ def __init__( rngs=self.rngs, ) - if config.use_post_ffw_norm: - self.post_ffw_norm_global = RMSNorm( - num_features=config.emb_dim, - dtype=config.dtype, - weight_dtype=config.weight_dtype, - kernel_axes=("norm",), - rngs=self.rngs, - ) + self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) - if model_mode == MODEL_MODE_PREFILL: - self.activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") - else: - self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") \ No newline at end of file + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") \ No newline at end of file diff --git a/primus/backends/maxtext/max_utils.py b/primus/backends/maxtext/max_utils.py index ac1802261..a5c0b3e8d 100644 --- a/primus/backends/maxtext/max_utils.py +++ b/primus/backends/maxtext/max_utils.py @@ -182,38 +182,11 @@ def initialize_wandb_writer(config): if jax.process_index() != 0 or not config.enable_wandb: return None - def safe_get_config(config, key, default=None): - try: - return getattr(config, key) - except KeyError: - return default - import wandb + os.makedirs(config.wandb_save_dir, exist_ok=True) - if safe_get_config(config, "wandb_save_dir") is None or config.wandb_save_dir == "": - wandb_save_dir = os.path.join(config.base_output_directory, "wandb") - else: - wandb_save_dir = config.wandb_save_dir - - if safe_get_config(config, "wandb_project") is None or config.wandb_project == "": - wandb_project = os.getenv("WANDB_PROJECT", "Primus-MaxText-Pretrain") - else: - wandb_project = config.wandb_project - if safe_get_config(config, "wandb_exp_name") is None or config.wandb_exp_name == "": - wandb_exp_name = config.run_name - else: - wandb_exp_name = config.wandb_exp_name - - if config.enable_wandb and "WANDB_API_KEY" not in os.environ: - max_logging.log( - "The environment variable WANDB_API_KEY is not set. Please set it or login wandb before proceeding" - ) - return None - - os.makedirs(wandb_save_dir, exist_ok=True) - - wandb.init(project=wandb_project, name=wandb_exp_name, dir=wandb_save_dir, config=dict(config.get_keys())) - max_logging.log(f"WandB logging enabled: {wandb_save_dir=}, {wandb_project=}, {wandb_exp_name=}") + wandb.init(project=config.wandb_project, name=config.wandb_exp_name, dir=config.wandb_save_dir, config=dict(config.get_keys())) + max_logging.log(f"WandB logging enabled: {config.wandb_save_dir=}, {config.wandb_project=}, {config.wandb_exp_name=}") return wandb diff --git a/primus/backends/maxtext/train.py b/primus/backends/maxtext/train.py index d5d582a36..c134afc9c 100644 --- a/primus/backends/maxtext/train.py +++ b/primus/backends/maxtext/train.py @@ -29,18 +29,19 @@ maxtext_utils, profiler, pyconfig, + sharding, train_utils, ) -from MaxText.data_loader import DataLoader +from MaxText.common_types import ShardMode from MaxText.metric_logger import MetricLogger from MaxText.train import ( _merge_dpo_state, _split_dpo_state, eval_step, get_first_step, - setup_train_loop, train_step, ) +from MaxText.train_utils import validate_train_config from MaxText.utils import gcs_utils from MaxText.utils.goodput_utils import ( GoodputEvent, @@ -103,6 +104,16 @@ def train_loop(config, recorder, state=None): # Write train config params, num model params, and XLA flags to tensorboard metric_logger.write_setup_info_to_tensorboard(state.params) + # Synchronize all hosts before entering the training loop. + # Without this barrier, timing variance during initialization (JIT compilation, + # profiler/logger setup, etc.) causes hosts to enter the training loop at different + # times. The first collective operation (data sharding in load_next_batch) then + # times out waiting for straggler hosts, resulting in "collective operation timeout" + # or "stop sending heartbeats" errors. + max_logging.log("====== BARRIER: Synchronizing hosts before training loop ======") + jax.experimental.multihost_utils.sync_global_devices("sync_before_training_loop") + max_logging.log("====== BARRIER PASSED: Starting training loop ======") + try: last_step_completion = datetime.datetime.now() for step in np.arange(start_step, config.steps): @@ -183,52 +194,54 @@ def train_loop(config, recorder, state=None): return state -def initialize(argv: Sequence[str], **kwargs) -> tuple[pyconfig.HyperParameters, Any, Any]: - """Initialization of hyperparameters and utilities""" - pathwaysutils.initialize() - jax.config.update("jax_default_prng_impl", "unsafe_rbg") - # TF allocates extraneous GPU memory when using TFDS data - # this leads to CUDA OOMs. WAR for now is to hide GPUs from TF - tf.config.set_visible_devices([], "GPU") - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): - os.environ["LIBTPU_INIT_ARGS"] = ( - os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" - ) - # TODO: mazumdera@ : ensure missing mandatory fields in base.yml are filled in in argv, - # or fill in here - config = pyconfig.initialize(argv, **kwargs) - jax.config.update("jax_use_shardy_partitioner", config.shardy) - max_utils.print_system_information() - validate_train_config(config) - max_utils.save_device_information(config) - os.environ["TFDS_DATA_DIR"] = config.dataset_path or "" - vertex_tensorboard_manager = VertexTensorboardManager() - if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): - vertex_tensorboard_manager.configure_vertex_tensorboard(config) - - # Goodput configurations - maybe_monitor_goodput(config) - recorder = create_goodput_recorder(config) - - # Stack traces configurations - debug_config = debug_configuration.DebugConfig( - stack_trace_config=stack_trace_configuration.StackTraceConfig( - collect_stack_trace=config.collect_stack_trace, - stack_trace_to_cloud=config.stack_trace_to_cloud, - stack_trace_interval_seconds=config.stack_trace_interval_seconds, - ) +def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any]: + """Initialization of hyperparameters and utilities""" + pathwaysutils.initialize() + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + # TF allocates extraneous GPU memory when using TFDS data + # this leads to CUDA OOMs. WAR for now is to hide GPUs from TF + tf.config.set_visible_devices([], "GPU") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): + os.environ["LIBTPU_INIT_ARGS"] = ( + os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + ) + # TODO: mazumdera@ : ensure missing mandatory fields in base.yml are filled in in argv, + # or fill in here + config = pyconfig.initialize(argv) + max_utils.print_system_information() + validate_train_config(config) + max_utils.save_device_information(config) + jax.config.update("jax_use_shardy_partitioner", config.shardy) + # update explicit sharding-supported config + if config.shard_mode == ShardMode.EXPLICIT: + jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) + os.environ["TFDS_DATA_DIR"] = config.dataset_path or "" + vertex_tensorboard_manager = VertexTensorboardManager() + if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): + vertex_tensorboard_manager.configure_vertex_tensorboard(config) + + # Create the Goodput recorder + recorder = create_goodput_recorder(config) + + # Stack traces configurations + debug_config = debug_configuration.DebugConfig( + stack_trace_config=stack_trace_configuration.StackTraceConfig( + collect_stack_trace=config.collect_stack_trace, + stack_trace_to_cloud=config.stack_trace_to_cloud, + stack_trace_interval_seconds=config.stack_trace_interval_seconds, + ) ) - diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config) - return config, recorder, diagnostic_config + diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config) + return config, recorder, diagnostic_config def run(config, recorder, diagnostic_config): - """Run the job given hyperparameters and utilities""" - with ( - diagnostic.diagnose(diagnostic_config), - maybe_record_goodput(recorder, GoodputEvent.JOB), - max_utils.maybe_get_transformer_engine_context(config), - maybe_monitor_goodput(config), - ): - train_loop(config, recorder) \ No newline at end of file + """Run the job given hyperparameters and utilities""" + with ( + diagnostic.diagnose(diagnostic_config), + maybe_record_goodput(recorder, GoodputEvent.JOB), + max_utils.maybe_get_transformer_engine_context(config), + maybe_monitor_goodput(config), + ): + train_loop(config, recorder) diff --git a/primus/configs/modules/maxtext/trainer_base.yaml b/primus/configs/modules/maxtext/trainer_base.yaml index fae170c65..fdc93daf3 100644 --- a/primus/configs/modules/maxtext/trainer_base.yaml +++ b/primus/configs/modules/maxtext/trainer_base.yaml @@ -282,6 +282,7 @@ logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], ['activation_length', ['sequence', 'context', 'expert']], @@ -301,7 +302,7 @@ logical_axis_rules: [ ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']], + ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['activation_vocab', ['sequence','context']], @@ -310,6 +311,7 @@ logical_axis_rules: [ ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', ['sequence']], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], + ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], @@ -322,15 +324,17 @@ logical_axis_rules: [ ['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']], ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], ['embed_no_exp', ['fsdp', 'sequence', 'context']], - ['embed_tensor_transpose', ['tensor_transpose']], + ['embed_tensor_transpose', ['tensor_transpose']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'expert']], + ["q_lora_up_proj",[]], ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['kv_lora', ['fsdp', 'sequence', 'context', 'expert']], + ["kv_lora_up_proj",[]], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['kv', []], @@ -347,6 +351,8 @@ logical_axis_rules: [ ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], + ['dense_layers', []], + ['moe_layers', []], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] @@ -576,9 +582,6 @@ report_performance_metric_for_gcp_monitoring: false enable_tensorboard: true enable_wandb: false -wandb_project: "" -wandb_exp_name: "" -wandb_save_dir: "" # Vertex AI Tensorboard Configurations - https://github.com/google/maxtext/tree/main/getting_started/Use_Vertex_AI_Tensorboard.md # Set to True for GCE, False if running via XPK diff --git a/primus/modules/trainer/maxtext/pre_trainer.py b/primus/modules/trainer/maxtext/pre_trainer.py index 48443a2cb..4b6417352 100644 --- a/primus/modules/trainer/maxtext/pre_trainer.py +++ b/primus/modules/trainer/maxtext/pre_trainer.py @@ -22,6 +22,7 @@ def __init__(self, *args, **kwargs): self.patch_max_utils() self.patch_checkpoint() self.patch_input_pipeline() + self.patch_config_types() self.patch_layers() self.primus_cfg = kwargs.pop("primus_config", None) @@ -121,22 +122,35 @@ def patch_max_utils(self): import MaxText.max_utils as orig_max_utils from primus.backends.maxtext.max_utils import ( + maybe_initialize_jax_distributed_system, + initialize_jax_for_gpu, + initialize_jax_for_cpu, + initialize_jax_for_tpu_with_emergency_checkpointing, print_system_information, save_device_information, ) + orig_max_utils.maybe_initialize_jax_distributed_system = maybe_initialize_jax_distributed_system + orig_max_utils.initialize_jax_for_gpu = initialize_jax_for_gpu + orig_max_utils.initialize_jax_for_cpu = initialize_jax_for_cpu + orig_max_utils.initialize_jax_for_tpu_with_emergency_checkpointing = initialize_jax_for_tpu_with_emergency_checkpointing orig_max_utils.print_system_information = print_system_information orig_max_utils.save_device_information = save_device_information warning_rank_0("MaxText Pre-Trainer: patch max_utils successfully.") def patch_checkpoint(self): import MaxText.checkpointing as orig_checkpointing + import MaxText.train_utils as orig_train_utils from primus.backends.maxtext.checkpointing import ( + load_state_if_possible, create_orbax_checkpoint_manager, ) + from primus.backends.maxtext.train_utils import create_training_tools + orig_checkpointing.load_state_if_possible = load_state_if_possible orig_checkpointing.create_orbax_checkpoint_manager = create_orbax_checkpoint_manager + orig_train_utils.create_training_tools = create_training_tools warning_rank_0("MaxText Pre-Trainer: patch checkpointing successfully.") def patch_wandb(self): @@ -177,6 +191,13 @@ def patch_input_pipeline(self): warning_rank_0("MaxText Pre-Trainer: patch _hf_data_processing successfully.") + def patch_config_types(self): + import MaxText.configs.types as orig_config_types + from primus.backends.maxtext.configs.types import PrimusMaxTextConfig + + orig_config_types.MaxTextConfig = PrimusMaxTextConfig + warning_rank_0("MaxText Pre-Trainer: patch config types successfully.") + def patch_layers(self): def patch_quantization(): import MaxText.layers.quantizations as orig_quantizations @@ -191,18 +212,14 @@ def patch_quantization(): patch_quantization() def patch_attn(): - import MaxText.layers.attention_mla as orig_attention_mla import MaxText.layers.attention_op as orig_attention_op import MaxText.layers.attentions as orig_attentions from primus.backends.maxtext.layers.attention_op import PrimusAttentionOp - from primus.backends.maxtext.layers.attentions import PrimusAttention orig_attention_op.AttentionOp = PrimusAttentionOp orig_attentions.AttentionOp = PrimusAttentionOp - orig_attentions.Attention = PrimusAttention - orig_attention_mla.Attention = PrimusAttention warning_rank_0("MaxText Pre-Trainer: patch Attention successfully.") patch_attn() @@ -216,3 +233,25 @@ def patch_moe(): warning_rank_0("MaxText Pre-Trainer: patch RoutedMoE successfully.") patch_moe() + + def patch_decoder_layer(): + import MaxText.layers.gemma as orig_gemma + import MaxText.layers.gemma2 as orig_gemma2 + import MaxText.layers.llama2 as orig_llama2 + import MaxText.layers.mistral as orig_mistral + import MaxText.layers.mixtral as orig_mixtral + + from primus.backends.maxtext.layers.gemma import PrimusGemmaDecoderLayer + from primus.backends.maxtext.layers.gemma2 import PrimusGemma2DecoderLayer + from primus.backends.maxtext.layers.llama2 import PrimusLlamaDecoderLayer + from primus.backends.maxtext.layers.mistral import PrimusMistralDecoderLayer + from primus.backends.maxtext.layers.mixtral import PrimusMixtralDecoderLayer + + orig_gemma.GemmaDecoderLayer = PrimusGemmaDecoderLayer + orig_gemma2.Gemma2DecoderLayer = PrimusGemma2DecoderLayer + orig_llama2.LlamaDecoderLayer = PrimusLlamaDecoderLayer + orig_mistral.MistralDecoderLayer = PrimusMistralDecoderLayer + orig_mixtral.MixtralDecoderLayer = PrimusMixtralDecoderLayer + warning_rank_0("MaxText Pre-Trainer: patch decoder layer successfully.") + + patch_decoder_layer() \ No newline at end of file diff --git a/primus/pretrain.py b/primus/pretrain.py index f1136a458..62ba2b0e8 100644 --- a/primus/pretrain.py +++ b/primus/pretrain.py @@ -112,6 +112,8 @@ def setup_backend_path(framework: str, backend_path=None, verbose: bool = True): } mapped_name = fallback_name_map.get(framework, framework) default_path = Path(__file__).resolve().parent.parent / "third_party" / mapped_name + if framework == "maxtext" and os.path.join(default_path, "src").exists(): + default_path = os.path.join(default_path, "src") candidate_paths.append(default_path) # Normalize & deduplicate From a6f37d8c8142ed1b3943163304019eb0ab0cf259 Mon Sep 17 00:00:00 2001 From: liyingli Date: Mon, 9 Feb 2026 10:25:48 +0000 Subject: [PATCH 04/24] update maxtext III --- examples/run_local_pretrain.sh | 2 +- examples/run_pretrain.sh | 11 +- primus/backends/maxtext/checkpointing.py | 185 +++++---- primus/backends/maxtext/configs/types.py | 139 +++---- .../input_pipeline/_hf_data_processing.py | 21 +- .../backends/maxtext/layers/attention_op.py | 35 +- primus/backends/maxtext/layers/gemma.py | 8 +- primus/backends/maxtext/layers/gemma2.py | 14 +- primus/backends/maxtext/layers/llama2.py | 15 +- primus/backends/maxtext/layers/mistral.py | 4 +- primus/backends/maxtext/layers/mixtral.py | 10 +- primus/backends/maxtext/layers/moe.py | 24 +- primus/backends/maxtext/max_utils.py | 32 +- primus/backends/maxtext/metric_logger.py | 8 +- primus/backends/maxtext/train.py | 388 +++++++++--------- .../configs/modules/maxtext/trainer_base.yaml | 3 +- primus/modules/trainer/maxtext/pre_trainer.py | 15 +- primus/pretrain.py | 7 +- 18 files changed, 498 insertions(+), 423 deletions(-) diff --git a/examples/run_local_pretrain.sh b/examples/run_local_pretrain.sh index 4cb36fdb3..f891922a9 100755 --- a/examples/run_local_pretrain.sh +++ b/examples/run_local_pretrain.sh @@ -42,7 +42,7 @@ EXP=${EXP:-"examples/megatron/exp_pretrain.yaml"} # Default docker image if [ "${BACKEND:-}" = "MaxText" ]; then - DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/jax-training:maxtext-v25.9"} + DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/mad-private:jax_rocm7.1_jax_0.8.2_ci_e5be0ef_20260131_v3"} else DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/primus:v26.1"} fi diff --git a/examples/run_pretrain.sh b/examples/run_pretrain.sh index dd23ec02d..442693e9f 100755 --- a/examples/run_pretrain.sh +++ b/examples/run_pretrain.sh @@ -102,11 +102,11 @@ export DATA_PATH=${DATA_PATH:-"${PRIMUS_PATH}/data"} export HF_HOME=${HF_HOME:-"${DATA_PATH}/huggingface"} LOG_INFO_RANK0 "Pip installing required packages ..." -if [ "${BACKEND:-}" != "MaxText" ]; then - pip install -r "$PRIMUS_PATH/requirements.txt" --quiet -else - pip install -r "$PRIMUS_PATH/requirements-jax.txt" --quiet -fi +# if [ "${BACKEND:-}" != "MaxText" ]; then +# pip install -r "$PRIMUS_PATH/requirements.txt" --quiet +# else +# pip install -r "$PRIMUS_PATH/requirements-jax.txt" --quiet +# fi FAULT_TOLERANCE_VALUE=$(EXAMPLE_FAULT_TOLERANCE "$@") @@ -274,7 +274,6 @@ if [ "${BACKEND:-}" == "MaxText" ]; then export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 export XLA_PYTHON_CLIENT_MEM_FRACTION=.97 export NVTE_USE_HIPBLASLT=1 - export XLA_FLAGS="--xla_gpu_memory_limit_slop_factor=95 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_graph_level=0 --xla_gpu_enable_latency_hiding_scheduler=True --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_enable_triton_gemm=False --xla_gpu_enable_cublaslt=True --xla_gpu_autotune_level=0 --xla_gpu_enable_all_gather_combine_by_dim=FALSE" if [ "${DUMP_HLO}" = "1" ]; then mkdir -p "${DUMP_HLO_DIR}" export XLA_FLAGS="$XLA_FLAGS --xla_dump_to=$DUMP_HLO_DIR" diff --git a/primus/backends/maxtext/checkpointing.py b/primus/backends/maxtext/checkpointing.py index 1034f32e5..8c493f18a 100644 --- a/primus/backends/maxtext/checkpointing.py +++ b/primus/backends/maxtext/checkpointing.py @@ -8,17 +8,20 @@ from typing import Any import jax - -from etils import epath -from flax.training import train_state import orbax.checkpoint as ocp import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager - +from etils import epath +from flax.training import train_state from MaxText import max_logging -from MaxText.checkpointing import _replica_devices, _restore_grain_iterator, load_params_from_path, _load_full_state_from_path -from MaxText.multihost_dataloading import MultiHostDataLoadIterator +from MaxText.checkpointing import ( + _load_full_state_from_path, + _replica_devices, + _restore_grain_iterator, + load_params_from_path, +) from MaxText.input_pipeline.input_pipeline_interface import PlaceHolderDataIterator +from MaxText.multihost_dataloading import MultiHostDataLoadIterator Composite = ocp.args.Composite EmergencyCheckpointManager = emergency_checkpoint_manager.CheckpointManager @@ -41,7 +44,7 @@ def load_state_if_possible( checkpoint_conversion_fn=None, source_checkpoint_layout="orbax", expansion_factor_real_data: int = -1, - ): +): """Loads TrainState as possible from the inputs. Args: @@ -71,91 +74,95 @@ def load_state_if_possible( if checkpoint_manager is not None: max_logging.log("checkpoint manager exists so trying to load this run's existing checkpoint") - step = checkpoint_manager.latest_step() if step < 0 else step - if step is not None: - max_logging.log(f"restoring from this run's directory step {step}") - - def map_to_pspec(data): - if not enable_single_replica_ckpt_restoring: - return ocp.type_handlers.ArrayRestoreArgs(sharding=data.sharding) - pspec = data.sharding.spec - mesh = data.sharding.mesh - replica_axis_index = 0 - replica_devices = _replica_devices(mesh.devices, replica_axis_index) - replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names) - single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec) - - return ocp.type_handlers.SingleReplicaArrayRestoreArgs( - sharding=jax.sharding.NamedSharding(mesh, pspec), - single_replica_sharding=single_replica_sharding, - global_shape=data.shape, - dtype=data.dtype, - ) - - # Cache the original ArrayHandler before potentially overriding it. - # This is the same handler used when enable_single_replica_ckpt_restoring=False. - original_array_handler = ocp.type_handlers.get_type_handler(jax.Array) - - # Register SingleReplicaArrayHandler globally for restore (if enabled) - if enable_single_replica_ckpt_restoring: - single_replica_handler = ocp.type_handlers.SingleReplicaArrayHandler( - replica_axis_index=0, - broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit - ) - ocp.type_handlers.register_type_handler(jax.Array, single_replica_handler, override=True) - - restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) - checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args) + step = checkpoint_manager.latest_step() if step < 0 else step + if step is not None: + max_logging.log(f"restoring from this run's directory step {step}") + + def map_to_pspec(data): + if not enable_single_replica_ckpt_restoring: + return ocp.type_handlers.ArrayRestoreArgs(sharding=data.sharding) + pspec = data.sharding.spec + mesh = data.sharding.mesh + replica_axis_index = 0 + replica_devices = _replica_devices(mesh.devices, replica_axis_index) + replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names) + single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec) + + return ocp.type_handlers.SingleReplicaArrayRestoreArgs( + sharding=jax.sharding.NamedSharding(mesh, pspec), + single_replica_sharding=single_replica_sharding, + global_shape=data.shape, + dtype=data.dtype, + ) - def _restore_original_array_handler(): - """Restore the original ArrayHandler after SingleReplicaArrayHandler restore. + # Cache the original ArrayHandler before potentially overriding it. + # This is the same handler used when enable_single_replica_ckpt_restoring=False. + original_array_handler = ocp.type_handlers.get_type_handler(jax.Array) - This is critical because SingleReplicaArrayHandler is designed for restore only. - Using it for saves will cause missing array_metadatas files and checkpoint failures. - We restore the EXACT handler that was in place before, not a new instance. - """ + # Register SingleReplicaArrayHandler globally for restore (if enabled) if enable_single_replica_ckpt_restoring: - max_logging.log("Restoring original ArrayHandler after SingleReplicaArrayHandler restore...") - # Re-register the original handler that was cached before the override - ocp.type_handlers.register_type_handler(jax.Array, original_array_handler, override=True) - max_logging.log("Original ArrayHandler restored successfully.") - - match (checkpoint_manager, dataset_type, data_iterator): - # Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager - # or EmergencyReplicatorCheckpointManager. The '_' indicates that 'dataset_type' and - # 'data_iterator' can be any value and aren't used in this pattern. - case (checkpoint_manager, _, _) if isinstance( - checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager) - ): - result = ( - checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state, - None, + single_replica_handler = ocp.type_handlers.SingleReplicaArrayHandler( + replica_axis_index=0, + broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit ) - _restore_original_array_handler() - return result - # Case 2: Matches if dataset type is "grain" and the data iterator is not a - # PlaceHolderDataIterator and a specific checkpoint file exists for the iterator - case ( - checkpoint_manager, - dataset_type, - data_iterator, - ) if ( - dataset_type == "grain" - and data_iterator - and not isinstance(data_iterator, PlaceHolderDataIterator) - and (checkpoint_manager.directory / str(step) / "iter").exists() - ): - result = _restore_grain_iterator( - checkpoint_manager, step, data_iterator, checkpoint_args, expansion_factor_real_data - ) - _restore_original_array_handler() - return result - # Case 3: Default/Fallback case. - # This case acts as a wildcard ('_') and matches if none of the preceding cases were met. - case _: - result = (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) - _restore_original_array_handler() - return result + ocp.type_handlers.register_type_handler(jax.Array, single_replica_handler, override=True) + + restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) + checkpoint_args = ocp.args.PyTreeRestore( + item=abstract_unboxed_pre_state, restore_args=restore_args + ) + + def _restore_original_array_handler(): + """Restore the original ArrayHandler after SingleReplicaArrayHandler restore. + + This is critical because SingleReplicaArrayHandler is designed for restore only. + Using it for saves will cause missing array_metadatas files and checkpoint failures. + We restore the EXACT handler that was in place before, not a new instance. + """ + if enable_single_replica_ckpt_restoring: + max_logging.log( + "Restoring original ArrayHandler after SingleReplicaArrayHandler restore..." + ) + # Re-register the original handler that was cached before the override + ocp.type_handlers.register_type_handler(jax.Array, original_array_handler, override=True) + max_logging.log("Original ArrayHandler restored successfully.") + + match (checkpoint_manager, dataset_type, data_iterator): + # Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager + # or EmergencyReplicatorCheckpointManager. The '_' indicates that 'dataset_type' and + # 'data_iterator' can be any value and aren't used in this pattern. + case (checkpoint_manager, _, _) if isinstance( + checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager) + ): + result = ( + checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state, + None, + ) + _restore_original_array_handler() + return result + # Case 2: Matches if dataset type is "grain" and the data iterator is not a + # PlaceHolderDataIterator and a specific checkpoint file exists for the iterator + case ( + checkpoint_manager, + dataset_type, + data_iterator, + ) if ( + dataset_type == "grain" + and data_iterator + and not isinstance(data_iterator, PlaceHolderDataIterator) + and (checkpoint_manager.directory / str(step) / "iter").exists() + ): + result = _restore_grain_iterator( + checkpoint_manager, step, data_iterator, checkpoint_args, expansion_factor_real_data + ) + _restore_original_array_handler() + return result + # Case 3: Default/Fallback case. + # This case acts as a wildcard ('_') and matches if none of the preceding cases were met. + case _: + result = (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) + _restore_original_array_handler() + return result if load_parameters_from_path != "": restored_params = load_params_from_path( @@ -219,7 +226,7 @@ def create_orbax_checkpoint_manager( create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async, - max_to_keep = max_to_keep, + max_to_keep=max_to_keep, ), logger=orbax_logger, ) diff --git a/primus/backends/maxtext/configs/types.py b/primus/backends/maxtext/configs/types.py index 99f905912..326ba7eaf 100644 --- a/primus/backends/maxtext/configs/types.py +++ b/primus/backends/maxtext/configs/types.py @@ -5,95 +5,77 @@ # See LICENSE for license information. ############################################################################### import os - -from pydantic import BaseModel -from pydantic.fields import Field -from pydantic import model_validator -from pydantic import ConfigDict from typing import Any -from MaxText.configs.types import ( - MoEGeneral, - DevelopmentAndDebugging, - MaxTextConfig, - # Run and Checkpointing - RunInfo, +from MaxText.configs.types import ( # Run and Checkpointing; Data Types and Quantization; Core Model Architecture; Attention Mechanisms; Mixture of Experts; Parallelism and Layout; Training, Optimization, and Fine-Tuning; Reinforcement Learning; Positional Embeddings; Dataset Loading and Tokenization; Inference; Development and Debugging; Metrics and Monitoring; Multimodal; Derived + AOT, + GRPO, + MTP, + VLLM, + AdamW, + Attention, Checkpointing, - OrbaxStorage, - EmergencyCheckpointing, - # Data Types and Quantization + DatasetGeneral, DataTypes, - Quantization, - # Core Model Architecture - ModelArchitecture, - MTP, + DcnParallelism, + Debug, + Decoding, + DeepSeekMoE, + DerivedValues, + DevelopmentAndDebugging, + EmergencyCheckpointing, + FineTuning, + GcpMonitoring, + Goodput, + GrainDataset, + HardwareAndMesh, + HfDataset, + HloDump, + IciParallelism, + InferenceBenchmark, + InferenceGeneral, + InferenceLayout, + InferenceServer, + LayoutAndSharding, + Llama4Attention, Logits, - # Attention Mechanisms - Attention, + MaxTextConfig, + Metrics, MlaAttention, MoBa, - Llama4Attention, - SplashAttention, - PagedAttention, - # Mixture of Experts + ModelArchitecture, + MoEGeneral, MoEKernels, - DeepSeekMoE, - Qwen3Next, - # Parallelism and Layout - HardwareAndMesh, - LayoutAndSharding, - DcnParallelism, - IciParallelism, + MultimodalGeneral, + Optimizer, + OrbaxStorage, + PagedAttention, PipelineParallelism, - # Training, Optimization, and Fine-Tuning + PositionalEmbedding, + PrefixCaching, + Profiling, + Quantization, + Qwen3Next, RematAndOffload, - TrainingLoop, - Optimizer, - AdamW, - FineTuning, - # Reinforcement Learning - RLHardware, - VLLM, - GRPO, + Reward, RLDataset, RLEvaluation, - Reward, - SpecialTokens, - # Positional Embeddings - PositionalEmbedding, + RLHardware, Rope, - YarnRope, - # Dataset Loading and Tokenization - DatasetGeneral, - TfdsDataset, - HfDataset, - GrainDataset, - Tokenizer, - # Inference - InferenceGeneral, - Decoding, - InferenceLayout, - InferenceServer, - InferenceBenchmark, - PrefixCaching, - # Development and Debugging - AOT, - Profiling, - HloDump, + RunInfo, + SpecialTokens, + SplashAttention, StackTrace, - # Metrics and Monitoring - Metrics, - Goodput, - GcpMonitoring, Tensorboard, - # Multimodal - MultimodalGeneral, - VisionTower, + TfdsDataset, + Tokenizer, + TrainingLoop, VisionProjector, - # Derived - DerivedValues, - Debug, + VisionTower, + YarnRope, ) +from pydantic import BaseModel, ConfigDict, model_validator +from pydantic.fields import Field class PrimusMoEGeneral(MoEGeneral): @@ -102,7 +84,8 @@ class PrimusMoEGeneral(MoEGeneral): class PrimusDevelopmentAndDebugging(DevelopmentAndDebugging): jax_distributed_heartbeat_timeout_seconds: int = Field( - 100, description="How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores." + 100, + description="How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores.", ) @@ -114,7 +97,9 @@ class PrimusTurboConfig(BaseModel): class PrimusWandbConfig(BaseModel): enable_wandb: bool = Field(False, description="Whether to enable WandB.") wandb_project: None | str = Field(None, description="The name of the WandB project.") - wandb_exp_name: None | str = Field(None, description="The name of the WandB experiment, derived from the run_name if not set.") + wandb_exp_name: None | str = Field( + None, description="The name of the WandB experiment, derived from the run_name if not set." + ) wandb_save_dir: None | str = Field(None, description="The directory to save the WandB logs.") @@ -203,16 +188,16 @@ class PrimusMaxTextConfig( ): """ The main configuration object for Primus MaxText. - + This class extends MaxTextConfig with Primus-specific configurations: - Replaces MoEGeneral with PrimusMoEGeneral (adds expert_balance) - Replaces DevelopmentAndDebugging with PrimusDevelopmentAndDebugging (adds jax_distributed_heartbeat_timeout_seconds) - Adds PrimusTurboConfig (Primus Turbo optimizations) - Adds PrimusWandbConfig (WandB integration) - + All other functionality from MaxTextConfig is preserved. """ - + debug: Debug = Field(default_factory=Debug) model_config = ConfigDict(extra="forbid", protected_namespaces=()) diff --git a/primus/backends/maxtext/input_pipeline/_hf_data_processing.py b/primus/backends/maxtext/input_pipeline/_hf_data_processing.py index d776d635e..79c059615 100644 --- a/primus/backends/maxtext/input_pipeline/_hf_data_processing.py +++ b/primus/backends/maxtext/input_pipeline/_hf_data_processing.py @@ -41,7 +41,7 @@ def preprocessing_pipeline( use_sft=None, sft_train_on_completion_only=True, grain_worker_count=1, # only support 0 or 1 - max_segments_per_seq=1, # max segments per sequence + max_segments_per_seq=1, # max segments per sequence ): """pipeline for preprocessing HF dataset""" assert ( @@ -79,7 +79,11 @@ def preprocessing_pipeline( if len(data_column_names) > 1: combined_column_name = "messages" dataset_features = datasets.Features( - {combined_column_name: [{"content": datasets.Value(dtype="string"), "role": datasets.Value(dtype="string")}]} + { + combined_column_name: [ + {"content": datasets.Value(dtype="string"), "role": datasets.Value(dtype="string")} + ] + } ) dataset = dataset.map( _input_pipeline_utils.combine_columns, @@ -158,10 +162,14 @@ def lists2array(x): operations.append(_input_pipeline_utils.ReformatPacking(data_column_names)) else: operations.append(_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id)) - operations.append(grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder)) + operations.append( + grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder) + ) if shift and not use_dpo: - operations.append(_input_pipeline_utils.ShiftData(ignored_ids=[pad_id, tokenizer.bos_token_id], axis=1)) + operations.append( + _input_pipeline_utils.ShiftData(ignored_ids=[pad_id, tokenizer.bos_token_id], axis=1) + ) # Since HuggingFace IterableDataset does not support access through index # Indexes generated by dummy_index_sampler is not used. @@ -185,11 +193,14 @@ def lists2array(x): read_options=grain.ReadOptions(num_threads=num_threads, prefetch_buffer_size=128), ) - multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh, generate_padding_batch) + multihost_gen = multihost_dataloading.MultiHostDataLoadIterator( + dataloader, global_mesh, generate_padding_batch + ) # Return multi-host jax.Array prep iterator return multihost_gen + def make_hf_train_iterator( config: ml_collections.ConfigDict, global_mesh, diff --git a/primus/backends/maxtext/layers/attention_op.py b/primus/backends/maxtext/layers/attention_op.py index d104caf5c..750b1ea41 100644 --- a/primus/backends/maxtext/layers/attention_op.py +++ b/primus/backends/maxtext/layers/attention_op.py @@ -6,7 +6,12 @@ ############################################################################### import jax.numpy as jnp -from MaxText.common_types import DEFAULT_MASK_VALUE, MODEL_MODE_TRAIN, Array, AttentionType +from MaxText.common_types import ( + DEFAULT_MASK_VALUE, + MODEL_MODE_TRAIN, + Array, + AttentionType, +) from MaxText.layers import nnx_wrappers from MaxText.layers.attention_op import AttentionOp @@ -29,8 +34,12 @@ def cudnn_flash_attention( """ # These imports are only meant to work in a GPU build. # pylint: disable=import-outside-toplevel - from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error - from transformer_engine.jax.attention import SequenceDescriptor # pytype: disable=import-error + from transformer_engine.jax.attention import ( + SequenceDescriptor, # pytype: disable=import-error + ) + from transformer_engine.jax.flax.transformer import ( + DotProductAttention, # pytype: disable=import-error + ) _, _, _, head_dim = query.shape # pylint: disable=unused-variable @@ -51,10 +60,14 @@ def cudnn_flash_attention( qkv_layout = "THD_THD_THD" # Packed format: 'T3HD', 'THD_T2HD' or 'THD_THD_THD' if decoder_segment_ids is None: decoder_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) - attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=decoder_segment_ids, segment_pos=None) + attn_mask = SequenceDescriptor.from_segment_ids_and_pos( + segment_ids=decoder_segment_ids, segment_pos=None + ) # Create dummy SequenceDescriptor for lazy_init dummy_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) - dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=dummy_segment_ids, segment_pos=None) + dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos( + segment_ids=dummy_segment_ids, segment_pos=None + ) max_segments_per_seq = self.config.max_segments_per_seq elif using_context_parallelism or self.config.dataset_type == "synthetic": if self.attention_type == AttentionType.LOCAL_SLIDING: @@ -65,7 +78,9 @@ def cudnn_flash_attention( mask_type = "causal" else: # Default case: no packing, no context parallelism - dummy_attn_mask = jnp.zeros((1, 1, 1, self.max_target_length, self.max_target_length), dtype=jnp.uint8) + dummy_attn_mask = jnp.zeros( + (1, 1, 1, self.max_target_length, self.max_target_length), dtype=jnp.uint8 + ) attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8) @@ -93,10 +108,14 @@ def cudnn_flash_attention( dummy_query_prefill = jnp.zeros( (1, self.max_target_length, self.num_query_heads, self.config.head_dim), dtype=self.dtype ) - dummy_key_prefill = jnp.zeros((1, self.max_target_length, self.num_kv_heads, self.config.head_dim), dtype=self.dtype) + dummy_key_prefill = jnp.zeros( + (1, self.max_target_length, self.num_kv_heads, self.config.head_dim), dtype=self.dtype + ) dummy_value_prefill = jnp.zeros( (1, self.max_target_length, self.num_kv_heads, self.config.head_dim), dtype=self.dtype ) - dpa_layer.lazy_init(dummy_query_prefill, dummy_key_prefill, dummy_value_prefill, sequence_descriptor=dummy_attn_mask) + dpa_layer.lazy_init( + dummy_query_prefill, dummy_key_prefill, dummy_value_prefill, sequence_descriptor=dummy_attn_mask + ) return dpa_layer(query, key, value, sequence_descriptor=attn_mask) diff --git a/primus/backends/maxtext/layers/gemma.py b/primus/backends/maxtext/layers/gemma.py index f7a1085cc..400d216e1 100644 --- a/primus/backends/maxtext/layers/gemma.py +++ b/primus/backends/maxtext/layers/gemma.py @@ -9,17 +9,15 @@ from flax import nnx from jax.sharding import Mesh - from MaxText import max_utils from MaxText.common_types import Config from MaxText.layers import quantizations from MaxText.layers.attentions import Attention -from MaxText.layers.linears import MlpBlock, Dropout +from MaxText.layers.gemma import GemmaDecoderLayer +from MaxText.layers.linears import Dropout, MlpBlock from MaxText.layers.normalizations import RMSNorm from MaxText.layers.quantizations import AqtQuantization as Quant -from MaxText.layers.gemma import GemmaDecoderLayer - class PrimusGemmaDecoderLayer(GemmaDecoderLayer): def __init__( @@ -97,4 +95,4 @@ def __init__( self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) - self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") \ No newline at end of file + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") diff --git a/primus/backends/maxtext/layers/gemma2.py b/primus/backends/maxtext/layers/gemma2.py index d4707059d..72c0fb1b9 100644 --- a/primus/backends/maxtext/layers/gemma2.py +++ b/primus/backends/maxtext/layers/gemma2.py @@ -9,17 +9,15 @@ from flax import nnx from jax.sharding import Mesh - from MaxText import max_utils -from MaxText.common_types import Config, MODEL_MODE_PREFILL +from MaxText.common_types import MODEL_MODE_PREFILL, Config from MaxText.layers import quantizations from MaxText.layers.attentions import Attention, AttentionType -from MaxText.layers.linears import MlpBlock, Dropout +from MaxText.layers.gemma2 import Gemma2DecoderLayer +from MaxText.layers.linears import Dropout, MlpBlock from MaxText.layers.normalizations import RMSNorm from MaxText.layers.quantizations import AqtQuantization as Quant -from MaxText.layers.gemma2 import Gemma2DecoderLayer - class PrimusGemma2DecoderLayer(Gemma2DecoderLayer): def __init__( @@ -190,6 +188,10 @@ def __init__( ) if model_mode == MODEL_MODE_PREFILL: - self.activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") + self.activation_axis_names = ( + "activation_batch", + "prefill_activation_norm_length", + "activation_embed", + ) else: self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") diff --git a/primus/backends/maxtext/layers/llama2.py b/primus/backends/maxtext/layers/llama2.py index 7a3b8e6a5..5225a2e55 100644 --- a/primus/backends/maxtext/layers/llama2.py +++ b/primus/backends/maxtext/layers/llama2.py @@ -9,16 +9,15 @@ from flax import nnx from jax.sharding import Mesh - from MaxText import max_utils -from MaxText.sharding import maybe_shard_with_logical from MaxText.common_types import MODEL_MODE_PREFILL, Config from MaxText.layers import quantizations from MaxText.layers.attentions import Attention -from MaxText.layers.linears import MlpBlock, Dropout +from MaxText.layers.linears import Dropout, MlpBlock +from MaxText.layers.llama2 import LlamaDecoderLayer from MaxText.layers.normalizations import RMSNorm from MaxText.layers.quantizations import AqtQuantization as Quant -from MaxText.layers.llama2 import LlamaDecoderLayer +from MaxText.sharding import maybe_shard_with_logical class PrimusLlamaDecoderLayer(LlamaDecoderLayer): @@ -36,7 +35,11 @@ def __init__( self.quant = quant if model_mode == MODEL_MODE_PREFILL: - self.activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") + self.activation_axis_names = ( + "activation_batch", + "prefill_activation_norm_length", + "activation_embed", + ) else: self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") @@ -112,4 +115,4 @@ def __init__( maybe_shard_with_logical, mesh=self.mesh, shard_mode=config.shard_mode, - ) \ No newline at end of file + ) diff --git a/primus/backends/maxtext/layers/mistral.py b/primus/backends/maxtext/layers/mistral.py index 37ae9e68c..64231a969 100644 --- a/primus/backends/maxtext/layers/mistral.py +++ b/primus/backends/maxtext/layers/mistral.py @@ -7,16 +7,14 @@ from flax import nnx from jax.sharding import Mesh - from MaxText import max_utils from MaxText.common_types import Config - from MaxText.layers import quantizations from MaxText.layers.attentions import Attention from MaxText.layers.linears import Dropout, MlpBlock +from MaxText.layers.mistral import MistralDecoderLayer from MaxText.layers.normalizations import RMSNorm from MaxText.layers.quantizations import AqtQuantization as Quant -from MaxText.layers.mistral import MistralDecoderLayer class PrimusMistralDecoderLayer(MistralDecoderLayer): diff --git a/primus/backends/maxtext/layers/mixtral.py b/primus/backends/maxtext/layers/mixtral.py index e19fce01f..997bc61c3 100644 --- a/primus/backends/maxtext/layers/mixtral.py +++ b/primus/backends/maxtext/layers/mixtral.py @@ -8,19 +8,15 @@ from flax import linen as nn from flax import nnx from jax.sharding import Mesh - from MaxText import max_utils from MaxText.common_types import Config -from MaxText.layers import initializers -from MaxText.layers import moe -from MaxText.layers import quantizations +from MaxText.layers import initializers, moe, quantizations from MaxText.layers.attentions import Attention from MaxText.layers.linears import Dropout +from MaxText.layers.mixtral import MixtralDecoderLayer from MaxText.layers.normalizations import RMSNorm from MaxText.layers.quantizations import AqtQuantization as Quant -from MaxText.layers.mistral import MixtralDecoderLayer - class PrimusMixtralDecoderLayer(MixtralDecoderLayer): @nn.compact @@ -105,4 +101,4 @@ def __init__( self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) - self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") \ No newline at end of file + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") diff --git a/primus/backends/maxtext/layers/moe.py b/primus/backends/maxtext/layers/moe.py index b36396725..49bab61ac 100644 --- a/primus/backends/maxtext/layers/moe.py +++ b/primus/backends/maxtext/layers/moe.py @@ -85,7 +85,9 @@ def dense_matmul( gate_logits = jnp.broadcast_to(rotated_weights[None, :, :], (batch_size, seq_len, num_experts)) ############################################# end #################################################### ########################################## - return super().dense_matmul(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias) + return super().dense_matmul( + inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias + ) def sparse_matmul( self, @@ -115,7 +117,15 @@ def sparse_matmul( # Fallback to original implementation if primus_turbo is not available max_logging.log("WARNING: primus_turbo not available, using default ragged_dot in MoE") return super().sparse_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, ) max_logging.log("Using primus_turbo grouped_gemm in MoE") @@ -135,7 +145,15 @@ def _turbo_ragged_dot(*, lhs, rhs, group_sizes, preferred_element_type=None, **k jax.lax.ragged_dot = _turbo_ragged_dot try: return super().sparse_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, ) finally: jax.lax.ragged_dot = _orig_ragged_dot diff --git a/primus/backends/maxtext/max_utils.py b/primus/backends/maxtext/max_utils.py index a5c0b3e8d..e5a048ed5 100644 --- a/primus/backends/maxtext/max_utils.py +++ b/primus/backends/maxtext/max_utils.py @@ -11,10 +11,16 @@ import jax import orbax.checkpoint as ocp -from orbax.checkpoint.experimental.emergency.multi_tier_checkpointing import initialization - from MaxText import max_logging -from MaxText.max_utils import _retrieve_jax_init_info, is_gpu_backend, is_cpu_backend, get_coordinator_ip_address +from MaxText.max_utils import ( + _retrieve_jax_init_info, + get_coordinator_ip_address, + is_cpu_backend, + is_gpu_backend, +) +from orbax.checkpoint.experimental.emergency.multi_tier_checkpointing import ( + initialization, +) def maybe_initialize_jax_distributed_system(raw_keys): @@ -49,17 +55,19 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_cpu(raw_keys) max_logging.log("Jax distributed system initialized on CPUs!") elif raw_keys["enable_multi_tier_checkpointing"]: - max_logging.log("Attempting to initialize the jax distributed system for multi-tier " "checkpointing...") + max_logging.log( + "Attempting to initialize the jax distributed system for multi-tier " "checkpointing..." + ) initialization.initialize_multi_tier_checkpointing( local_checkpoint_directory=raw_keys["local_checkpoint_directory"], backup_interval_minutes=raw_keys["multi_tier_checkpointing_backup_interval_minutes"], run_name=raw_keys["run_name"], jax_initialization_timeout_seconds=raw_keys["jax_distributed_initialization_timeout"], data_parallelism=raw_keys["mtc_data_parallelism"], - ) + ) max_logging.log("Jax distributed system initialized for multi-tier checkpointing!") elif (raw_keys["enable_checkpointing"] and raw_keys["compile_topology_num_slices"] == -1) or raw_keys[ - "hardware" + "hardware" ] == "gpu_multiprocess": max_logging.log("Attempting to initialize the jax distributed system...") if not raw_keys["enable_emergency_checkpoint"]: @@ -183,10 +191,18 @@ def initialize_wandb_writer(config): return None import wandb + os.makedirs(config.wandb_save_dir, exist_ok=True) - wandb.init(project=config.wandb_project, name=config.wandb_exp_name, dir=config.wandb_save_dir, config=dict(config.get_keys())) - max_logging.log(f"WandB logging enabled: {config.wandb_save_dir=}, {config.wandb_project=}, {config.wandb_exp_name=}") + wandb.init( + project=config.wandb_project, + name=config.wandb_exp_name, + dir=config.wandb_save_dir, + config=dict(config.get_keys()), + ) + max_logging.log( + f"WandB logging enabled: {config.wandb_save_dir=}, {config.wandb_project=}, {config.wandb_exp_name=}" + ) return wandb diff --git a/primus/backends/maxtext/metric_logger.py b/primus/backends/maxtext/metric_logger.py index f249bae49..8a707808d 100644 --- a/primus/backends/maxtext/metric_logger.py +++ b/primus/backends/maxtext/metric_logger.py @@ -10,7 +10,7 @@ import jax import numpy as np from MaxText import max_logging, max_utils, maxtext_utils -from MaxText.metric_logger import MetricLogger +from MaxText.metric_logger import MetadataKey, MetricLogger from .max_utils import close_wandb_writer, initialize_wandb_writer @@ -52,10 +52,12 @@ def write_metrics_to_wandb(self, metrics, step, is_training): def write_setup_info_to_tensorboard(self, params): """Writes setup information like train config params, num model params, and XLA flags to TensorBoard.""" num_model_parameters = max_utils.calculate_num_params_from_pytree(params) - self.metadata["per_device_tflops"], _, _ = maxtext_utils.calculate_tflops_training_per_device( + self.metadata[MetadataKey.PER_DEVICE_TFLOPS], _, _ = ( + maxtext_utils.calculate_tflops_training_per_device(self.config) + ) + self.metadata[MetadataKey.PER_DEVICE_TOKENS] = maxtext_utils.calculate_tokens_training_per_device( self.config ) - self.metadata["per_device_tokens"] = maxtext_utils.calculate_tokens_training_per_device(self.config) max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion") max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), self.writer) max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], self.writer) diff --git a/primus/backends/maxtext/train.py b/primus/backends/maxtext/train.py index c134afc9c..563e25962 100644 --- a/primus/backends/maxtext/train.py +++ b/primus/backends/maxtext/train.py @@ -53,195 +53,211 @@ def train_loop(config, recorder, state=None): - """Main Training loop.""" - ( - init_rng, - checkpoint_manager, - state_mesh_shardings, - model, - mesh, - learning_rate_schedule, - data_iterator, - data_loader, - rampup_manager, - eval_data_iterator, - state, - ) = train_utils.setup_train_loop(config, recorder) - - if config.use_dpo: - if "reference_params" not in state.params: - reference_params = jax.tree.map(jnp.copy, state.params["params"]) - state = _merge_dpo_state(state, reference_params) - state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) - - params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) - - p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( - config, - model, - mesh, - state, - state_mesh_shardings, - train_step, - eval_step, - eval_data_iterator, - params_shardings, - ) - - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - shaped_batch = maxtext_utils.get_shaped_batch(config) - if config.shard_optimizer_over_data: - state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) - if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded - compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() - compiled_stats = compiled.memory_analysis() - max_utils.print_compiled_memory_stats(compiled_stats) - - start_step = get_first_step(state) # this is the start_step for training - prof = profiler.Profiler(config, offset_step=start_step) - metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) - - # Write train config params, num model params, and XLA flags to tensorboard - metric_logger.write_setup_info_to_tensorboard(state.params) - - # Synchronize all hosts before entering the training loop. - # Without this barrier, timing variance during initialization (JIT compilation, - # profiler/logger setup, etc.) causes hosts to enter the training loop at different - # times. The first collective operation (data sharding in load_next_batch) then - # times out waiting for straggler hosts, resulting in "collective operation timeout" - # or "stop sending heartbeats" errors. - max_logging.log("====== BARRIER: Synchronizing hosts before training loop ======") - jax.experimental.multihost_utils.sync_global_devices("sync_before_training_loop") - max_logging.log("====== BARRIER PASSED: Starting training loop ======") - - try: - last_step_completion = datetime.datetime.now() - for step in np.arange(start_step, config.steps): - prof.maybe_activate_profiler(step, state) - - with jax.profiler.StepTraceAnnotation("train", step_num=step): - example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager) - # Reshard data from loaded sharding to performant activation sharding - example_batch = sharding.maybe_shard_with_name( - example_batch, - sharding.get_input_data_sharding(config, mesh), - shard_mode=config.shard_mode, - ) - # pylint: disable=not-callable - nextrng = jax.jit(jax.random.fold_in)(init_rng, step) - with maybe_record_goodput(recorder, GoodputEvent.STEP, step): - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - if config.shard_optimizer_over_data: - state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) - state, metrics = p_train_step(state, example_batch, nextrng) - jax.block_until_ready(state) - - step_time_delta = datetime.datetime.now() - last_step_completion - last_step_completion = datetime.datetime.now() - - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) - - if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): - jax.block_until_ready(state) # Ensure compilation has finished. - gcs_utils.upload_dump( - config.dump_hlo_local_dir, - config.dump_hlo_gcs_dir, - module_name=config.dump_hlo_module_name, - delete_local_after=config.dump_hlo_delete_local_after, - all_host_upload=config.dump_hlo_upload_all, - ) + """Main Training loop.""" + ( + init_rng, + checkpoint_manager, + state_mesh_shardings, + model, + mesh, + learning_rate_schedule, + data_iterator, + data_loader, + rampup_manager, + eval_data_iterator, + state, + ) = train_utils.setup_train_loop(config, recorder) - if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: - assert eval_data_iterator - # Explicitly reset the eval iterator and counters before starting the eval loop - eval_data_iterator.reset() - metric_logger.reset_eval_metrics() - - eval_step_count = 0 - # pylint: disable=not-callable - for eval_batch in eval_data_iterator: - if config.eval_steps > 0 and eval_step_count >= config.eval_steps: - break - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - eval_metrics = p_eval_step(state, eval_batch, nextrng) - metric_logger.record_eval_metrics(step, metrics=eval_metrics) - max_logging.log(f"Completed eval step {eval_step_count}") - eval_step_count += 1 - metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) - if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: - prof.deactivate() - raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") - - prof.maybe_deactivate_profiler(step, state) - - if step == start_step: - max_utils.print_mem_stats("After params initialized") - - metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) - - if config.save_checkpoint_on_completion: - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) - if checkpoint_manager is not None: - # in case the last checkpoint_period checkpoint is still in progress - checkpoint_manager.wait_until_finished() - except exceptions.StopTraining as e: - max_logging.log(f"Training stopped: {str(e)}") - finally: - metric_logger.flush_metrics_and_cleanup() - - return state - - -def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any]: - """Initialization of hyperparameters and utilities""" - pathwaysutils.initialize() - jax.config.update("jax_default_prng_impl", "unsafe_rbg") - # TF allocates extraneous GPU memory when using TFDS data - # this leads to CUDA OOMs. WAR for now is to hide GPUs from TF - tf.config.set_visible_devices([], "GPU") - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): - os.environ["LIBTPU_INIT_ARGS"] = ( - os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + if config.use_dpo: + if "reference_params" not in state.params: + reference_params = jax.tree.map(jnp.copy, state.params["params"]) + state = _merge_dpo_state(state, reference_params) + state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + + params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt( + config, state_mesh_shardings ) - # TODO: mazumdera@ : ensure missing mandatory fields in base.yml are filled in in argv, - # or fill in here - config = pyconfig.initialize(argv) - max_utils.print_system_information() - validate_train_config(config) - max_utils.save_device_information(config) - jax.config.update("jax_use_shardy_partitioner", config.shardy) - # update explicit sharding-supported config - if config.shard_mode == ShardMode.EXPLICIT: - jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) - os.environ["TFDS_DATA_DIR"] = config.dataset_path or "" - vertex_tensorboard_manager = VertexTensorboardManager() - if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): - vertex_tensorboard_manager.configure_vertex_tensorboard(config) - - # Create the Goodput recorder - recorder = create_goodput_recorder(config) - - # Stack traces configurations - debug_config = debug_configuration.DebugConfig( - stack_trace_config=stack_trace_configuration.StackTraceConfig( - collect_stack_trace=config.collect_stack_trace, - stack_trace_to_cloud=config.stack_trace_to_cloud, - stack_trace_interval_seconds=config.stack_trace_interval_seconds, - ) + + p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( + config, + model, + mesh, + state, + state_mesh_shardings, + train_step, + eval_step, + eval_data_iterator, + params_shardings, + ) + + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + shaped_batch = maxtext_utils.get_shaped_batch(config) + if config.shard_optimizer_over_data: + state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) + if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded + compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() + compiled_stats = compiled.memory_analysis() + max_utils.print_compiled_memory_stats(compiled_stats) + + start_step = get_first_step(state) # this is the start_step for training + prof = profiler.Profiler(config, offset_step=start_step) + metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) + + # Write train config params, num model params, and XLA flags to tensorboard + metric_logger.write_setup_info_to_tensorboard(state.params) + + # Synchronize all hosts before entering the training loop. + # Without this barrier, timing variance during initialization (JIT compilation, + # profiler/logger setup, etc.) causes hosts to enter the training loop at different + # times. The first collective operation (data sharding in load_next_batch) then + # times out waiting for straggler hosts, resulting in "collective operation timeout" + # or "stop sending heartbeats" errors. + max_logging.log("====== BARRIER: Synchronizing hosts before training loop ======") + jax.experimental.multihost_utils.sync_global_devices("sync_before_training_loop") + max_logging.log("====== BARRIER PASSED: Starting training loop ======") + + try: + last_step_completion = datetime.datetime.now() + for step in np.arange(start_step, config.steps): + prof.maybe_activate_profiler(step, state) + + with jax.profiler.StepTraceAnnotation("train", step_num=step): + example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager) + # Reshard data from loaded sharding to performant activation sharding + example_batch = sharding.maybe_shard_with_name( + example_batch, + sharding.get_input_data_sharding(config, mesh), + shard_mode=config.shard_mode, + ) + # pylint: disable=not-callable + nextrng = jax.jit(jax.random.fold_in)(init_rng, step) + with maybe_record_goodput(recorder, GoodputEvent.STEP, step): + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + if config.shard_optimizer_over_data: + state = sharding.maybe_shard_with_name( + state, state_mesh_shardings, config.shard_mode + ) + state, metrics = p_train_step(state, example_batch, nextrng) + jax.block_until_ready(state) + + step_time_delta = datetime.datetime.now() - last_step_completion + last_step_completion = datetime.datetime.now() + + state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + checkpointing.maybe_save_checkpoint( + checkpoint_manager, state_to_save, config, data_iterator, step + ) + + if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): + jax.block_until_ready(state) # Ensure compilation has finished. + gcs_utils.upload_dump( + config.dump_hlo_local_dir, + config.dump_hlo_gcs_dir, + module_name=config.dump_hlo_module_name, + delete_local_after=config.dump_hlo_delete_local_after, + all_host_upload=config.dump_hlo_upload_all, + ) + + if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: + assert eval_data_iterator + # Explicitly reset the eval iterator and counters before starting the eval loop + eval_data_iterator.reset() + metric_logger.reset_eval_metrics() + + eval_step_count = 0 + # pylint: disable=not-callable + for eval_batch in eval_data_iterator: + if config.eval_steps > 0 and eval_step_count >= config.eval_steps: + break + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + eval_metrics = p_eval_step(state, eval_batch, nextrng) + metric_logger.record_eval_metrics(step, metrics=eval_metrics) + max_logging.log(f"Completed eval step {eval_step_count}") + eval_step_count += 1 + metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) + if ( + metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] + <= config.target_eval_loss + ): + prof.deactivate() + raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") + + prof.maybe_deactivate_profiler(step, state) + + if step == start_step: + max_utils.print_mem_stats("After params initialized") + + metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) + + if config.save_checkpoint_on_completion: + state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) + if checkpoint_manager is not None: + # in case the last checkpoint_period checkpoint is still in progress + checkpoint_manager.wait_until_finished() + except exceptions.StopTraining as e: + max_logging.log(f"Training stopped: {str(e)}") + finally: + metric_logger.flush_metrics_and_cleanup() + + return state + + +def initialize(argv: Sequence[str], **kwargs) -> tuple[pyconfig.HyperParameters, Any, Any]: + """Initialization of hyperparameters and utilities""" + pathwaysutils.initialize() + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + # TF allocates extraneous GPU memory when using TFDS data + # this leads to CUDA OOMs. WAR for now is to hide GPUs from TF + tf.config.set_visible_devices([], "GPU") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): + os.environ["LIBTPU_INIT_ARGS"] = ( + os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + ) + # TODO: mazumdera@ : ensure missing mandatory fields in base.yml are filled in in argv, + # or fill in here + config = pyconfig.initialize(argv, **kwargs) + max_utils.print_system_information() + validate_train_config(config) + max_utils.save_device_information(config) + jax.config.update("jax_use_shardy_partitioner", config.shardy) + # update explicit sharding-supported config + if config.shard_mode == ShardMode.EXPLICIT: + jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) + os.environ["TFDS_DATA_DIR"] = config.dataset_path or "" + vertex_tensorboard_manager = VertexTensorboardManager() + if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): + vertex_tensorboard_manager.configure_vertex_tensorboard(config) + + # Create the Goodput recorder + recorder = create_goodput_recorder(config) + + # Stack traces configurations + debug_config = debug_configuration.DebugConfig( + stack_trace_config=stack_trace_configuration.StackTraceConfig( + collect_stack_trace=config.collect_stack_trace, + stack_trace_to_cloud=config.stack_trace_to_cloud, + stack_trace_interval_seconds=config.stack_trace_interval_seconds, + ) ) - diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config) - return config, recorder, diagnostic_config + diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config) + return config, recorder, diagnostic_config def run(config, recorder, diagnostic_config): - """Run the job given hyperparameters and utilities""" - with ( - diagnostic.diagnose(diagnostic_config), - maybe_record_goodput(recorder, GoodputEvent.JOB), - max_utils.maybe_get_transformer_engine_context(config), - maybe_monitor_goodput(config), - ): - train_loop(config, recorder) + """Run the job given hyperparameters and utilities""" + try: + with ( + diagnostic.diagnose(diagnostic_config), + maybe_record_goodput(recorder, GoodputEvent.JOB), + max_utils.maybe_get_transformer_engine_context(config), + maybe_monitor_goodput(config), + ): + train_loop(config, recorder) + except Exception as e: + max_logging.log(f"Error in train_loop: {e}") + import traceback + + max_logging.log(f"Traceback: {traceback.format_exc()}") + raise diff --git a/primus/configs/modules/maxtext/trainer_base.yaml b/primus/configs/modules/maxtext/trainer_base.yaml index fdc93daf3..be81b358e 100644 --- a/primus/configs/modules/maxtext/trainer_base.yaml +++ b/primus/configs/modules/maxtext/trainer_base.yaml @@ -324,7 +324,7 @@ logical_axis_rules: [ ['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']], ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], ['embed_no_exp', ['fsdp', 'sequence', 'context']], - ['embed_tensor_transpose', ['tensor_transpose']], + ['embed_tensor_transpose', ['tensor_transpose']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], @@ -461,6 +461,7 @@ log_period: 100 # Flushes Tensorboard jax_distributed_initialization_timeout: 300 # This is the default timeout in https://github.com/jax-ml/jax/blob/main/jax/_src/distributed.py # Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers # only to the jax coordination service. +jax_distributed_heartbeat_timeout_seconds: 300 # How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores. jax_debug_log_modules: "" # Set this to "jax" to enable jax verbose logging such as for the jax coordination service initialization. skip_jax_distributed_system: false # If True we will not initialize the jax distributed system. # Currently the jax distributed is needed on cloud TPUs for async checkpointing. diff --git a/primus/modules/trainer/maxtext/pre_trainer.py b/primus/modules/trainer/maxtext/pre_trainer.py index 4b6417352..684316b46 100644 --- a/primus/modules/trainer/maxtext/pre_trainer.py +++ b/primus/modules/trainer/maxtext/pre_trainer.py @@ -122,10 +122,10 @@ def patch_max_utils(self): import MaxText.max_utils as orig_max_utils from primus.backends.maxtext.max_utils import ( - maybe_initialize_jax_distributed_system, - initialize_jax_for_gpu, initialize_jax_for_cpu, + initialize_jax_for_gpu, initialize_jax_for_tpu_with_emergency_checkpointing, + maybe_initialize_jax_distributed_system, print_system_information, save_device_information, ) @@ -133,7 +133,9 @@ def patch_max_utils(self): orig_max_utils.maybe_initialize_jax_distributed_system = maybe_initialize_jax_distributed_system orig_max_utils.initialize_jax_for_gpu = initialize_jax_for_gpu orig_max_utils.initialize_jax_for_cpu = initialize_jax_for_cpu - orig_max_utils.initialize_jax_for_tpu_with_emergency_checkpointing = initialize_jax_for_tpu_with_emergency_checkpointing + orig_max_utils.initialize_jax_for_tpu_with_emergency_checkpointing = ( + initialize_jax_for_tpu_with_emergency_checkpointing + ) orig_max_utils.print_system_information = print_system_information orig_max_utils.save_device_information = save_device_information warning_rank_0("MaxText Pre-Trainer: patch max_utils successfully.") @@ -143,8 +145,8 @@ def patch_checkpoint(self): import MaxText.train_utils as orig_train_utils from primus.backends.maxtext.checkpointing import ( - load_state_if_possible, create_orbax_checkpoint_manager, + load_state_if_possible, ) from primus.backends.maxtext.train_utils import create_training_tools @@ -193,6 +195,7 @@ def patch_input_pipeline(self): def patch_config_types(self): import MaxText.configs.types as orig_config_types + from primus.backends.maxtext.configs.types import PrimusMaxTextConfig orig_config_types.MaxTextConfig = PrimusMaxTextConfig @@ -240,7 +243,7 @@ def patch_decoder_layer(): import MaxText.layers.llama2 as orig_llama2 import MaxText.layers.mistral as orig_mistral import MaxText.layers.mixtral as orig_mixtral - + from primus.backends.maxtext.layers.gemma import PrimusGemmaDecoderLayer from primus.backends.maxtext.layers.gemma2 import PrimusGemma2DecoderLayer from primus.backends.maxtext.layers.llama2 import PrimusLlamaDecoderLayer @@ -254,4 +257,4 @@ def patch_decoder_layer(): orig_mixtral.MixtralDecoderLayer = PrimusMixtralDecoderLayer warning_rank_0("MaxText Pre-Trainer: patch decoder layer successfully.") - patch_decoder_layer() \ No newline at end of file + patch_decoder_layer() diff --git a/primus/pretrain.py b/primus/pretrain.py index 62ba2b0e8..7f4671096 100644 --- a/primus/pretrain.py +++ b/primus/pretrain.py @@ -112,9 +112,10 @@ def setup_backend_path(framework: str, backend_path=None, verbose: bool = True): } mapped_name = fallback_name_map.get(framework, framework) default_path = Path(__file__).resolve().parent.parent / "third_party" / mapped_name - if framework == "maxtext" and os.path.join(default_path, "src").exists(): - default_path = os.path.join(default_path, "src") - candidate_paths.append(default_path) + if framework == "maxtext" and (default_path / "src").exists(): + default_path = default_path / "src" + candidate_paths.insert(0, str(default_path)) + print(f"[Primus] candidate_paths: {candidate_paths}") # Normalize & deduplicate candidate_paths = list(dict.fromkeys(os.path.normpath(os.path.abspath(p)) for p in candidate_paths)) From b96d0b3b3e42638a0ed00930f3a0b50e89892780 Mon Sep 17 00:00:00 2001 From: liyingli Date: Tue, 10 Feb 2026 08:16:08 +0000 Subject: [PATCH 05/24] update jax docker to 26.1 --- .../maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml | 4 ++-- examples/run_local_pretrain.sh | 2 +- examples/run_pretrain.sh | 11 ++++++----- requirements-jax.txt | 11 +---------- 4 files changed, 10 insertions(+), 18 deletions(-) diff --git a/examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml b/examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml index c543f9eb1..60bffc889 100644 --- a/examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml +++ b/examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml @@ -38,5 +38,5 @@ modules: megablox: false capacity_factor: 1 max_target_length: 4096 - per_device_batch_size: 12 - remat_policy: "minimal" + per_device_batch_size: 11 + remat_policy: "save_dot_with_context_except_mlp" diff --git a/examples/run_local_pretrain.sh b/examples/run_local_pretrain.sh index f891922a9..33327c73b 100755 --- a/examples/run_local_pretrain.sh +++ b/examples/run_local_pretrain.sh @@ -42,7 +42,7 @@ EXP=${EXP:-"examples/megatron/exp_pretrain.yaml"} # Default docker image if [ "${BACKEND:-}" = "MaxText" ]; then - DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/mad-private:jax_rocm7.1_jax_0.8.2_ci_e5be0ef_20260131_v3"} + DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/jax-training:maxtext-v26.1"} else DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/primus:v26.1"} fi diff --git a/examples/run_pretrain.sh b/examples/run_pretrain.sh index 442693e9f..a6f13b570 100755 --- a/examples/run_pretrain.sh +++ b/examples/run_pretrain.sh @@ -102,11 +102,11 @@ export DATA_PATH=${DATA_PATH:-"${PRIMUS_PATH}/data"} export HF_HOME=${HF_HOME:-"${DATA_PATH}/huggingface"} LOG_INFO_RANK0 "Pip installing required packages ..." -# if [ "${BACKEND:-}" != "MaxText" ]; then -# pip install -r "$PRIMUS_PATH/requirements.txt" --quiet -# else -# pip install -r "$PRIMUS_PATH/requirements-jax.txt" --quiet -# fi +if [ "${BACKEND:-}" != "MaxText" ]; then + pip install -r "$PRIMUS_PATH/requirements.txt" --quiet +else + pip install -r "$PRIMUS_PATH/requirements-jax.txt" --quiet +fi FAULT_TOLERANCE_VALUE=$(EXAMPLE_FAULT_TOLERANCE "$@") @@ -273,6 +273,7 @@ if [ "${BACKEND:-}" == "MaxText" ]; then export DUMP_HLO=${DUMP_HLO:-0} export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 export XLA_PYTHON_CLIENT_MEM_FRACTION=.97 + export XLA_FLAGS="--xla_gpu_memory_limit_slop_factor=95 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_enable_command_buffer='' --xla_gpu_enable_latency_hiding_scheduler=True --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_enable_triton_gemm=False --xla_gpu_enable_cublaslt=True --xla_gpu_autotune_level=4 --xla_gpu_enable_all_gather_combine_by_dim=FALSE" export NVTE_USE_HIPBLASLT=1 if [ "${DUMP_HLO}" = "1" ]; then mkdir -p "${DUMP_HLO_DIR}" diff --git a/requirements-jax.txt b/requirements-jax.txt index aebef58e3..9de6a6b9a 100644 --- a/requirements-jax.txt +++ b/requirements-jax.txt @@ -1,12 +1,3 @@ loguru wandb -expecttest -pre-commit -nltk -matplotlib -markdown2 -weasyprint -tyro -blobfile -mlflow -pyrsmi +pre-commit \ No newline at end of file From 3767dbddc53508fba6771787199b63a5a6b6b348 Mon Sep 17 00:00:00 2001 From: liyingli Date: Tue, 10 Feb 2026 09:01:12 +0000 Subject: [PATCH 06/24] fix ib deps for jax --- examples/run_pretrain.sh | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/examples/run_pretrain.sh b/examples/run_pretrain.sh index a6f13b570..010fd1c7f 100755 --- a/examples/run_pretrain.sh +++ b/examples/run_pretrain.sh @@ -403,6 +403,21 @@ if [[ "$PATCH_TE_FLASH_ATTN" == "1" ]]; then fi LOG_INFO_RANK0 "" +# -------------------- Install required packages for Jax -------------------- +install_pkgs_for_maxtext() { + LOG_INFO_RANK0 "========== Install IB required packages for Jax/MaxText ==========" + apt update + apt install autoconf automake libtool pkg-config -y + apt install jq dpkg-dev kmod xz-utils -y + apt install libibverbs-dev ibverbs-utils infiniband-diags -y + apt install rdma-core librdmacm-dev libibverbs-dev libibumad-dev -y + LOG_INFO_RANK0 "========== Install IB required packages for Jax/MaxText Done ==========" +} + +if [[ "$NNODES" -gt 1 ]] && [[ "${BACKEND:-}" == "MaxText" ]]; then + install_pkgs_for_maxtext +fi + # ----------------- Rebuild nbxt ----------------- export REBUILD_BNXT=${REBUILD_BNXT:-0} export PATH_TO_BNXT_TAR_PACKAGE=${PATH_TO_BNXT_TAR_PACKAGE} @@ -423,20 +438,6 @@ else LOG_INFO "Skip bnxt rebuild. REBUILD_BNXT=$REBUILD_BNXT, PATH_TO_BNXT_TAR_PACKAGE=$PATH_TO_BNXT_TAR_PACKAGE" fi -# -------------------- Install required packages for Jax -------------------- -install_pkgs_for_maxtext() { - LOG_INFO_RANK0 "========== Install required packages for Jax/MaxText ==========" - apt install iproute2 -y - apt install -y linux-headers-"$(uname -r)" libelf-dev - apt install -y gcc make libtool autoconf librdmacm-dev rdmacm-utils infiniband-diags ibverbs-utils perftest ethtool libibverbs-dev \ - rdma-core strace libibmad5 libibnetdisc5 ibverbs-providers libibumad-dev libibumad3 libibverbs1 libnl-3-dev libnl-route-3-dev - LOG_INFO_RANK0 "========== Install required packages for Jax/MaxText Done ==========" -} - -if [[ "$NNODES" -gt 1 ]] && [[ "${BACKEND:-}" == "MaxText" ]]; then - install_pkgs_for_maxtext -fi - # -------------------- HipBLASLt Tuning -------------------- handle_hipblaslt_tuning() { local STAGE=${PRIMUS_HIPBLASLT_TUNING_STAGE:-0} From cbda94703ba9425a5020751cb65559e4be270135 Mon Sep 17 00:00:00 2001 From: liyingli Date: Wed, 11 Feb 2026 07:01:34 +0000 Subject: [PATCH 07/24] fix mixtral config error --- examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml | 2 +- examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml b/examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml index c543f9eb1..2bee81be5 100644 --- a/examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml +++ b/examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml @@ -39,4 +39,4 @@ modules: capacity_factor: 1 max_target_length: 4096 per_device_batch_size: 12 - remat_policy: "minimal" + remat_policy: "save_dot_with_context_except_mlp" diff --git a/examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml b/examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml index 60bffc889..f680155ac 100644 --- a/examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml +++ b/examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml @@ -39,4 +39,4 @@ modules: capacity_factor: 1 max_target_length: 4096 per_device_batch_size: 11 - remat_policy: "save_dot_with_context_except_mlp" + remat_policy: "minimal" From e0794c93ab6beab70a08467227795f3e34db2132 Mon Sep 17 00:00:00 2001 From: liyingli Date: Thu, 12 Feb 2026 16:25:26 +0000 Subject: [PATCH 08/24] add 405b config file --- .../MI355X/llama3.1_405B-pretrain.yaml | 51 +++++++++++++++++++ .../configs/models/maxtext/llama3.1_405B.yaml | 7 +++ 2 files changed, 58 insertions(+) create mode 100644 examples/maxtext/configs/MI355X/llama3.1_405B-pretrain.yaml create mode 100644 primus/configs/models/maxtext/llama3.1_405B.yaml diff --git a/examples/maxtext/configs/MI355X/llama3.1_405B-pretrain.yaml b/examples/maxtext/configs/MI355X/llama3.1_405B-pretrain.yaml new file mode 100644 index 000000000..f157536f0 --- /dev/null +++ b/examples/maxtext/configs/MI355X/llama3.1_405B-pretrain.yaml @@ -0,0 +1,51 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:llama3.1_405B-pretrain} +workspace: ./output + +modules: + pre_trainer: + framework: maxtext + config: pre_trainer.yaml + + # model to run + model: llama3.1_405B.yaml + overrides: + run_name: "llama3.1_405b_training" + base_output_directory: "./output" + steps: 50 + log_period: 10 + profiler: "" + + # data + dataset_type: "synthetic" + hf_access_token: ${HF_TOKEN:""} + + # checkpoint + enable_checkpointing: false + async_checkpointing: false + + # inter-node parallelism strategy + dcn_data_parallelism: 1 + dcn_fsdp_parallelism: -1 + dcn_pipeline_parallelism: 1 + dcn_tensor_parallelism: 1 + dcn_sequence_parallelism: 1 + + # intra-node parallelism strategy + ici_fsdp_parallelism: -1 + ici_data_parallelism: 1 + ici_sequence_parallelism: 1 + ici_tensor_parallelism: 1 + ici_pipeline_parallelism: 1 + + remat_policy: 'full' + optimizer_memory_host_offload: False + param_scan_axis: 1 + megablox: False + + use_iota_embed: True + scan_layers: True + + max_target_length: 8192 + per_device_batch_size: 5 \ No newline at end of file diff --git a/primus/configs/models/maxtext/llama3.1_405B.yaml b/primus/configs/models/maxtext/llama3.1_405B.yaml new file mode 100644 index 000000000..8aa118a0b --- /dev/null +++ b/primus/configs/models/maxtext/llama3.1_405B.yaml @@ -0,0 +1,7 @@ +extends: + - model_base.yaml + +model_name: "llama3.1-405b" +tokenizer_path: "meta-llama/Llama-3.3-70B-Instruct" +attention: "cudnn_flash_te" +use_iota_embed: true \ No newline at end of file From a704ceeec3221854d5c914fbf66cd295997bfeb5 Mon Sep 17 00:00:00 2001 From: liyingli Date: Fri, 13 Feb 2026 00:50:00 +0000 Subject: [PATCH 09/24] update multi-node shell for jax --- examples/run_pretrain.sh | 74 ++++++++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 26 deletions(-) diff --git a/examples/run_pretrain.sh b/examples/run_pretrain.sh index 010fd1c7f..61bc41d31 100755 --- a/examples/run_pretrain.sh +++ b/examples/run_pretrain.sh @@ -175,32 +175,49 @@ export NCCL_CHECKS_DISABLE=1 # Set InfiniBand GID index for NCCL communication if [ "$USING_AINIC" == "1" ]; then - export ANP_HOME_DIR=${ANP_HOME_DIR:-"/opt/amd-anp"} - export RCCL_HOME_DIR=${RCCL_HOME_DIR:-"/opt/rccl"} - export MPI_HOME_DIR=${MPI_HOME_DIR:-"/opt/ompi"} - export NCCL_NET_PLUGIN=librccl-anp.so - LOG_INFO_RANK0 "Using AINIC" - LOG_INFO_RANK0 "RCCL_HOME_DIR: $RCCL_HOME_DIR" - LOG_INFO_RANK0 "ANP_HOME_DIR: $ANP_HOME_DIR" - LOG_INFO_RANK0 "MPI_HOME_DIR: $MPI_HOME_DIR" - - # unset NCCL_IB_GID_INDEX - export NCCL_IB_GID_INDEX=1 - # export NCCL_IB_ROCE_VERSION_NUM=2 - export NCCL_MAX_P2P_CHANNELS=56 - export NCCL_IB_TC=104 - export NCCL_IB_FIFO_TC=192 - export NET_OPTIONAL_RECV_COMPLETION=1 - export NCCL_IB_USE_INLINE=1 - export RCCL_GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING=0 - export NCCL_GDR_FLUSH_DISABLE=1 - export NCCL_DMABUF_ENABLE=0 - export NCCL_IGNORE_CPU_AFFINITY=1 - export NCCL_IB_QPS_PER_CONNECTION=1 - - export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/libibverbs:${RCCL_HOME_DIR}/build/release:${ANP_HOME_DIR}/build:${MPI_HOME_DIR}/lib:$LD_LIBRARY_PATH - + if [ "${BACKEND:-}" == "MaxText" ]; then + # ------- RCCL/NCCL IB Tuning ------- + export IONIC_LOCKFREE=all + export NCCL_GDR_COPY_ENABLE=1 + export NCCL_GDR_FLUSH_DISABLE=1 + export NCCL_IB_ECE_ENABLE=0 + export NCCL_IB_FIFO_TC=184 + export NCCL_IB_GID_INDEX=1 + export NCCL_IB_PCI_RELAXED_ORDERING=1 + export NCCL_IB_TC=96 + export NCCL_IB_USE_INLINE=1 + export NCCL_IGNORE_CPU_AFFINITY=1 + export NCCL_PXN_DISABLE=0 + export NET_OPTIONAL_RECV_COMPLETION=1 + export RCCL_GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING=0 + export RCCL_LL128_FORCE_ENABLE=1 + else + export ANP_HOME_DIR=${ANP_HOME_DIR:-"/opt/amd-anp"} + export RCCL_HOME_DIR=${RCCL_HOME_DIR:-"/opt/rccl"} + export MPI_HOME_DIR=${MPI_HOME_DIR:-"/opt/ompi"} + export NCCL_NET_PLUGIN=librccl-anp.so + + LOG_INFO_RANK0 "RCCL_HOME_DIR: $RCCL_HOME_DIR" + LOG_INFO_RANK0 "ANP_HOME_DIR: $ANP_HOME_DIR" + LOG_INFO_RANK0 "MPI_HOME_DIR: $MPI_HOME_DIR" + + # unset NCCL_IB_GID_INDEX + export NCCL_IB_GID_INDEX=1 + # export NCCL_IB_ROCE_VERSION_NUM=2 + export NCCL_MAX_P2P_CHANNELS=56 + export NCCL_IB_TC=104 + export NCCL_IB_FIFO_TC=192 + export NET_OPTIONAL_RECV_COMPLETION=1 + export NCCL_IB_USE_INLINE=1 + export RCCL_GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING=0 + export NCCL_GDR_FLUSH_DISABLE=1 + export NCCL_DMABUF_ENABLE=0 + export NCCL_IGNORE_CPU_AFFINITY=1 + export NCCL_IB_QPS_PER_CONNECTION=1 + + export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/libibverbs:${RCCL_HOME_DIR}/build/release:${ANP_HOME_DIR}/build:${MPI_HOME_DIR}/lib:$LD_LIBRARY_PATH + fi else export NCCL_IB_GID_INDEX=3 fi @@ -272,7 +289,12 @@ if [ "${BACKEND:-}" == "MaxText" ]; then export DUMP_HLO_DIR=${DUMP_HLO_DIR:-"${PRIMUS_PATH}/output/xla_dump_hlo"} export DUMP_HLO=${DUMP_HLO:-0} export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 - export XLA_PYTHON_CLIENT_MEM_FRACTION=.97 + if [ $NNODES -gt 1 ]; then + export XLA_PYTHON_CLIENT_MEM_FRACTION=.93 + export JAX_HIP_GRAPH_LOWERING=false + else + export XLA_PYTHON_CLIENT_MEM_FRACTION=.97 + fi export XLA_FLAGS="--xla_gpu_memory_limit_slop_factor=95 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_enable_command_buffer='' --xla_gpu_enable_latency_hiding_scheduler=True --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_enable_triton_gemm=False --xla_gpu_enable_cublaslt=True --xla_gpu_autotune_level=4 --xla_gpu_enable_all_gather_combine_by_dim=FALSE" export NVTE_USE_HIPBLASLT=1 if [ "${DUMP_HLO}" = "1" ]; then From 10d1a0f13a27d7d98042e0fedbe84b4bfbf1ac88 Mon Sep 17 00:00:00 2001 From: fuyuajin-amd Date: Sat, 14 Feb 2026 03:51:23 +0000 Subject: [PATCH 10/24] Improve slurm launcher logging and make nodelist optional - Add timestamp to log filenames to prevent overwriting across runs - Move tee logging outside the inline script to capture consolidated multi-node output in a single log file - Make --nodelist conditional via NODE_LIST env variable --- examples/run_slurm_pretrain.sh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/run_slurm_pretrain.sh b/examples/run_slurm_pretrain.sh index 04da35a4d..d13e6638f 100755 --- a/examples/run_slurm_pretrain.sh +++ b/examples/run_slurm_pretrain.sh @@ -19,6 +19,7 @@ Optional Environment Variables: NNODES Number of nodes to use [default: 1] MASTER_PORT Master port [default: 12345] LOG_DIR Directory for log output [default: ./output] + NODE_LIST Comma-separated list of nodes for srun --nodelist [default: unset] Example: export DATA_PATH=/mnt/data @@ -35,12 +36,13 @@ export NNODES=${NNODES:-1} SCRIPT_DIR=$(dirname "$(realpath "${BASH_SOURCE[0]}")") export LOG_DIR=${LOG_DIR:-"./output"} -LOG_FILE="${LOG_DIR}/log_slurm_pretrain.txt" +LOG_FILE="${LOG_DIR}/log_slurm_pretrain_$(date +%Y%m%d_%H%M%S).txt" mkdir -p "$LOG_DIR" srun -N "${NNODES}" \ --exclusive \ --export ALL \ + ${NODE_LIST:+--nodelist="${NODE_LIST}"} \ --ntasks-per-node=1 \ --cpus-per-task="${CPUS_PER_TASK:-128}" \ bash -c " @@ -58,5 +60,5 @@ srun -N "${NNODES}" \ export NODE_RANK=\${SLURM_PROCID} export GPUS_PER_NODE=\${SLURM_GPUS_ON_NODE} export REBUILD_PRIMUS_TURBO=\${REBUILD_PRIMUS_TURBO} - bash ${SCRIPT_DIR}/run_local_pretrain.sh \"\$@\" 2>&1 | tee ${LOG_FILE} - " bash "$@" + bash ${SCRIPT_DIR}/run_local_pretrain.sh \"\$@\" + " bash "$@" 2>&1 | tee "${LOG_FILE}" From f09fc6fb90bc87697dfdeccd0950848c8d9995b4 Mon Sep 17 00:00:00 2001 From: fuyuajin-amd Date: Sat, 14 Feb 2026 19:22:39 +0000 Subject: [PATCH 11/24] corrected XLA_FLAGS and added env var to suppress errors - set TF_CPP_MIN_LOG_LEVEL=2. Without this setting, error occurs at the end when all training steps complete. - XLA_FLAGS is case sensitive. Corrected a few values. --- examples/run_pretrain.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/run_pretrain.sh b/examples/run_pretrain.sh index 61bc41d31..b8826ae72 100755 --- a/examples/run_pretrain.sh +++ b/examples/run_pretrain.sh @@ -295,7 +295,8 @@ if [ "${BACKEND:-}" == "MaxText" ]; then else export XLA_PYTHON_CLIENT_MEM_FRACTION=.97 fi - export XLA_FLAGS="--xla_gpu_memory_limit_slop_factor=95 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_enable_command_buffer='' --xla_gpu_enable_latency_hiding_scheduler=True --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_enable_triton_gemm=False --xla_gpu_enable_cublaslt=True --xla_gpu_autotune_level=4 --xla_gpu_enable_all_gather_combine_by_dim=FALSE" + export TF_CPP_MIN_LOG_LEVEL=2 # this env var is used to suppress the error logs at the end of training + export XLA_FLAGS="--xla_gpu_memory_limit_slop_factor=95 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_enable_command_buffer='' --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_cublaslt=true --xla_gpu_autotune_level=4 --xla_gpu_enable_all_gather_combine_by_dim=false" export NVTE_USE_HIPBLASLT=1 if [ "${DUMP_HLO}" = "1" ]; then mkdir -p "${DUMP_HLO_DIR}" From a0aed10645e6f2b2beb3cf5ea6735388d931844f Mon Sep 17 00:00:00 2001 From: fuyuajin-amd Date: Mon, 16 Feb 2026 01:00:50 +0000 Subject: [PATCH 12/24] Make Jax/MaxText work for the unified primus-cli launching command - detect backend framework in `primus-cli-direct.sh`. Install JAX dependencies - If using AINIC (setting USING_AINIC=1), `03_enable_ainic.sh` will run. The `LD_LIBRARY_PATH` is modified to make sure libraries are correctly loaded for JAX/MaxText. - Set XLA_PYTHON_CLIENT_MEM_FRACTION=.93 to avoid HSA_STATUS_ERROR_OUT_OF_RESOURCES error during multi-node training - Corrected some XLA_FLAGS. It is case sensitive. Values `true` and `false` do not need to be capitalized. - set TF_CPP_MIN_LOG_LEVEL=2 to suppress the error messages at the end of JAX/MaxText training Here is an example to launch JAX/MaxText traing on two nodes. `./primus-cli --config runner/maxtext-test.yaml slurm srun -N 2 -- train pretrain --config examples/maxtext/configs/MI355X/llama2_7B-pretrain.yaml` --- runner/helpers/hooks/03_enable_ainic.sh | 5 ++- .../hooks/train/pretrain/maxtext/prepare.py | 8 +++- runner/primus-cli-direct.sh | 38 ++++++++++++++++++- 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/runner/helpers/hooks/03_enable_ainic.sh b/runner/helpers/hooks/03_enable_ainic.sh index 82b198ad0..5bd162378 100755 --- a/runner/helpers/hooks/03_enable_ainic.sh +++ b/runner/helpers/hooks/03_enable_ainic.sh @@ -41,8 +41,9 @@ NCCL_IB_QPS_PER_CONNECTION="${NCCL_IB_QPS_PER_CONNECTION:-1}" # LD_LIBRARY_PATH: prepend AINIC/RCCL/MPI paths while preserving existing. _ld_base="/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/libibverbs:${RCCL_HOME_DIR}/build/release:${ANP_HOME_DIR}/build:${MPI_HOME_DIR}/install/lib" -LD_LIBRARY_PATH="${_ld_base}${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" - +# Need to append AINIC/RCCL/MPI paths to the existing LD_LIBRARY_PATH. Otherwise, +# JAX MaxText will not find the appropriate ROCm libraries. +LD_LIBRARY_PATH="${LD_LIBRARY_PATH:+${LD_LIBRARY_PATH}:}${_ld_base}" LOG_INFO_RANK0 "Using AINIC" LOG_INFO_RANK0 "RCCL_HOME_DIR: ${RCCL_HOME_DIR}" LOG_INFO_RANK0 "ANP_HOME_DIR: ${ANP_HOME_DIR}" diff --git a/runner/helpers/hooks/train/pretrain/maxtext/prepare.py b/runner/helpers/hooks/train/pretrain/maxtext/prepare.py index 917f7cf8f..4bdab8985 100644 --- a/runner/helpers/hooks/train/pretrain/maxtext/prepare.py +++ b/runner/helpers/hooks/train/pretrain/maxtext/prepare.py @@ -237,14 +237,18 @@ def main(): print(f"env.DUMP_HLO_DIR={dump_hlo_dir}") print(f"env.DUMP_HLO={dump_hlo}") print("env.NVTE_ALLOW_NONDETERMINISTIC_ALGO=1") - print("env.XLA_PYTHON_CLIENT_MEM_FRACTION=.97") + # set XLA_PYTHON_CLIENT_MEM_FRACTION to 0.93 + # to avoid HSA_STATUS_ERROR_OUT_OF_RESOURCES error during multi-node training + print("env.XLA_PYTHON_CLIENT_MEM_FRACTION=.93") print("env.NVTE_USE_HIPBLASLT=1") - xla_flags = "--xla_gpu_memory_limit_slop_factor=95 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_graph_level=0 --xla_gpu_enable_latency_hiding_scheduler=True --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_enable_triton_gemm=False --xla_gpu_enable_cublaslt=True --xla_gpu_autotune_level=0 --xla_gpu_enable_all_gather_combine_by_dim=FALSE" + xla_flags = "--xla_gpu_memory_limit_slop_factor=95 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_enable_command_buffer='' --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_cublaslt=true --xla_gpu_autotune_level=0 --xla_gpu_enable_all_gather_combine_by_dim=false" if dump_hlo == "1": xla_flags += f" --xla_dump_to={dump_hlo_dir}" log_info(f"XLA HLO dumping enabled, output directory: {dump_hlo_dir}") print(f"env.XLA_FLAGS={xla_flags}") + # set TF_CPP_MIN_LOG_LEVEL=2 to suppress the error messages at the end of JAX/MaxText training + print(f"env.TF_CPP_MIN_LOG_LEVEL=2") # AMD GPU optimizations print("env.HIP_FORCE_DEV_KERNARG=1") diff --git a/runner/primus-cli-direct.sh b/runner/primus-cli-direct.sh index 1dddcb551..4eb6c92c4 100755 --- a/runner/primus-cli-direct.sh +++ b/runner/primus-cli-direct.sh @@ -337,9 +337,45 @@ mkdir -p "$(dirname "${direct_config[log_file]:-}")" ############################################################################### # STEP 5: Install dependencies ############################################################################### +# Detect the backend framework from the experiment YAML (--config in PRIMUS_ARGS) +# so we can install the correct requirements file: +# maxtext -> requirements-jax.txt +# others -> requirements.txt +_detect_framework() { + local cfg_path="" + local args=("${primus_args[@]}") + for ((i=0; i<${#args[@]}; i++)); do + if [[ "${args[$i]}" == "--config" && -n "${args[$((i+1))]:-}" ]]; then + cfg_path="${args[$((i+1))]}" + break + fi + done + if [[ -z "$cfg_path" || ! -f "$cfg_path" ]]; then + echo "" + return + fi + python3 -c " +import yaml, sys +try: + cfg = yaml.safe_load(open('$cfg_path')) + print(cfg.get('modules',{}).get('pre_trainer',{}).get('framework','')) +except Exception: + print('') +" 2>/dev/null +} + +DETECTED_FRAMEWORK="$(_detect_framework)" +LOG_INFO_RANK0 "[direct] Detected framework: ${DETECTED_FRAMEWORK:-unknown}" + # Skip pip install in dry-run mode if [[ "$DRY_RUN_MODE" != "1" ]]; then - pip install -qq -r requirements.txt + if [[ "$DETECTED_FRAMEWORK" == "maxtext" ]]; then + LOG_INFO_RANK0 "[direct] Installing JAX dependencies (requirements-jax.txt)" + pip install -qq -r requirements-jax.txt + else + LOG_INFO_RANK0 "[direct] Installing PyTorch dependencies (requirements.txt)" + pip install -qq -r requirements.txt + fi fi ############################################################################### From 62c59e4630b5f4e68960f2f81415ed198426939a Mon Sep 17 00:00:00 2001 From: fuyuajin-amd Date: Wed, 18 Feb 2026 12:50:17 +0000 Subject: [PATCH 13/24] Updated the apt install package list for Jax/MaxText Problem: when apt install linux-headers-"$(uname -r)", it was resolved to wrong version number on some nodes, and caused "package not found" error. Solution: remove it from the package install list. It does not affect the performance. --- runner/helpers/hooks/train/pretrain/maxtext/prepare.py | 1 - 1 file changed, 1 deletion(-) diff --git a/runner/helpers/hooks/train/pretrain/maxtext/prepare.py b/runner/helpers/hooks/train/pretrain/maxtext/prepare.py index 4bdab8985..0b94cfd55 100644 --- a/runner/helpers/hooks/train/pretrain/maxtext/prepare.py +++ b/runner/helpers/hooks/train/pretrain/maxtext/prepare.py @@ -141,7 +141,6 @@ def install_maxtext_dependencies() -> None: cmd = ( "apt install iproute2 -y && " "apt install -y " - 'linux-headers-"$(uname -r)" ' "libelf-dev " "gcc make libtool autoconf " "librdmacm-dev rdmacm-utils infiniband-diags ibverbs-utils perftest ethtool " From eac13641b8304922dd880e98d2c49c917d6a1e24 Mon Sep 17 00:00:00 2001 From: fuyuajin-amd Date: Wed, 18 Feb 2026 15:42:10 +0000 Subject: [PATCH 14/24] add /dev/infiniband as default in primus-cli global config file --- runner/.primus.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runner/.primus.yaml b/runner/.primus.yaml index 989a7d6d1..5f1599302 100644 --- a/runner/.primus.yaml +++ b/runner/.primus.yaml @@ -47,7 +47,7 @@ container: device: - "/dev/kfd" - "/dev/dri" - # - "/dev/infiniband" + - "/dev/infiniband" # Linux capabilities (each passed as --cap-add) # NOTE: Do not modify these capabilities - they are required for proper container operation From 08967fd367ca813f13a1f2c98ee1e77e71e56e77 Mon Sep 17 00:00:00 2001 From: fuyuajin-amd Date: Wed, 18 Feb 2026 15:55:13 +0000 Subject: [PATCH 15/24] add primus-cli global config yaml file for AINIC usage --- runner/use_ainic.yaml | 93 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 runner/use_ainic.yaml diff --git a/runner/use_ainic.yaml b/runner/use_ainic.yaml new file mode 100644 index 000000000..83c8b9677 --- /dev/null +++ b/runner/use_ainic.yaml @@ -0,0 +1,93 @@ +# Primus CLI System Default Configuration +# This file provides system-wide default settings for Primus CLI +# +# Priority: CLI args > User config (~/.primus.yaml) > System defaults (this file) + +# Main settings (apply to all modes) +main: + debug: false + dry_run: false + +# Slurm-specific settings +slurm: + debug: false + dry_run: false + # partition: "" + nodes: 1 + gpus_per_node: 8 + time: "4:00:00" + +# Container-specific settings +container: + debug: false + dry_run: false + + # Docker/Podman runtime options + # All keys directly map to CLI arguments (--key value) + options: + # Container image + image: "rocm/jax-training:maxtext-v26.1" + + # Single-value options + ipc: "host" + network: "host" + # cpus: "96" + # memory: "256G" + name: "primus-training" + privileged: "true" + security-opt: "seccomp=unconfined" + group-add: "video" + + # Cumulative options (can be specified multiple times via CLI) + # Device access (each passed as --device) + # NOTE: Do not modify these device paths - they are required for ROCm/GPU access + # /dev/kfd - Kernel Fusion Driver (ROCm core) + # /dev/dri - Direct Rendering Infrastructure (GPU access) + # /dev/infiniband - InfiniBand network device (multi-node communication) + device: + - "/dev/kfd" + - "/dev/dri" + - "/dev/infiniband" + + # Linux capabilities (each passed as --cap-add) + # NOTE: Do not modify these capabilities - they are required for proper container operation + # SYS_PTRACE - Required for debugging and profiling tools + # CAP_SYS_ADMIN - Required for system administration operations + cap-add: + - "SYS_PTRACE" + - "CAP_SYS_ADMIN" + + # Volume mounts (each passed as --volume) + volume: [] + # volume: + # - "/data:/data" + # - "/output:/output" + # - "/workspace/Primus" + # - "/model_weights:/model_weights:ro" + + # Environment variables (each passed as --env KEY=VALUE) + env: + # If using AINIC, set the environment variables for AINIC + # make sure NCCL_IB_GID_INDEX value is set appropriately + - "USING_AINIC=1" + - "NCCL_PXN_DISABLE=0" + - "NCCL_IB_GID_INDEX=1" +# Direct mode settings +direct: + debug: false + + # Distributed training parameters + gpus_per_node: 8 + master_port: 1234 + nnodes: 1 + master_addr: "localhost" + + # Direct mode specific options + run_mode: "torchrun" + script: "primus/cli/main.py" + numa: "auto" + log_file: "" + + # Default patch scripts and env vars + patch: [] + env: [] From 095b2677e8559bad4cbffd97274b8578c72177e8 Mon Sep 17 00:00:00 2001 From: fuyuajin-amd Date: Wed, 18 Feb 2026 20:48:24 +0000 Subject: [PATCH 16/24] Updated Primus-cli user guide 1. added examples for using AINIC in training 2. added more examples for running preflight 3. updated arguments format for benchmark gemm command. The script was changed, but document was not updated. --- docs/cli/PRIMUS-CLI-GUIDE.md | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/docs/cli/PRIMUS-CLI-GUIDE.md b/docs/cli/PRIMUS-CLI-GUIDE.md index 3ec7a2e5b..da629a4aa 100644 --- a/docs/cli/PRIMUS-CLI-GUIDE.md +++ b/docs/cli/PRIMUS-CLI-GUIDE.md @@ -36,7 +36,7 @@ ```bash # Run GEMM benchmark directly on current host -./primus-cli direct -- benchmark gemm -M 4096 -N 4096 -K 4096 +./primus-cli direct -- benchmark gemm --M 4096 --N 4096 --K 4096 ``` --- @@ -65,7 +65,7 @@ Primus CLI supports three execution modes, each suitable for different scenarios ./primus-cli direct -- train pretrain --config config.yaml # GEMM benchmark -./primus-cli direct -- benchmark gemm -M 4096 -N 4096 -K 4096 +./primus-cli direct -- benchmark gemm --M 4096 --N 4096 --K 4096 # Environment check (info only) ./primus-cli direct -- preflight --host --gpu --network @@ -115,7 +115,7 @@ Primus CLI supports three execution modes, each suitable for different scenarios # Set resource limits ./primus-cli container --cpus 32 --memory 256G \ - -- benchmark gemm -M 8192 -N 8192 -K 8192 + -- benchmark gemm --M 8192 --N 8192 --K 8192 # Mount local Primus code for development ./primus-cli container --volume ~/workspace/Primus:/workspace/Primus \ @@ -164,10 +164,18 @@ Primus CLI supports three execution modes, each suitable for different scenarios -- train pretrain --config deepseek_v2.yaml # Run distributed GEMM benchmark -./primus-cli slurm srun -N 2 -- benchmark gemm -M 16384 -N 16384 -K 16384 +./primus-cli slurm srun -N 2 -- benchmark gemm --M 16384 --N 16384 --K 16384 # Multi-node environment check (info only) +# this will generate a fast info report of the host, GPU, and network ./primus-cli slurm srun -N 4 -- preflight --host --gpu --network + +# this will generate a full preflight report of the host, GPU, and network, as well as the performance tests +./primus-cli slurm srun -N 4 -- preflight --report-file-name preflight-report-4N + +# if you are using AINIC in your cluster, use the appropriate configuration file +# for preflight test, set docker image to rocm/primus:v26.1 in the configuration file +./primus-cli --config runner/use_ainic.yaml slurm srun -N 2 -- preflight --report-file-name preflight-report-2N ``` **Suitable for**: @@ -250,6 +258,15 @@ direct: ./primus-cli --config prod.yaml slurm srun -N 8 -- train pretrain ``` +### Using AINIC Configuration File + +If you are using AINIC in your cluster, you can use `runner/use_ainic.yaml` configuration file to configure the AINIC environment. In this file, we have already set the environment variables for AINIC: `USING_AINIC=1`, `NCCL_PXN_DISABLE=0`, and `NCCL_IB_GID_INDEX=1`. You can modify the `NCCL_IB_GID_INDEX` value based on your AINIC settings. Also, you can modify the `image` value to the appropriate Docker image you are using. + +Here is an example of using the AINIC configuration file to run a training job: +```bash +./primus-cli --config runner/use_ainic.yaml slurm srun -N 2 -- train pretrain --config examples/maxtext/configs/MI355X/llama2_7B-pretrain.yaml +``` + ### Configuration Priority **Priority Order** (high to low): @@ -306,13 +323,13 @@ Command-line args > Specified config file > System default config > User config #### GEMM Benchmark ```bash # Single-node GEMM -./primus-cli direct -- benchmark gemm -M 4096 -N 4096 -K 4096 +./primus-cli direct -- benchmark gemm --M 4096 --N 4096 --K 4096 # Run in container -./primus-cli container -- benchmark gemm -M 8192 -N 8192 -K 8192 +./primus-cli container -- benchmark gemm --M 8192 --N 8192 --K 8192 # Multi-node GEMM -./primus-cli slurm srun -N 2 -- benchmark gemm -M 16384 -N 16384 -K 16384 +./primus-cli slurm srun -N 2 -- benchmark gemm --M 16384 --N 16384 --K 16384 ``` #### Other Benchmarks From 87bc2e758a81e5b2bb6b6a85d06197cee6799819 Mon Sep 17 00:00:00 2001 From: Fuyuan Jing <167437074+amd-fuyuajin@users.noreply.github.com> Date: Tue, 24 Feb 2026 22:29:50 -0500 Subject: [PATCH 17/24] Update docs/cli/PRIMUS-CLI-GUIDE.md accept copilot commit suggestion Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/cli/PRIMUS-CLI-GUIDE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/cli/PRIMUS-CLI-GUIDE.md b/docs/cli/PRIMUS-CLI-GUIDE.md index da629a4aa..2cb8577dc 100644 --- a/docs/cli/PRIMUS-CLI-GUIDE.md +++ b/docs/cli/PRIMUS-CLI-GUIDE.md @@ -260,7 +260,7 @@ direct: ### Using AINIC Configuration File -If you are using AINIC in your cluster, you can use `runner/use_ainic.yaml` configuration file to configure the AINIC environment. In this file, we have already set the environment variables for AINIC: `USING_AINIC=1`, `NCCL_PXN_DISABLE=0`, and `NCCL_IB_GID_INDEX=1`. You can modify the `NCCL_IB_GID_INDEX` value based on your AINIC settings. Also, you can modify the `image` value to the appropriate Docker image you are using. +If you are using AINIC in your cluster, you can use the `runner/use_ainic.yaml` configuration file to configure the AINIC environment. This file includes pre-configured environment variables for AINIC: `USING_AINIC=1`, `NCCL_PXN_DISABLE=0`, and `NCCL_IB_GID_INDEX=1`. You can modify the `NCCL_IB_GID_INDEX` value based on your AINIC settings and update the `image` value to match your Docker image. Here is an example of using the AINIC configuration file to run a training job: ```bash From fdb9c48cc45fc1ceea91247b5ef3b241b9845919 Mon Sep 17 00:00:00 2001 From: liyingli Date: Wed, 25 Feb 2026 07:38:58 +0000 Subject: [PATCH 18/24] fix up by review --- docs/cli/PRIMUS-CLI-GUIDE.md | 2 +- .../MI355X/llama3.1_405B-pretrain.yaml | 2 +- examples/run_pretrain.sh | 2 +- primus/backends/maxtext/configs/types.py | 4 ++-- .../backends/maxtext/layers/attention_op.py | 5 ++++- primus/backends/maxtext/train.py | 21 +++++++------------ .../configs/models/maxtext/llama3.1_405B.yaml | 2 +- requirements-jax.txt | 2 +- runner/helpers/hooks/03_enable_ainic.sh | 2 +- .../hooks/train/pretrain/maxtext/prepare.py | 2 +- 10 files changed, 20 insertions(+), 24 deletions(-) diff --git a/docs/cli/PRIMUS-CLI-GUIDE.md b/docs/cli/PRIMUS-CLI-GUIDE.md index 2cb8577dc..ff954816b 100644 --- a/docs/cli/PRIMUS-CLI-GUIDE.md +++ b/docs/cli/PRIMUS-CLI-GUIDE.md @@ -262,7 +262,7 @@ direct: If you are using AINIC in your cluster, you can use the `runner/use_ainic.yaml` configuration file to configure the AINIC environment. This file includes pre-configured environment variables for AINIC: `USING_AINIC=1`, `NCCL_PXN_DISABLE=0`, and `NCCL_IB_GID_INDEX=1`. You can modify the `NCCL_IB_GID_INDEX` value based on your AINIC settings and update the `image` value to match your Docker image. -Here is an example of using the AINIC configuration file to run a training job: +Here is an example of using the AINIC configuration file to run a training job: ```bash ./primus-cli --config runner/use_ainic.yaml slurm srun -N 2 -- train pretrain --config examples/maxtext/configs/MI355X/llama2_7B-pretrain.yaml ``` diff --git a/examples/maxtext/configs/MI355X/llama3.1_405B-pretrain.yaml b/examples/maxtext/configs/MI355X/llama3.1_405B-pretrain.yaml index f157536f0..0111a5db7 100644 --- a/examples/maxtext/configs/MI355X/llama3.1_405B-pretrain.yaml +++ b/examples/maxtext/configs/MI355X/llama3.1_405B-pretrain.yaml @@ -48,4 +48,4 @@ modules: scan_layers: True max_target_length: 8192 - per_device_batch_size: 5 \ No newline at end of file + per_device_batch_size: 5 diff --git a/examples/run_pretrain.sh b/examples/run_pretrain.sh index b8826ae72..1082f7ab6 100755 --- a/examples/run_pretrain.sh +++ b/examples/run_pretrain.sh @@ -289,7 +289,7 @@ if [ "${BACKEND:-}" == "MaxText" ]; then export DUMP_HLO_DIR=${DUMP_HLO_DIR:-"${PRIMUS_PATH}/output/xla_dump_hlo"} export DUMP_HLO=${DUMP_HLO:-0} export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 - if [ $NNODES -gt 1 ]; then + if [ "${NNODES}" -gt 1 ]; then export XLA_PYTHON_CLIENT_MEM_FRACTION=.93 export JAX_HIP_GRAPH_LOWERING=false else diff --git a/primus/backends/maxtext/configs/types.py b/primus/backends/maxtext/configs/types.py index 326ba7eaf..a6cf03703 100644 --- a/primus/backends/maxtext/configs/types.py +++ b/primus/backends/maxtext/configs/types.py @@ -221,13 +221,13 @@ def set_derived_and_validate_values(self) -> "PrimusMaxTextConfig": MaxTextConfig.set_derived_and_validate_values(self) # Add any Primus-specific validations here if needed - if self.wandb_save_dir is None or self.wandb_save_dir == "" and self.base_output_directory: + if (self.wandb_save_dir is None or self.wandb_save_dir == "") and self.base_output_directory: self.wandb_save_dir = os.path.join(self.base_output_directory, "wandb") if self.wandb_project is None or self.wandb_project == "": self.wandb_project = os.getenv("WANDB_PROJECT", "Primus-MaxText-Pretrain") - if self.wandb_exp_name is None or self.wandb_exp_name == "" and self.run_name: + if (self.wandb_exp_name is None or self.wandb_exp_name == "") and self.run_name: self.wandb_exp_name = self.run_name if self.enable_wandb and "WANDB_API_KEY" not in os.environ: diff --git a/primus/backends/maxtext/layers/attention_op.py b/primus/backends/maxtext/layers/attention_op.py index 750b1ea41..c10a0a2ad 100644 --- a/primus/backends/maxtext/layers/attention_op.py +++ b/primus/backends/maxtext/layers/attention_op.py @@ -50,6 +50,7 @@ def cudnn_flash_attention( mask_type = "padding_causal" qkv_layout = "BSHD_BSHD_BSHD" # Non-packed format: 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' max_segments_per_seq = 1 # max number of segments per sequence; for non-packed its 1 + attn_mask_threshold = 0.5 # Handle local sliding window attention if configured if self.attention_type == AttentionType.LOCAL_SLIDING: @@ -82,7 +83,9 @@ def cudnn_flash_attention( (1, 1, 1, self.max_target_length, self.max_target_length), dtype=jnp.uint8 ) attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) - attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8) + attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * attn_mask_threshold), 0, 1).astype( + jnp.uint8 + ) dpa_layer = DotProductAttention( head_dim=head_dim, diff --git a/primus/backends/maxtext/train.py b/primus/backends/maxtext/train.py index 563e25962..ac918c64e 100644 --- a/primus/backends/maxtext/train.py +++ b/primus/backends/maxtext/train.py @@ -247,17 +247,10 @@ def initialize(argv: Sequence[str], **kwargs) -> tuple[pyconfig.HyperParameters, def run(config, recorder, diagnostic_config): """Run the job given hyperparameters and utilities""" - try: - with ( - diagnostic.diagnose(diagnostic_config), - maybe_record_goodput(recorder, GoodputEvent.JOB), - max_utils.maybe_get_transformer_engine_context(config), - maybe_monitor_goodput(config), - ): - train_loop(config, recorder) - except Exception as e: - max_logging.log(f"Error in train_loop: {e}") - import traceback - - max_logging.log(f"Traceback: {traceback.format_exc()}") - raise + with ( + diagnostic.diagnose(diagnostic_config), + maybe_record_goodput(recorder, GoodputEvent.JOB), + max_utils.maybe_get_transformer_engine_context(config), + maybe_monitor_goodput(config), + ): + train_loop(config, recorder) diff --git a/primus/configs/models/maxtext/llama3.1_405B.yaml b/primus/configs/models/maxtext/llama3.1_405B.yaml index 8aa118a0b..1519b43e9 100644 --- a/primus/configs/models/maxtext/llama3.1_405B.yaml +++ b/primus/configs/models/maxtext/llama3.1_405B.yaml @@ -4,4 +4,4 @@ extends: model_name: "llama3.1-405b" tokenizer_path: "meta-llama/Llama-3.3-70B-Instruct" attention: "cudnn_flash_te" -use_iota_embed: true \ No newline at end of file +use_iota_embed: true diff --git a/requirements-jax.txt b/requirements-jax.txt index 9de6a6b9a..1a65a0ff1 100644 --- a/requirements-jax.txt +++ b/requirements-jax.txt @@ -1,3 +1,3 @@ loguru wandb -pre-commit \ No newline at end of file +pre-commit diff --git a/runner/helpers/hooks/03_enable_ainic.sh b/runner/helpers/hooks/03_enable_ainic.sh index 5bd162378..f79bbb8e0 100755 --- a/runner/helpers/hooks/03_enable_ainic.sh +++ b/runner/helpers/hooks/03_enable_ainic.sh @@ -41,7 +41,7 @@ NCCL_IB_QPS_PER_CONNECTION="${NCCL_IB_QPS_PER_CONNECTION:-1}" # LD_LIBRARY_PATH: prepend AINIC/RCCL/MPI paths while preserving existing. _ld_base="/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/libibverbs:${RCCL_HOME_DIR}/build/release:${ANP_HOME_DIR}/build:${MPI_HOME_DIR}/install/lib" -# Need to append AINIC/RCCL/MPI paths to the existing LD_LIBRARY_PATH. Otherwise, +# Need to append AINIC/RCCL/MPI paths to the existing LD_LIBRARY_PATH. Otherwise, # JAX MaxText will not find the appropriate ROCm libraries. LD_LIBRARY_PATH="${LD_LIBRARY_PATH:+${LD_LIBRARY_PATH}:}${_ld_base}" LOG_INFO_RANK0 "Using AINIC" diff --git a/runner/helpers/hooks/train/pretrain/maxtext/prepare.py b/runner/helpers/hooks/train/pretrain/maxtext/prepare.py index 0b94cfd55..75bbe8d73 100644 --- a/runner/helpers/hooks/train/pretrain/maxtext/prepare.py +++ b/runner/helpers/hooks/train/pretrain/maxtext/prepare.py @@ -236,7 +236,7 @@ def main(): print(f"env.DUMP_HLO_DIR={dump_hlo_dir}") print(f"env.DUMP_HLO={dump_hlo}") print("env.NVTE_ALLOW_NONDETERMINISTIC_ALGO=1") - # set XLA_PYTHON_CLIENT_MEM_FRACTION to 0.93 + # set XLA_PYTHON_CLIENT_MEM_FRACTION to 0.93 # to avoid HSA_STATUS_ERROR_OUT_OF_RESOURCES error during multi-node training print("env.XLA_PYTHON_CLIENT_MEM_FRACTION=.93") print("env.NVTE_USE_HIPBLASLT=1") From 36a162d9402e510b4239405a40580d76821a0f78 Mon Sep 17 00:00:00 2001 From: liyingli Date: Wed, 25 Feb 2026 14:00:24 +0000 Subject: [PATCH 19/24] update cicd for maxtext --- .github/workflows/ci.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 9853068d4..c581310f9 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -13,7 +13,7 @@ env: PRIMUS_TURBO_COMMIT: 5233748e9c5c5795a6484ab31ece47c442d29ec2 # feat(mxfp4): refactor gemm mxfp4 and mxfp8. fuse transpose, hadamard transform and quantization. (#195) ROCSHMEM_COMMIT: 17ff985c026f9f97f85068647e863ab541dd5645 # Update version to 3.2.0 for 7.2.0 rocm release (#351) (#355) BASE_IMAGE: docker.io/rocm/primus:v26.1 - MAXTEXT_BASE_IMAGE: docker.io/rocm/jax-training:maxtext-v25.9 + MAXTEXT_BASE_IMAGE: docker.io/rocm/jax-training:maxtext-v26.2 jobs: code-lint: @@ -286,7 +286,7 @@ jobs: echo "✅ [Pip install requirements] started at: $(date)" mkdir -p ${PRIMUS_WORKDIR}/primus-cache python3 -m pip install --upgrade pip setuptools - pip3 install --cache-dir=${PRIMUS_WORKDIR}/primus-cache --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.0 + pip3 install --cache-dir=${PRIMUS_WORKDIR}/primus-cache --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.1 MAX_JOBS=128 pip3 install --cache-dir=${PRIMUS_WORKDIR}/primus-cache --no-build-isolation --no-clean -r requirements.txt end_time=$(date +%s) elapsed=$((end_time - start_time)) From ba5c95ccb218a62870cced7b257c44e3528a3e57 Mon Sep 17 00:00:00 2001 From: liyingli Date: Thu, 26 Feb 2026 14:40:43 +0000 Subject: [PATCH 20/24] disable turbo install to avoid segfault, update cicd for jax and enable model override args set --- .github/workflows/ci.yaml | 9 ++---- primus/modules/trainer/maxtext/pre_trainer.py | 10 +++--- tests/trainer/test_maxtext_trainer.py | 32 +++++++++---------- 3 files changed, 24 insertions(+), 27 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c581310f9..30c8f4941 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -13,7 +13,7 @@ env: PRIMUS_TURBO_COMMIT: 5233748e9c5c5795a6484ab31ece47c442d29ec2 # feat(mxfp4): refactor gemm mxfp4 and mxfp8. fuse transpose, hadamard transform and quantization. (#195) ROCSHMEM_COMMIT: 17ff985c026f9f97f85068647e863ab541dd5645 # Update version to 3.2.0 for 7.2.0 rocm release (#351) (#355) BASE_IMAGE: docker.io/rocm/primus:v26.1 - MAXTEXT_BASE_IMAGE: docker.io/rocm/jax-training:maxtext-v26.2 + MAXTEXT_BASE_IMAGE: docker.io/rocm/jax-training:maxtext-v26.1 jobs: code-lint: @@ -263,7 +263,7 @@ jobs: env: PRIMUS_WORKDIR: /wekafs/primus-data/primus_safe_ci/jax needs: [code-lint] - runs-on: [primus-lm-cicd-jax-8t8mh] + runs-on: [primus-lm-cicd-jax-v26d1-dl6qc] steps: - run: echo "🎉 Begin Primus-Turbo Checkout." - name: Set commit hash to env @@ -286,19 +286,16 @@ jobs: echo "✅ [Pip install requirements] started at: $(date)" mkdir -p ${PRIMUS_WORKDIR}/primus-cache python3 -m pip install --upgrade pip setuptools - pip3 install --cache-dir=${PRIMUS_WORKDIR}/primus-cache --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.1 - MAX_JOBS=128 pip3 install --cache-dir=${PRIMUS_WORKDIR}/primus-cache --no-build-isolation --no-clean -r requirements.txt end_time=$(date +%s) elapsed=$((end_time - start_time)) echo "✅ [Pip install requirements] ended at: $(date)" echo "⏱️ [Pip install requirements] Total elapsed time: ${elapsed} seconds" start_time=$(date +%s) echo "✅ [build primus-turbo] started at: $(date)" - PRIMUS_TURBO_FRAMEWORK="JAX" pip3 install --no-build-isolation -e . -v end_time=$(date +%s) elapsed=$((end_time - start_time)) echo "✅ [build primus-turbo] ended at: $(date)" - echo "⏱️ [build primus-turbo] Total elapsed time: ${elapsed} seconds" + echo "⏱️ [build primus-turbo] Torch installation causes segfault, so we skip it and actually not install turbo. Total elapsed time: ${elapsed} seconds" - run: echo "🎉 Begin Primus Unit Test." - uses: actions/checkout@v4 with: diff --git a/primus/modules/trainer/maxtext/pre_trainer.py b/primus/modules/trainer/maxtext/pre_trainer.py index 684316b46..cf2146b61 100644 --- a/primus/modules/trainer/maxtext/pre_trainer.py +++ b/primus/modules/trainer/maxtext/pre_trainer.py @@ -60,7 +60,7 @@ def prepare_model_overrides(self, override_args: Dict[str, Any]): """ Monkey patch maxtext cli args to override model args dynamically. Supports nested overrides like: - {"model": {"num_experts": 16, "base_num_decoder_layers": 4}} + {"override_model": {"num_experts": 16, "base_num_decoder_layers": 4}} All override keys MUST be under the "model" key. """ @@ -71,14 +71,14 @@ def prepare_model_overrides(self, override_args: Dict[str, Any]): warning_rank_0(f"MaxText Pre-Trainer: Applying override_args: {override_args}") - # --- Step 1. Flatten any nested dict under 'model' + # --- Step 1. Flatten any nested dict under 'override_model' flat_overrides = {} for k, v in override_args.items(): - if k != "model": - raise ValueError(f"Only the 'model' key is supported for overrides, found: {k}") + if k != "override_model": + raise ValueError(f"Only the 'override_model' key is supported for overrides, found: {k}") if not isinstance(v, dict): raise ValueError( - f"MaxText Pre-Trainer: The value for 'model' must be a dict, got {type(v).__name__}." + f"MaxText Pre-Trainer: The value for 'override_model' must be a dict, got {type(v).__name__}." ) for subk, subv in v.items(): if isinstance(subv, dict): diff --git a/tests/trainer/test_maxtext_trainer.py b/tests/trainer/test_maxtext_trainer.py index 6e04344e0..945ea6aef 100644 --- a/tests/trainer/test_maxtext_trainer.py +++ b/tests/trainer/test_maxtext_trainer.py @@ -88,7 +88,7 @@ def test_llama3_8B_BF16(self): "llama3_8B-BF16", exp_path="examples/maxtext/configs/MI300X/llama3_8B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -101,7 +101,7 @@ def test_llama3_8B_FP8(self): "llama3_8B-FP8", exp_path="examples/maxtext/configs/MI300X/llama3_8B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -117,7 +117,7 @@ def test_llama3_70B_BF16(self): "llama3_70B-BF16", exp_path="examples/maxtext/configs/MI300X/llama3_70B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -131,7 +131,7 @@ def test_llama3_70B_FP8(self): "llama3_70B-FP8", exp_path="examples/maxtext/configs/MI300X/llama3_70B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -147,7 +147,7 @@ def test_llama3_3_70B_BF16(self): "llama3_3_70B-BF16", exp_path="examples/maxtext/configs/MI300X/llama3.3_70B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -161,7 +161,7 @@ def test_llama3_3_70B_FP8(self): "llama3_3_70B-FP8", exp_path="examples/maxtext/configs/MI300X/llama3.3_70B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -177,7 +177,7 @@ def test_llama2_7B_BF16(self): "llama2_7B-BF16", exp_path="examples/maxtext/configs/MI300X/llama2_7B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -191,7 +191,7 @@ def test_llama2_7B_FP8(self): "llama2_7B-FP8", exp_path="examples/maxtext/configs/MI300X/llama2_7B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -207,7 +207,7 @@ def test_llama2_70B_BF16(self): "llama2_70B-BF16", exp_path="examples/maxtext/configs/MI300X/llama2_70B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -221,7 +221,7 @@ def test_llama2_70B_FP8(self): "llama2_70B-FP8", exp_path="examples/maxtext/configs/MI300X/llama2_70B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -237,7 +237,7 @@ def test_mixtral_8x7B_BF16(self): "mixtral_8x7B-BF16", exp_path="examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -251,7 +251,7 @@ def test_mixtral_8x7B_FP8(self): "mixtral_8x7B-FP8", exp_path="examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -267,7 +267,7 @@ def test_grok1_BF16(self): "grok1-BF16", exp_path="examples/maxtext/configs/MI300X/grok1-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -281,7 +281,7 @@ def test_grok1_FP8(self): "grok1-FP8", exp_path="examples/maxtext/configs/MI300X/grok1-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -297,7 +297,7 @@ def test_dpsk_v2_16B_BF16(self): "dpsk_v2_16B-BF16", exp_path="examples/maxtext/configs/MI300X/deepseek_v2_16B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -311,7 +311,7 @@ def test_dpsk_v2_16B_FP8(self): "dpsk_v2_16B-FP8", exp_path="examples/maxtext/configs/MI300X/deepseek_v2_16B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", From 8357b2fbda1711874f15b3fb18df36b524b56d75 Mon Sep 17 00:00:00 2001 From: liyingli Date: Fri, 27 Feb 2026 08:49:30 +0000 Subject: [PATCH 21/24] unify NCCL_IB_TC and NCCL_IB_FIFO_TC for maxtext and torch --- examples/run_local_pretrain.sh | 7 ++ examples/run_pretrain.sh | 38 ++++----- examples/scripts/detect_nccl_ib_tc.sh | 110 ++++++++++++++++++++++++++ 3 files changed, 137 insertions(+), 18 deletions(-) create mode 100644 examples/scripts/detect_nccl_ib_tc.sh diff --git a/examples/run_local_pretrain.sh b/examples/run_local_pretrain.sh index 33327c73b..d02da3f63 100755 --- a/examples/run_local_pretrain.sh +++ b/examples/run_local_pretrain.sh @@ -125,6 +125,13 @@ if [ "$USING_AINIC" == "1" ]; then ENV_ARGS+=("--env" "ANP_HOME_DIR") ENV_ARGS+=("--env" "MPI_HOME_DIR") + TC_RESULTS=$(bash "${PRIMUS_PATH}/examples/scripts/detect_nccl_ib_tc.sh") + if [ -z "$TC_RESULTS" ]; then + echo "TC_RESULTS: $TC_RESULTS" + ENV_ARGS+=("--env" "TC_RESULTS") + else + echo "Failed to detect NCCL_IB_TC and NCCL_IB_FIFO_TC" + fi # VOLUME_ARGS+=(-v /mnt/shared:/mnt/shared) # VOLUME_ARGS+=(-v /etc/libibverbs.d/:/etc/libibverbs.d:ro) # VOLUME_ARGS+=(-v /usr/lib/x86_64-linux-gnu/libibverbs/:/usr/lib/x86_64-linux-gnu/libibverbs/:ro) diff --git a/examples/run_pretrain.sh b/examples/run_pretrain.sh index 1082f7ab6..b02ae0910 100755 --- a/examples/run_pretrain.sh +++ b/examples/run_pretrain.sh @@ -176,21 +176,33 @@ export NCCL_CHECKS_DISABLE=1 # Set InfiniBand GID index for NCCL communication if [ "$USING_AINIC" == "1" ]; then LOG_INFO_RANK0 "Using AINIC" + # unset NCCL_IB_GID_INDEX + export NCCL_IB_GID_INDEX=1 + # export NCCL_IB_ROCE_VERSION_NUM=2 + if [ -z "${TC_RESULTS:-}" ]; then + export NCCL_IB_TC=${NCCL_IB_TC:-104} + export NCCL_IB_FIFO_TC=${NCCL_IB_FIFO_TC:-192} + else + read -r NCCL_IB_TC NCCL_IB_FIFO_TC <<< "$TC_RESULTS" + export NCCL_IB_TC + export NCCL_IB_FIFO_TC + fi + export NET_OPTIONAL_RECV_COMPLETION=1 + export NCCL_IB_USE_INLINE=1 + export RCCL_GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING=0 + export NCCL_GDR_FLUSH_DISABLE=1 + export NCCL_IGNORE_CPU_AFFINITY=1 + LOG_INFO_RANK0 "NCCL_IB_TC: $NCCL_IB_TC" + LOG_INFO_RANK0 "NCCL_IB_FIFO_TC: $NCCL_IB_FIFO_TC" + if [ "${BACKEND:-}" == "MaxText" ]; then # ------- RCCL/NCCL IB Tuning ------- export IONIC_LOCKFREE=all export NCCL_GDR_COPY_ENABLE=1 - export NCCL_GDR_FLUSH_DISABLE=1 export NCCL_IB_ECE_ENABLE=0 - export NCCL_IB_FIFO_TC=184 - export NCCL_IB_GID_INDEX=1 export NCCL_IB_PCI_RELAXED_ORDERING=1 - export NCCL_IB_TC=96 - export NCCL_IB_USE_INLINE=1 - export NCCL_IGNORE_CPU_AFFINITY=1 + export NCCL_PXN_DISABLE=0 - export NET_OPTIONAL_RECV_COMPLETION=1 - export RCCL_GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING=0 export RCCL_LL128_FORCE_ENABLE=1 else export ANP_HOME_DIR=${ANP_HOME_DIR:-"/opt/amd-anp"} @@ -202,18 +214,8 @@ if [ "$USING_AINIC" == "1" ]; then LOG_INFO_RANK0 "ANP_HOME_DIR: $ANP_HOME_DIR" LOG_INFO_RANK0 "MPI_HOME_DIR: $MPI_HOME_DIR" - # unset NCCL_IB_GID_INDEX - export NCCL_IB_GID_INDEX=1 - # export NCCL_IB_ROCE_VERSION_NUM=2 export NCCL_MAX_P2P_CHANNELS=56 - export NCCL_IB_TC=104 - export NCCL_IB_FIFO_TC=192 - export NET_OPTIONAL_RECV_COMPLETION=1 - export NCCL_IB_USE_INLINE=1 - export RCCL_GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING=0 - export NCCL_GDR_FLUSH_DISABLE=1 export NCCL_DMABUF_ENABLE=0 - export NCCL_IGNORE_CPU_AFFINITY=1 export NCCL_IB_QPS_PER_CONNECTION=1 export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/libibverbs:${RCCL_HOME_DIR}/build/release:${ANP_HOME_DIR}/build:${MPI_HOME_DIR}/lib:$LD_LIBRARY_PATH diff --git a/examples/scripts/detect_nccl_ib_tc.sh b/examples/scripts/detect_nccl_ib_tc.sh new file mode 100644 index 000000000..e24f0cf35 --- /dev/null +++ b/examples/scripts/detect_nccl_ib_tc.sh @@ -0,0 +1,110 @@ +#!/bin/bash +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +# +# Auto-detect correct NCCL_IB_TC and NCCL_IB_FIFO_TC for Pensando AINIC clusters. +# Reads QoS DSCP-to-priority mapping from nicctl and finds the PFC-protected DSCP. +# +# Usage: +# source detect_nccl_ib_tc.sh # sets NCCL_IB_TC and NCCL_IB_FIFO_TC +# eval $(./detect_nccl_ib_tc.sh) # alternative: export from subshell + +set -euo pipefail + +is_pensando() { + local ib_dev="" + + for dev in /sys/class/infiniband/*; do + [ -e "$dev" ] || continue + ib_dev=$(basename "$dev") + break + done + [ -z "$ib_dev" ] && return 1 + + if echo "$ib_dev" | grep -qi "ionic"; then + return 0 + fi + + local ca_type + ca_type=$(ibstat "$ib_dev" 2>/dev/null | grep "CA type:" | head -1 || true) + echo "$ca_type" | grep -qi "Pensando" +} + +detect_pensando_tc() { + if ! command -v nicctl &>/dev/null; then + echo "WARN: nicctl not found, using known Pensando defaults" >&2 + echo "104 192" + return + fi + + local qos_output + qos_output=$(nicctl show qos 2>/dev/null) || { + echo "WARN: nicctl show qos failed, using defaults" >&2 + echo "104 192" + return + } + + local pfc_prio + pfc_prio=$(echo "$qos_output" | grep "PFC no-drop priorities" | head -1 | awk '{print $NF}') + + if [ -z "$pfc_prio" ]; then + echo "WARN: Could not determine PFC priority, using defaults" >&2 + echo "104 192" + return + fi + + # nicctl output lines look like: + # DSCP : 26 ==> priority : 3 + # DSCP bitmap : 0x0000000004000000 ==> priority : 3 + # DSCP : 0-25, 27-47, 49-63 ==> priority : 0 + # We want the single DSCP (not bitmap, not range) that maps to each priority. + + # Helper: extract single-value DSCP for a given priority + extract_dscp_for_priority() { + echo "$qos_output" \ + | grep -v "bitmap" \ + | grep "DSCP" \ + | grep "==> priority : ${1}$" \ + | head -1 \ + | sed 's/.*DSCP[^:]*: *//' \ + | sed 's/ *==> .*//' \ + | tr -d ' ' + } + + # NCCL_IB_TC: use DSCP that maps to PFC-protected (no-drop) priority + local data_dscp + data_dscp=$(extract_dscp_for_priority "$pfc_prio") + + if ! echo "$data_dscp" | grep -qE '^[0-9]+$'; then + echo "WARN: Could not parse DSCP for PFC priority $pfc_prio, using defaults" >&2 + echo "104 192" + return + fi + + # NCCL_IB_FIFO_TC: use DSCP that maps to the strict-priority queue + # (scheduling output: priority N has "strict" type) + local strict_prio + strict_prio=$(echo "$qos_output" | grep -i "strict" | head -1 | awk '{print $1}') + local fifo_dscp="" + if [ -n "$strict_prio" ] && echo "$strict_prio" | grep -qE '^[0-9]+$'; then + fifo_dscp=$(extract_dscp_for_priority "$strict_prio") + fi + + if ! echo "$fifo_dscp" | grep -qE '^[0-9]+$'; then + echo "WARN: Could not find strict-priority DSCP, using same as data" >&2 + fifo_dscp="$data_dscp" + fi + + echo "$((data_dscp * 4)) $((fifo_dscp * 4))" +} + +if ! is_pensando; then + echo "# Not a Pensando AINIC cluster, no NCCL_IB_TC override needed" >&2 + exit 0 +fi + +result=$(detect_pensando_tc) +echo "$result" From 694420e5aed9e92ec190801c0be87ff32d453548 Mon Sep 17 00:00:00 2001 From: liyingli Date: Tue, 3 Mar 2026 09:28:38 +0000 Subject: [PATCH 22/24] update jax ainic image and add detect_nccl_ib_tc for cli and remove detect_framework in step 5 for cli --- .github/workflows/ci.yaml | 19 +++ .github/workflows/docker/Dockerfile_jax.ainic | 40 ++++++ examples/run_pretrain.sh | 37 +++-- runner/helpers/hooks/03_detect_nccl_ib_tc.sh | 133 ++++++++++++++++++ runner/helpers/hooks/03_enable_ainic.sh | 2 +- runner/primus-cli-direct.sh | 61 ++------ runner/use_ainic.yaml | 1 + 7 files changed, 226 insertions(+), 67 deletions(-) create mode 100644 .github/workflows/docker/Dockerfile_jax.ainic create mode 100755 runner/helpers/hooks/03_detect_nccl_ib_tc.sh diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 2d3bc3bad..ee483e48b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -168,6 +168,25 @@ jobs: docker push docker.io/tasimage/primus:${{env.IMAGE_TAG}}-jax docker login -u rocmshared -p ${{ secrets.ROCM_DOCKER_HUB_TOKEN }} + echo "> Build Docker Image with tag: ${{ env.IMAGE_TAG }}-jax-ainic" + start_time=$(date +%s) + mkdir -p $GITHUB_WORKSPACE/.github/workflows/docker/ainic + cp /apps/tas/0_public/primus_docker_ci/ainic/ainic_bundle_1.117.5-a-56.tar.gz $GITHUB_WORKSPACE/.github/workflows/docker/ainic/ || { echo "Error: Failed to copy ainic bundle"; exit 1; } + docker build -f $GITHUB_WORKSPACE/.github/workflows/docker/Dockerfile_jax.ainic \ + --network=host \ + -t tasimage/primus:${{env.IMAGE_TAG}}-jax-ainic \ + --build-arg BASE_IMAGE=${MAXTEXT_BASE_IMAGE} \ + --build-arg AINIC_BUNDLE_PATH=ainic \ + $GITHUB_WORKSPACE/.github/workflows/docker + end_time=$(date +%s) + elapsed=$((end_time - start_time)) + echo "⏱️ [build primus docker-jax-ainic] Total elapsed time: ${elapsed} seconds" + + docker tag tasimage/primus:${{env.IMAGE_TAG}}-jax-ainic docker.io/tasimage/primus:${{env.IMAGE_TAG}}-jax-ainic + docker login -u tasimage -p ${{ secrets.PRIMUS_DOCKER_HUB_TOKEN }} + docker push docker.io/tasimage/primus:${{env.IMAGE_TAG}}-jax-ainic + docker login -u rocmshared -p ${{ secrets.ROCM_DOCKER_HUB_TOKEN }} + # echo "> Docker cleanup local images" # docker rmi tasimage/primus:${{env.IMAGE_TAG}} # docker rmi tasimage/primus:${{env.IMAGE_TAG}}-v25.09-ainic diff --git a/.github/workflows/docker/Dockerfile_jax.ainic b/.github/workflows/docker/Dockerfile_jax.ainic new file mode 100644 index 000000000..085678ae7 --- /dev/null +++ b/.github/workflows/docker/Dockerfile_jax.ainic @@ -0,0 +1,40 @@ +# Base image +ARG BASE_IMAGE +FROM ${BASE_IMAGE} + +ARG AINIC_BUNDLE_PATH + +# Non-interactive APT +ENV DEBIAN_FRONTEND=noninteractive + + +# --------------------------------------------------------------------------- +# Install build dependencies +# --------------------------------------------------------------------------- +RUN rm /etc/apt/sources.list.d/*radeon* && \ + apt update && \ + apt install initramfs-tools -y + +# --------------------------------------------------------------------------- +# Enviroment variables +# --------------------------------------------------------------------------- +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib +ENV WORKDIR=/workspace + +# =============================== Build AINIC Driver =============================== +# WARNING: Please ensure the following environment variables are correctly set: +# WARNING: 1. PATH: /usr/sbin must be included. +# WARNING: 2. LD_LIBRARY_PATH: /usr/lib must be included. +# WARNING: If these paths are missing, tools and libraries may not function correctly. +# INFO: Installation completed successfully + +COPY ${AINIC_BUNDLE_PATH}/ainic_bundle_1.117.5-a-56.tar.gz ${WORKDIR} +RUN cd ${WORKDIR} && \ + echo "Building ainic bundle... current directory: ${WORKDIR}" && \ + tar zxf ainic_bundle_1.117.5-a-56.tar.gz && \ + cd ainic_bundle_1.117.5-a-56 && \ + tar zxf host_sw_pkg.tar.gz && \ + cd host_sw_pkg && \ + ./install.sh --domain=user -y 2>&1 | tee log_install.txt && \ + cd ${WORKDIR} && \ + apt-get install -y ./amd/ainic/deb-repo/libionic*.deb diff --git a/examples/run_pretrain.sh b/examples/run_pretrain.sh index 930a5f2a9..1729cb50c 100755 --- a/examples/run_pretrain.sh +++ b/examples/run_pretrain.sh @@ -196,6 +196,14 @@ if [ "$USING_AINIC" == "1" ]; then LOG_INFO_RANK0 "NCCL_IB_FIFO_TC: $NCCL_IB_FIFO_TC" if [ "${BACKEND:-}" == "MaxText" ]; then + if ! command -v ibv_devinfo &>/dev/null || ! ibv_devinfo &>/dev/null; then + LOG_ERROR "Error: ibv_devinfo not found. Please upgrade driver or use tasimage image." + exit 1 + fi + + export ANP_HOME_DIR=${ANP_HOME_DIR:-"/workspace/amd-anp"} + export RCCL_HOME_DIR=${RCCL_HOME_DIR:-"/workspace/rccl"} + export MPI_HOME_DIR=${MPI_HOME_DIR:-"/ompi-4.1.6/install/"} # ------- RCCL/NCCL IB Tuning ------- export IONIC_LOCKFREE=all export NCCL_GDR_COPY_ENABLE=1 @@ -204,23 +212,12 @@ if [ "$USING_AINIC" == "1" ]; then export NCCL_PXN_DISABLE=0 export RCCL_LL128_FORCE_ENABLE=1 + + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/libibverbs:${RCCL_HOME_DIR}/build/release:${ANP_HOME_DIR}/build:${MPI_HOME_DIR}/lib else export ANP_HOME_DIR=${ANP_HOME_DIR:-"/opt/amd-anp"} export RCCL_HOME_DIR=${RCCL_HOME_DIR:-"/opt/rccl"} export MPI_HOME_DIR=${MPI_HOME_DIR:-"/opt/ompi"} - # Check which NCCL net plugin library is present under ${ANP_HOME_DIR}/build and set accordingly - if [ -f "${ANP_HOME_DIR}/build/librccl-anp.so" ]; then - export NCCL_NET_PLUGIN=librccl-anp.so - elif [ -f "${ANP_HOME_DIR}/build/librccl-net.so" ]; then - export NCCL_NET_PLUGIN=librccl-net.so - else - LOG_ERROR "Error: Neither librccl-anp.so nor librccl-net.so found in ${ANP_HOME_DIR}/build." - exit 1 - fi - - LOG_INFO_RANK0 "RCCL_HOME_DIR: $RCCL_HOME_DIR" - LOG_INFO_RANK0 "ANP_HOME_DIR: $ANP_HOME_DIR" - LOG_INFO_RANK0 "MPI_HOME_DIR: $MPI_HOME_DIR" export NCCL_MAX_P2P_CHANNELS=56 export NCCL_DMABUF_ENABLE=0 @@ -228,6 +225,20 @@ if [ "$USING_AINIC" == "1" ]; then export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/libibverbs:${RCCL_HOME_DIR}/build/release:${ANP_HOME_DIR}/build:${MPI_HOME_DIR}/lib:$LD_LIBRARY_PATH fi + # Check which NCCL net plugin library is present under ${ANP_HOME_DIR}/build and set accordingly + if [ -f "${ANP_HOME_DIR}/build/librccl-anp.so" ]; then + export NCCL_NET_PLUGIN=librccl-anp.so + elif [ -f "${ANP_HOME_DIR}/build/librccl-net.so" ]; then + export NCCL_NET_PLUGIN=librccl-net.so + else + LOG_ERROR "Error: Neither librccl-anp.so nor librccl-net.so found in ${ANP_HOME_DIR}/build." + exit 1 + fi + + LOG_INFO_RANK0 "RCCL_HOME_DIR: $RCCL_HOME_DIR" + LOG_INFO_RANK0 "ANP_HOME_DIR: $ANP_HOME_DIR" + LOG_INFO_RANK0 "MPI_HOME_DIR: $MPI_HOME_DIR" + LOG_INFO_RANK0 "NCCL_NET_PLUGIN: $NCCL_NET_PLUGIN" else export NCCL_IB_GID_INDEX=3 fi diff --git a/runner/helpers/hooks/03_detect_nccl_ib_tc.sh b/runner/helpers/hooks/03_detect_nccl_ib_tc.sh new file mode 100755 index 000000000..0bcfe88aa --- /dev/null +++ b/runner/helpers/hooks/03_detect_nccl_ib_tc.sh @@ -0,0 +1,133 @@ +#!/bin/bash +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +# +# Global hook: auto-detect NCCL_IB_TC and NCCL_IB_FIFO_TC for Pensando AINIC. +# +# Runs BEFORE 03_enable_ainic.sh (sort order: 03_d < 03_e). +# On Pensando clusters this hook queries nicctl to find the actual PFC-protected +# DSCP and sets NCCL_IB_TC / NCCL_IB_FIFO_TC accordingly. +# +# Priority chain: manual env > this hook (hardware detect) > 03_enable_ainic defaults +# +# If the user already exported NCCL_IB_TC / NCCL_IB_FIFO_TC, this hook +# respects them and exits immediately. +# +# On non-Pensando / non-AINIC clusters this hook is a no-op. +# +# This hook emits env.* lines which will be exported by execute_hooks.sh. +############################################################################### + +set -euo pipefail + +# --------------------------------------------------------------------------- +# Guard: only relevant when AINIC is enabled +# --------------------------------------------------------------------------- +if [[ "${USING_AINIC:-0}" != "1" ]]; then + exit 0 +fi + +# --------------------------------------------------------------------------- +# Guard: respect user-specified values (manual > detect) +# --------------------------------------------------------------------------- +if [[ -n "${NCCL_IB_TC:-}" && -n "${NCCL_IB_FIFO_TC:-}" ]]; then + echo "NCCL_IB_TC and NCCL_IB_FIFO_TC already set, skipping auto-detect" >&2 + exit 0 +fi + +# --------------------------------------------------------------------------- +# Detect whether the first IB device is a Pensando/ionic NIC +# --------------------------------------------------------------------------- +_is_pensando() { + local ib_dev="" + for dev in /sys/class/infiniband/*; do + [ -e "$dev" ] || continue + ib_dev=$(basename "$dev") + break + done + [ -z "$ib_dev" ] && return 1 + + # Quick path: device name contains "ionic" + if echo "$ib_dev" | grep -qi "ionic"; then + return 0 + fi + + # Fallback: check ibstat CA type + local ca_type + ca_type=$(ibstat "$ib_dev" 2>/dev/null | grep "CA type:" | head -1 || true) + echo "$ca_type" | grep -qi "Pensando" +} + +if ! _is_pensando; then + # Not a Pensando AINIC cluster — nothing to do. + exit 0 +fi + +# --------------------------------------------------------------------------- +# Ensure nicctl is available (try to install if missing) +# --------------------------------------------------------------------------- +if ! command -v nicctl &>/dev/null; then + echo "nicctl not found, attempting to install..." >&2 + if ! apt-get install -y nicctl &>/dev/null; then + echo "WARN: Failed to install nicctl, keeping current defaults" >&2 + exit 0 + fi +fi + +# Double-check after install attempt +if ! command -v nicctl &>/dev/null; then + echo "WARN: nicctl still not available after install, keeping current defaults" >&2 + exit 0 +fi + +# --------------------------------------------------------------------------- +# Query nicctl for the correct DSCP values +# --------------------------------------------------------------------------- +qos_output=$(nicctl show qos 2>/dev/null) || { + echo "WARN: nicctl show qos failed, keeping current defaults" >&2 + exit 0 +} + +pfc_prio=$(echo "$qos_output" | grep "PFC no-drop priorities" | head -1 | awk '{print $NF}' || true) +if [ -z "$pfc_prio" ]; then + echo "WARN: Could not determine PFC priority, keeping current defaults" >&2 + exit 0 +fi + +# Extract single-value DSCP for a given priority (skip bitmap and range lines) +_extract_dscp_for_priority() { + echo "$qos_output" \ + | grep -v "bitmap" \ + | grep "DSCP" \ + | grep "==> priority : ${1}$" \ + | head -1 \ + | sed 's/.*DSCP[^:]*: *//' \ + | sed 's/ *==> .*//' \ + | tr -d ' ' || true +} + +# NCCL_IB_TC: DSCP mapped to PFC-protected (no-drop) priority × 4 +data_dscp=$(_extract_dscp_for_priority "$pfc_prio") +if ! echo "$data_dscp" | grep -qE '^[0-9]+$'; then + echo "WARN: Could not parse DSCP for PFC priority $pfc_prio, keeping current defaults" >&2 + exit 0 +fi + +# NCCL_IB_FIFO_TC: DSCP mapped to strict-priority queue × 4 +strict_prio=$(echo "$qos_output" | grep -i "strict" | head -1 | awk '{print $1}' || true) +fifo_dscp="" +if [ -n "$strict_prio" ] && echo "$strict_prio" | grep -qE '^[0-9]+$'; then + fifo_dscp=$(_extract_dscp_for_priority "$strict_prio") +fi +if ! echo "$fifo_dscp" | grep -qE '^[0-9]+$'; then + fifo_dscp="$data_dscp" +fi + +ib_tc=$((data_dscp * 4)) +fifo_tc=$((fifo_dscp * 4)) + +echo "env.NCCL_IB_TC=${ib_tc}" +echo "env.NCCL_IB_FIFO_TC=${fifo_tc}" diff --git a/runner/helpers/hooks/03_enable_ainic.sh b/runner/helpers/hooks/03_enable_ainic.sh index f79bbb8e0..c23109858 100755 --- a/runner/helpers/hooks/03_enable_ainic.sh +++ b/runner/helpers/hooks/03_enable_ainic.sh @@ -27,7 +27,7 @@ RCCL_HOME_DIR="${RCCL_HOME_DIR:-/opt/rccl}" MPI_HOME_DIR="${MPI_HOME_DIR:-/opt/ompi-4.1.6}" NCCL_IB_TC="${NCCL_IB_TC:-104}" -NCCL_IB_FIFO_TC="${NCCL_IB_FIFO_TC:-184}" +NCCL_IB_FIFO_TC="${NCCL_IB_FIFO_TC:-192}" NCCL_IB_GID_INDEX="${NCCL_IB_GID_INDEX:-1}" NCCL_IB_ROCE_VERSION_NUM="${NCCL_IB_ROCE_VERSION_NUM:-2}" NCCL_MAX_P2P_CHANNELS="${NCCL_MAX_P2P_CHANNELS:-56}" diff --git a/runner/primus-cli-direct.sh b/runner/primus-cli-direct.sh index 4eb6c92c4..b58a15271 100755 --- a/runner/primus-cli-direct.sh +++ b/runner/primus-cli-direct.sh @@ -333,53 +333,8 @@ if [[ -z "${direct_config[log_file]:-}" ]]; then fi mkdir -p "$(dirname "${direct_config[log_file]:-}")" - -############################################################################### -# STEP 5: Install dependencies -############################################################################### -# Detect the backend framework from the experiment YAML (--config in PRIMUS_ARGS) -# so we can install the correct requirements file: -# maxtext -> requirements-jax.txt -# others -> requirements.txt -_detect_framework() { - local cfg_path="" - local args=("${primus_args[@]}") - for ((i=0; i<${#args[@]}; i++)); do - if [[ "${args[$i]}" == "--config" && -n "${args[$((i+1))]:-}" ]]; then - cfg_path="${args[$((i+1))]}" - break - fi - done - if [[ -z "$cfg_path" || ! -f "$cfg_path" ]]; then - echo "" - return - fi - python3 -c " -import yaml, sys -try: - cfg = yaml.safe_load(open('$cfg_path')) - print(cfg.get('modules',{}).get('pre_trainer',{}).get('framework','')) -except Exception: - print('') -" 2>/dev/null -} - -DETECTED_FRAMEWORK="$(_detect_framework)" -LOG_INFO_RANK0 "[direct] Detected framework: ${DETECTED_FRAMEWORK:-unknown}" - -# Skip pip install in dry-run mode -if [[ "$DRY_RUN_MODE" != "1" ]]; then - if [[ "$DETECTED_FRAMEWORK" == "maxtext" ]]; then - LOG_INFO_RANK0 "[direct] Installing JAX dependencies (requirements-jax.txt)" - pip install -qq -r requirements-jax.txt - else - LOG_INFO_RANK0 "[direct] Installing PyTorch dependencies (requirements.txt)" - pip install -qq -r requirements.txt - fi -fi - ############################################################################### -# STEP 6: Source GPU environment and helper modules +# STEP 5: Source GPU environment and helper modules ############################################################################### # Source primus-env.sh (it will set its own SCRIPT_DIR, which is fine) @@ -387,7 +342,7 @@ fi source "${RUNNER_DIR}/helpers/envs/primus-env.sh" ############################################################################### -# STEP 7: Execute hooks and capture extra arguments / env +# STEP 6: Execute hooks and capture extra arguments / env ############################################################################### # Hooks can return: # - Extra Primus CLI arguments by printing lines: extra.= @@ -409,7 +364,7 @@ if [[ ${#HOOK_EXTRA_PRIMUS_ARGS[@]} -gt 0 ]]; then fi ############################################################################### -# STEP 8: Execute patch scripts +# STEP 7: Execute patch scripts ############################################################################### # Execute patch scripts from config + CLI. # Note: direct_config[patch] is stored as a newline-separated list. @@ -427,7 +382,7 @@ if [[ -n "${direct_config[patch]:-}" ]]; then fi ############################################################################### -# STEP 8.5: Apply extra Primus args from patches (extra.* protocol) +# STEP 7.5: Apply extra Primus args from patches (extra.* protocol) ############################################################################### if [[ ${#PATCH_EXTRA_PRIMUS_ARGS[@]} -gt 0 ]]; then set -- "$@" "${PATCH_EXTRA_PRIMUS_ARGS[@]}" @@ -435,7 +390,7 @@ if [[ ${#PATCH_EXTRA_PRIMUS_ARGS[@]} -gt 0 ]]; then fi ############################################################################### -# STEP 9: Build and export environment variables (highest priority) +# STEP 8: Build and export environment variables (highest priority) ############################################################################### # First, source any env files specified via --env (tracked as env_file). @@ -472,7 +427,7 @@ if [[ -n "${direct_config[env]:-}" ]]; then fi ############################################################################### -# STEP 10: Build launch command +# STEP 9: Build launch command ############################################################################### # Allow RUN_MODE to be overridden by environment variable @@ -523,7 +478,7 @@ elif [[ "$RUN_MODE" == "torchrun" ]]; then fi ############################################################################### -# STEP 11: Display configuration (always) +# STEP 10: Display configuration (always) ############################################################################### if [[ "$DRY_RUN_MODE" == "1" ]]; then print_section "[DRY RUN] Direct Launch Configuration" @@ -581,7 +536,7 @@ fi print_section "" ############################################################################### -# STEP 12: Execute command +# STEP 11: Execute command ############################################################################### # Temporarily allow pipeline to fail so we can capture PIPESTATUS and log it eval "$CMD" diff --git a/runner/use_ainic.yaml b/runner/use_ainic.yaml index 83c8b9677..3d21d776c 100644 --- a/runner/use_ainic.yaml +++ b/runner/use_ainic.yaml @@ -72,6 +72,7 @@ container: - "USING_AINIC=1" - "NCCL_PXN_DISABLE=0" - "NCCL_IB_GID_INDEX=1" + - "NCCL_NET_PLUGIN=librccl-anp.so" # Only applicable for new AINIC images (with librccl-anp.so) # Direct mode settings direct: debug: false From 846b4b35f305614732c17ae461adb6fc6940185f Mon Sep 17 00:00:00 2001 From: liyingli Date: Tue, 3 Mar 2026 09:44:19 +0000 Subject: [PATCH 23/24] update cicd runner for jax --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index de6e503b2..d18fe9d70 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -307,7 +307,7 @@ jobs: env: PRIMUS_WORKDIR: /wekafs/primus-data/primus_safe_ci/jax needs: [code-lint] - runs-on: [primus-lm-cicd-jax-v26d1-dl6qc] + runs-on: [primus-lm-cicd-jax-m42vb] steps: - run: echo "🎉 Begin Primus-Turbo Checkout." - name: Set commit hash to env From 24b2931f6f1db925625611505795490841ddf282 Mon Sep 17 00:00:00 2001 From: liyingli Date: Tue, 3 Mar 2026 09:58:22 +0000 Subject: [PATCH 24/24] update docker file ainic for jax --- .github/workflows/docker/Dockerfile_jax.ainic | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/docker/Dockerfile_jax.ainic b/.github/workflows/docker/Dockerfile_jax.ainic index 085678ae7..680dac6e9 100644 --- a/.github/workflows/docker/Dockerfile_jax.ainic +++ b/.github/workflows/docker/Dockerfile_jax.ainic @@ -35,6 +35,4 @@ RUN cd ${WORKDIR} && \ cd ainic_bundle_1.117.5-a-56 && \ tar zxf host_sw_pkg.tar.gz && \ cd host_sw_pkg && \ - ./install.sh --domain=user -y 2>&1 | tee log_install.txt && \ - cd ${WORKDIR} && \ - apt-get install -y ./amd/ainic/deb-repo/libionic*.deb + ./install.sh --domain=user -y 2>&1 | tee log_install.txt