From f49a7704ad6ec1b863f1c4336e602619bd09afbb Mon Sep 17 00:00:00 2001 From: Vivek Date: Wed, 18 Dec 2024 12:25:53 +0200 Subject: [PATCH 1/6] Add HPU as new device Add yaml for single device LoRA with HPU Update yaml for single device LoRA with HPU Set device for lm_eval Add lm_eval configs for HPU Support distributed training in Gaudi Distributed Finetuning on HPU Keeping original changes Re-Enable profiler activities for XPU --- .../llama3_1/8B_full_single_device_hpu.yaml | 110 ++++++++++++++++ .../llama3_1/8B_lora_single_device_hpu.yaml | 118 ++++++++++++++++++ .../eval_configs_hpu/custom_eval_base.yaml | 37 ++++++ .../eval_configs_hpu/custom_eval_full_ft.yaml | 37 ++++++ .../eval_configs_hpu/custom_eval_lora_ft.yaml | 38 ++++++ recipes/eleuther_eval.py | 3 + recipes/lora_finetune_distributed.py | 44 +++++-- torchtune/training/_distributed.py | 2 + torchtune/training/memory.py | 4 +- torchtune/training/precision.py | 5 +- torchtune/utils/_device.py | 19 ++- 11 files changed, 401 insertions(+), 16 deletions(-) create mode 100644 recipes/configs/llama3_1/8B_full_single_device_hpu.yaml create mode 100644 recipes/configs/llama3_1/8B_lora_single_device_hpu.yaml create mode 100644 recipes/configs/llama3_1/eval_configs_hpu/custom_eval_base.yaml create mode 100644 recipes/configs/llama3_1/eval_configs_hpu/custom_eval_full_ft.yaml create mode 100644 recipes/configs/llama3_1/eval_configs_hpu/custom_eval_lora_ft.yaml diff --git a/recipes/configs/llama3_1/8B_full_single_device_hpu.yaml b/recipes/configs/llama3_1/8B_full_single_device_hpu.yaml new file mode 100644 index 0000000000..a6833f97b7 --- /dev/null +++ b/recipes/configs/llama3_1/8B_full_single_device_hpu.yaml @@ -0,0 +1,110 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Llama3.1 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config llama3_1/8B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config llama3_1/8B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +output_dir: /tmp/torchtune/llama3_1_8B/full_single_device # /tmp may be deleted by your system. Change it to your preference. + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: 512 + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset + packed: True # True increases speed +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.llama3_1_8b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 # Use to increase effective batch size +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +compile: True # torch.compile the model + loss, True increases speed + decreases memory + +# Training environment +device: hpu + +# Memory management +enable_activation_checkpointing: False # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 10 +log_peak_memory_stats: False + + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: False + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_1/8B_lora_single_device_hpu.yaml b/recipes/configs/llama3_1/8B_lora_single_device_hpu.yaml new file mode 100644 index 0000000000..4bdc88636b --- /dev/null +++ b/recipes/configs/llama3_1/8B_lora_single_device_hpu.yaml @@ -0,0 +1,118 @@ +# Config for single device LoRA finetuning in lora_finetune_single_device.py +# using a Llama3.1 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +output_dir: /tmp/torchtune/llama3_1_8B/lora_single_device # /tmp may be deleted by your system. Change it to your preference. + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.lora_llama3_1_8b + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank + lora_dropout: 0.0 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: 512 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: True # True increases speed +seed: null +shuffle: True +batch_size: 1 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + # fused: True + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 8 # Use to increase effective batch size +compile: True # torch.compile the model + loss, True increases speed + decreases memory + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 10 +log_peak_memory_stats: False + +# Environment +device: hpu +dtype: bf16 + +# Activations Memory +enable_activation_checkpointing: False # True reduces memory +enable_activation_offloading: False # True reduces memory + + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: False + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_base.yaml b/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_base.yaml new file mode 100644 index 0000000000..4e8e034aa4 --- /dev/null +++ b/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_base.yaml @@ -0,0 +1,37 @@ +# Model Arguments +model: + _component_: torchtune.models.llama3_1.llama3_1_8b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/torchtune/llama3_1_8B/full_single_device + model_type: LLAMA3 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: 512 + +# Environment +device: hpu +dtype: bf16 +seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed +compile: True +# EleutherAI specific eval args +tasks: ["hellaswag"] +limit: null +max_seq_length: 512 +batch_size: 32 +enable_kv_cache: True + +# Quantization specific args +quantizer: null diff --git a/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_full_ft.yaml b/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_full_ft.yaml new file mode 100644 index 0000000000..8e63db29eb --- /dev/null +++ b/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_full_ft.yaml @@ -0,0 +1,37 @@ +# Model Arguments +model: + _component_: torchtune.models.llama3_1.llama3_1_8b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/torchtune/llama3_1_8B/full_single_device/base_model/ + checkpoint_files: [ + /tmp/torchtune/llama3_1_8B/full_single_device/epoch_0/ft-model-00001-of-00004.safetensors, + /tmp/torchtune/llama3_1_8B/full_single_device/epoch_0/ft-model-00002-of-00004.safetensors, + /tmp/torchtune/llama3_1_8B/full_single_device/epoch_0/ft-model-00003-of-00004.safetensors, + /tmp/torchtune/llama3_1_8B/full_single_device/epoch_0/ft-model-00004-of-00004.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/torchtune/llama3_1_8B/full_single_device + model_type: LLAMA3 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: 512 + +# Environment +device: hpu +dtype: bf16 +seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed +compile: True +# EleutherAI specific eval args +tasks: ["hellaswag"] +limit: null +max_seq_length: 512 +batch_size: 32 +enable_kv_cache: True + +# Quantization specific args +quantizer: null diff --git a/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_lora_ft.yaml b/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_lora_ft.yaml new file mode 100644 index 0000000000..05a1b9d8cd --- /dev/null +++ b/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_lora_ft.yaml @@ -0,0 +1,38 @@ +# Model Arguments +model: + _component_: torchtune.models.llama3_1.llama3_1_8b + +Full finetuned +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + /tmp/torchtune/llama3_1_8B/lora_single_device/epoch_0/ft-model-00001-of-00004.safetensors, + /tmp/torchtune/llama3_1_8B/lora_single_device/epoch_0/ft-model-00002-of-00004.safetensors, + /tmp/torchtune/llama3_1_8B/lora_single_device/epoch_0/ft-model-00003-of-00004.safetensors, + /tmp/torchtune/llama3_1_8B/lora_single_device/epoch_0/ft-model-00004-of-00004.safetensors, + ] + recipe_checkpoint: null + output_dir: output + model_type: LLAMA3 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: 512 + +# Environment +device: hpu +dtype: bf16 +seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed +compile: True +# EleutherAI specific eval args +tasks: ["hellaswag"] +limit: null +max_seq_length: 512 +batch_size: 32 +enable_kv_cache: True + +# Quantization specific args +quantizer: null diff --git a/recipes/eleuther_eval.py b/recipes/eleuther_eval.py index 7be03aba4a..a9d0bdc4d9 100644 --- a/recipes/eleuther_eval.py +++ b/recipes/eleuther_eval.py @@ -315,6 +315,9 @@ def __init__( self._batch_size = batch_size self._dtype = dtype self._enable_kv_cache = enable_kv_cache + # Set device explicitely here since HPU is not included in + # `device_list` in `HFLM` class + self._device = torch.device(device) @property def model(self): diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 1ae96347ba..1c2389a4a0 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -498,6 +498,20 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) + if torch.hpu.is_available(): + # Initialize LoRA params and RoPE buffers (Before FSDP sharding) + with training.set_default_dtype(self._dtype), self._device: + lora_device = "cpu" if fsdp_cpu_offload else self._device + for m in model.modules(): + if (isinstance(m, AdapterModule)) and not lora_weights_state_dict: + # lora may not be covered in state dict + # if finetune for the 1st time + m.to_empty(device=lora_device) + m.initialize_parameters() + + if hasattr(m, "rope_init"): + m.rope_init() + # For FSDP sharding fsdp_shard_conditions = [ partial( @@ -522,18 +536,19 @@ def _setup_model( else: lora_missing, lora_unexpected = None, None - # Initialize LoRA params and RoPE buffers - with training.set_default_dtype(self._dtype), self._device: - lora_device = "cpu" if fsdp_cpu_offload else self._device - for m in model.modules(): - if (isinstance(m, AdapterModule)) and not lora_weights_state_dict: - # lora may not be covered in state dict - # if finetune for the 1st time - m.to_empty(device=lora_device) - m.initialize_parameters() + if torch.cuda.is_available(): + # Initialize LoRA params and RoPE buffers + with training.set_default_dtype(self._dtype), self._device: + lora_device = "cpu" if fsdp_cpu_offload else self._device + for m in model.modules(): + if (isinstance(m, AdapterModule)) and not lora_weights_state_dict: + # lora may not be covered in state dict + # if finetune for the 1st time + m.to_empty(device=lora_device) + m.initialize_parameters() - if hasattr(m, "rope_init"): - m.rope_init() + if hasattr(m, "rope_init"): + m.rope_init() base_missing, base_unexpected = training.load_from_full_model_state_dict( model, @@ -997,6 +1012,13 @@ def recipe_main(cfg: DictConfig) -> None: "Distributed finetune recipe should be run via a distributed launcher." "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" ) +<<<<<<< HEAD +======= + if torch.cuda.is_available(): + init_process_group("cuda:nccl,cpu:gloo") + elif torch.hpu.is_available(): + init_process_group("hpu:hccl,cpu:gloo") +>>>>>>> 0b4f8cc1 (Add HPU as new device) if cfg.get("fsdp_cpu_offload", False): # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 4e089e516a..6ffdf69a46 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -173,6 +173,8 @@ def _broadcast_tensor(tensor: torch.Tensor, src: int = 0) -> torch.Tensor: tensor = tensor.to(get_device("cuda")) elif dist.get_backend() == "xccl": tensor = tensor.to(get_device("xpu")) + elif dist.get_backend() == "hccl": + tensor = tensor.to(get_device("hpu")) dist.broadcast(tensor, src=src, group=None) return tensor.to(device) else: diff --git a/torchtune/training/memory.py b/torchtune/training/memory.py index 19a8b09b90..cf53d8423d 100644 --- a/torchtune/training/memory.py +++ b/torchtune/training/memory.py @@ -48,7 +48,9 @@ def cleanup_before_training() -> None: Call gc collect, empty device cache, and reset peak memory stats. """ gc.collect() - get_torch_device_namespace().empty_cache() + from torchtune.utils._device import is_hpu_available + if not is_hpu_available: + get_torch_device_namespace().empty_cache() get_torch_device_namespace().reset_peak_memory_stats() diff --git a/torchtune/training/precision.py b/torchtune/training/precision.py index d9232cf97f..38096757b2 100644 --- a/torchtune/training/precision.py +++ b/torchtune/training/precision.py @@ -10,7 +10,7 @@ import torch from torchtune.utils import get_logger -from torchtune.utils._device import is_npu_available +from torchtune.utils._device import is_hpu_available, is_npu_available log = get_logger() @@ -69,7 +69,8 @@ def verify_bf16_support() -> bool: mps_support = torch.backends.mps.is_available() and torch.backends.mps.is_built() npu_support = is_npu_available and torch.npu.is_bf16_supported() xpu_support = torch.xpu.is_available() and torch.xpu.is_bf16_supported() - return cuda_support or mps_support or npu_support or xpu_support + hpu_support = is_hpu_available and torch.hpu.is_bf16_supported() + return cuda_support or mps_support or npu_support or xpu_support or hpu_support def get_dtype( diff --git a/torchtune/utils/_device.py b/torchtune/utils/_device.py index cf32653b7b..5593298dd9 100644 --- a/torchtune/utils/_device.py +++ b/torchtune/utils/_device.py @@ -47,6 +47,19 @@ def is_torch_npu_available() -> bool: is_npu_available = is_torch_npu_available() +def is_torch_hpu_available() -> bool: + """Check the availability of HPU""" + try: + import habana_frameworks.torch # noqa: F401 + + return torch.hpu.is_available() + except ImportError: + return False + + +is_hpu_available = is_torch_hpu_available() + + def _get_local_rank() -> Optional[int]: """Function that gets the local rank from the environment. @@ -78,7 +91,6 @@ def _setup_device(device: torch.device) -> torch.device: device_type = device_support.device_type device_name = device_support.device_name torch_device = get_torch_device_namespace() - if device.index is None: device = torch.device(type=device_type, index=local_rank) @@ -107,6 +119,8 @@ def _get_device_type_from_env() -> str: device = "cuda" elif is_npu_available: device = "npu" + elif is_hpu_available: + device = "hpu" elif torch.xpu.is_available(): device = "xpu" elif torch.mps.is_available(): @@ -171,7 +185,7 @@ def get_device(device: Optional[str] = None) -> torch.device: if device is None: device = _get_device_type_from_env() device = torch.device(device) - if device.type in ["cuda", "npu", "xpu"]: + if device.type in ["cuda", "npu", "xpu", "hpu"]: device = _setup_device(device) _validate_device_from_env(device) return device @@ -220,6 +234,7 @@ class DeviceSupport(Enum): NPU = ("npu", "NPU", "hccl") XPU = ("xpu", "XPU", "ccl") MPS = ("mps", "MPS", "gloo") + HPU = ("hpu", "HPU", "hccl") def __init__( self, From 97643998eefc819f6ba400ab05b0b6e17c8d8d91 Mon Sep 17 00:00:00 2001 From: Vivek Date: Mon, 5 May 2025 11:38:44 +0300 Subject: [PATCH 2/6] Fix assert at end of training --- .../llama3_1/8B_full_single_device_hpu.yaml | 110 ---------------- .../llama3_1/8B_lora_single_device_hpu.yaml | 118 ------------------ recipes/lora_finetune_distributed.py | 9 +- 3 files changed, 3 insertions(+), 234 deletions(-) delete mode 100644 recipes/configs/llama3_1/8B_full_single_device_hpu.yaml delete mode 100644 recipes/configs/llama3_1/8B_lora_single_device_hpu.yaml diff --git a/recipes/configs/llama3_1/8B_full_single_device_hpu.yaml b/recipes/configs/llama3_1/8B_full_single_device_hpu.yaml deleted file mode 100644 index a6833f97b7..0000000000 --- a/recipes/configs/llama3_1/8B_full_single_device_hpu.yaml +++ /dev/null @@ -1,110 +0,0 @@ -# Config for single device full finetuning in full_finetune_single_device.py -# using a Llama3.1 8B Instruct model -# -# This config assumes that you've run the following command before launching -# this run: -# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" -# -# The default config uses an optimizer from bitsandbytes. If you do not have it installed, -# you can install it with -# pip install bitsandbytes -# -# To launch on a single device, run the following command from root: -# tune run full_finetune_single_device --config llama3_1/8B_full_single_device -# -# You can add specific overrides through the command line. For example -# to override the checkpointer directory while launching training -# you can run: -# tune run full_finetune_single_device --config llama3_1/8B_full_single_device checkpointer.checkpoint_dir= -# -# This config works only for training on single device. - - -output_dir: /tmp/torchtune/llama3_1_8B/full_single_device # /tmp may be deleted by your system. Change it to your preference. - -# Tokenizer -tokenizer: - _component_: torchtune.models.llama3.llama3_tokenizer - path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model - max_seq_len: 512 - -# Dataset -dataset: - _component_: torchtune.datasets.alpaca_dataset - packed: True # True increases speed -seed: null -shuffle: True - -# Model Arguments -model: - _component_: torchtune.models.llama3_1.llama3_1_8b - -checkpointer: - _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ - checkpoint_files: [ - model-00001-of-00004.safetensors, - model-00002-of-00004.safetensors, - model-00003-of-00004.safetensors, - model-00004-of-00004.safetensors - ] - recipe_checkpoint: null - output_dir: ${output_dir} - model_type: LLAMA3 -resume_from_checkpoint: False - -# Fine-tuning arguments -batch_size: 2 -epochs: 1 -optimizer: - _component_: torch.optim.AdamW - lr: 2e-5 -loss: - _component_: torchtune.modules.loss.CEWithChunkedOutputLoss -max_steps_per_epoch: null -gradient_accumulation_steps: 1 # Use to increase effective batch size -optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 -compile: True # torch.compile the model + loss, True increases speed + decreases memory - -# Training environment -device: hpu - -# Memory management -enable_activation_checkpointing: False # True reduces memory -enable_activation_offloading: False # True reduces memory - -# Reduced precision -dtype: bf16 - -# Logging -metric_logger: - _component_: torchtune.training.metric_logging.DiskLogger - log_dir: ${output_dir}/logs -log_every_n_steps: 10 -log_peak_memory_stats: False - - -# Profiler (disabled) -profiler: - _component_: torchtune.training.setup_torch_profiler - enabled: False - - #Output directory of trace artifacts - output_dir: ${output_dir}/profiling_outputs - - #`torch.profiler.ProfilerActivity` types to trace - cpu: True - cuda: False - - #trace options passed to `torch.profiler.profile` - profile_memory: False - with_stack: False - record_shapes: True - with_flops: False - - # `torch.profiler.schedule` options: - # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat - wait_steps: 5 - warmup_steps: 3 - active_steps: 2 - num_cycles: 1 diff --git a/recipes/configs/llama3_1/8B_lora_single_device_hpu.yaml b/recipes/configs/llama3_1/8B_lora_single_device_hpu.yaml deleted file mode 100644 index 4bdc88636b..0000000000 --- a/recipes/configs/llama3_1/8B_lora_single_device_hpu.yaml +++ /dev/null @@ -1,118 +0,0 @@ -# Config for single device LoRA finetuning in lora_finetune_single_device.py -# using a Llama3.1 8B Instruct model -# -# This config assumes that you've run the following command before launching -# this run: -# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" -# -# To launch on a single device, run the following command from root: -# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device -# -# You can add specific overrides through the command line. For example -# to override the checkpointer directory while launching training -# you can run: -# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device checkpointer.checkpoint_dir= -# -# This config works only for training on single device. - - -output_dir: /tmp/torchtune/llama3_1_8B/lora_single_device # /tmp may be deleted by your system. Change it to your preference. - -# Model Arguments -model: - _component_: torchtune.models.llama3_1.lora_llama3_1_8b - lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] - apply_lora_to_mlp: True - apply_lora_to_output: False - lora_rank: 8 # higher increases accuracy and memory - lora_alpha: 16 # usually alpha=2*rank - lora_dropout: 0.0 - -# Tokenizer -tokenizer: - _component_: torchtune.models.llama3.llama3_tokenizer - path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model - max_seq_len: 512 - -checkpointer: - _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ - checkpoint_files: [ - model-00001-of-00004.safetensors, - model-00002-of-00004.safetensors, - model-00003-of-00004.safetensors, - model-00004-of-00004.safetensors - ] - recipe_checkpoint: null - output_dir: ${output_dir} - model_type: LLAMA3 -resume_from_checkpoint: False -save_adapter_weights_only: False - -# Dataset and Sampler -dataset: - _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: True # True increases speed -seed: null -shuffle: True -batch_size: 1 - -# Optimizer and Scheduler -optimizer: - _component_: torch.optim.AdamW - # fused: True - weight_decay: 0.01 - lr: 3e-4 -lr_scheduler: - _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup - num_warmup_steps: 100 - -loss: - _component_: torchtune.modules.loss.CEWithChunkedOutputLoss - -# Training -epochs: 1 -max_steps_per_epoch: null -gradient_accumulation_steps: 8 # Use to increase effective batch size -compile: True # torch.compile the model + loss, True increases speed + decreases memory - -# Logging -metric_logger: - _component_: torchtune.training.metric_logging.DiskLogger - log_dir: ${output_dir}/logs -log_every_n_steps: 10 -log_peak_memory_stats: False - -# Environment -device: hpu -dtype: bf16 - -# Activations Memory -enable_activation_checkpointing: False # True reduces memory -enable_activation_offloading: False # True reduces memory - - -# Profiler (disabled) -profiler: - _component_: torchtune.training.setup_torch_profiler - enabled: False - - #Output directory of trace artifacts - output_dir: ${output_dir}/profiling_outputs - - #`torch.profiler.ProfilerActivity` types to trace - cpu: True - cuda: False - - #trace options passed to `torch.profiler.profile` - profile_memory: False - with_stack: False - record_shapes: True - with_flops: False - - # `torch.profiler.schedule` options: - # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat - wait_steps: 5 - warmup_steps: 3 - active_steps: 2 - num_cycles: 1 diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 1c2389a4a0..51ac1b9bb7 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -1012,13 +1012,10 @@ def recipe_main(cfg: DictConfig) -> None: "Distributed finetune recipe should be run via a distributed launcher." "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" ) -<<<<<<< HEAD -======= - if torch.cuda.is_available(): + if cfg.device == "hpu": + init_process_group(backend="hccl") + else: init_process_group("cuda:nccl,cpu:gloo") - elif torch.hpu.is_available(): - init_process_group("hpu:hccl,cpu:gloo") ->>>>>>> 0b4f8cc1 (Add HPU as new device) if cfg.get("fsdp_cpu_offload", False): # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU From 85eae81e4f880f380fd36f1b14ea2cbae72853e3 Mon Sep 17 00:00:00 2001 From: Vivek Date: Mon, 5 May 2025 12:40:49 +0300 Subject: [PATCH 3/6] delete hpu eval configs --- .../eval_configs_hpu/custom_eval_base.yaml | 37 ------------------ .../eval_configs_hpu/custom_eval_full_ft.yaml | 37 ------------------ .../eval_configs_hpu/custom_eval_lora_ft.yaml | 38 ------------------- 3 files changed, 112 deletions(-) delete mode 100644 recipes/configs/llama3_1/eval_configs_hpu/custom_eval_base.yaml delete mode 100644 recipes/configs/llama3_1/eval_configs_hpu/custom_eval_full_ft.yaml delete mode 100644 recipes/configs/llama3_1/eval_configs_hpu/custom_eval_lora_ft.yaml diff --git a/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_base.yaml b/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_base.yaml deleted file mode 100644 index 4e8e034aa4..0000000000 --- a/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_base.yaml +++ /dev/null @@ -1,37 +0,0 @@ -# Model Arguments -model: - _component_: torchtune.models.llama3_1.llama3_1_8b - -checkpointer: - _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ - checkpoint_files: [ - model-00001-of-00004.safetensors, - model-00002-of-00004.safetensors, - model-00003-of-00004.safetensors, - model-00004-of-00004.safetensors - ] - recipe_checkpoint: null - output_dir: /tmp/torchtune/llama3_1_8B/full_single_device - model_type: LLAMA3 - -# Tokenizer -tokenizer: - _component_: torchtune.models.llama3.llama3_tokenizer - path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model - max_seq_len: 512 - -# Environment -device: hpu -dtype: bf16 -seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed -compile: True -# EleutherAI specific eval args -tasks: ["hellaswag"] -limit: null -max_seq_length: 512 -batch_size: 32 -enable_kv_cache: True - -# Quantization specific args -quantizer: null diff --git a/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_full_ft.yaml b/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_full_ft.yaml deleted file mode 100644 index 8e63db29eb..0000000000 --- a/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_full_ft.yaml +++ /dev/null @@ -1,37 +0,0 @@ -# Model Arguments -model: - _component_: torchtune.models.llama3_1.llama3_1_8b - -checkpointer: - _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/torchtune/llama3_1_8B/full_single_device/base_model/ - checkpoint_files: [ - /tmp/torchtune/llama3_1_8B/full_single_device/epoch_0/ft-model-00001-of-00004.safetensors, - /tmp/torchtune/llama3_1_8B/full_single_device/epoch_0/ft-model-00002-of-00004.safetensors, - /tmp/torchtune/llama3_1_8B/full_single_device/epoch_0/ft-model-00003-of-00004.safetensors, - /tmp/torchtune/llama3_1_8B/full_single_device/epoch_0/ft-model-00004-of-00004.safetensors, - ] - recipe_checkpoint: null - output_dir: /tmp/torchtune/llama3_1_8B/full_single_device - model_type: LLAMA3 - -# Tokenizer -tokenizer: - _component_: torchtune.models.llama3.llama3_tokenizer - path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model - max_seq_len: 512 - -# Environment -device: hpu -dtype: bf16 -seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed -compile: True -# EleutherAI specific eval args -tasks: ["hellaswag"] -limit: null -max_seq_length: 512 -batch_size: 32 -enable_kv_cache: True - -# Quantization specific args -quantizer: null diff --git a/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_lora_ft.yaml b/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_lora_ft.yaml deleted file mode 100644 index 05a1b9d8cd..0000000000 --- a/recipes/configs/llama3_1/eval_configs_hpu/custom_eval_lora_ft.yaml +++ /dev/null @@ -1,38 +0,0 @@ -# Model Arguments -model: - _component_: torchtune.models.llama3_1.llama3_1_8b - -Full finetuned -checkpointer: - _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ - checkpoint_files: [ - /tmp/torchtune/llama3_1_8B/lora_single_device/epoch_0/ft-model-00001-of-00004.safetensors, - /tmp/torchtune/llama3_1_8B/lora_single_device/epoch_0/ft-model-00002-of-00004.safetensors, - /tmp/torchtune/llama3_1_8B/lora_single_device/epoch_0/ft-model-00003-of-00004.safetensors, - /tmp/torchtune/llama3_1_8B/lora_single_device/epoch_0/ft-model-00004-of-00004.safetensors, - ] - recipe_checkpoint: null - output_dir: output - model_type: LLAMA3 - -# Tokenizer -tokenizer: - _component_: torchtune.models.llama3.llama3_tokenizer - path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model - max_seq_len: 512 - -# Environment -device: hpu -dtype: bf16 -seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed -compile: True -# EleutherAI specific eval args -tasks: ["hellaswag"] -limit: null -max_seq_length: 512 -batch_size: 32 -enable_kv_cache: True - -# Quantization specific args -quantizer: null From 2bd999c0d9830bbcd261d35a33c15c5f4d0ffdd9 Mon Sep 17 00:00:00 2001 From: Vivek Date: Mon, 5 May 2025 14:26:51 +0300 Subject: [PATCH 4/6] fix lora recipe after rebase --- recipes/lora_finetune_distributed.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 51ac1b9bb7..51ddcd49aa 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -1012,10 +1012,6 @@ def recipe_main(cfg: DictConfig) -> None: "Distributed finetune recipe should be run via a distributed launcher." "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" ) - if cfg.device == "hpu": - init_process_group(backend="hccl") - else: - init_process_group("cuda:nccl,cpu:gloo") if cfg.get("fsdp_cpu_offload", False): # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU From dca0febadd0b8b1fb672d1e766bbcb77edbf124d Mon Sep 17 00:00:00 2001 From: Vivek Date: Thu, 29 May 2025 06:57:20 +0300 Subject: [PATCH 5/6] Add HPU support to Profiler --- torchtune/training/_profiler.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchtune/training/_profiler.py b/torchtune/training/_profiler.py index 47970ce7c7..74e4b79284 100644 --- a/torchtune/training/_profiler.py +++ b/torchtune/training/_profiler.py @@ -181,6 +181,7 @@ def setup_torch_profiler( cpu: bool = True, cuda: bool = True, xpu: bool = True, + hpu: bool = False, profile_memory: bool = DEFAULT_TRACE_OPTS["profile_memory"], with_stack: bool = DEFAULT_TRACE_OPTS["with_stack"], record_shapes: bool = DEFAULT_TRACE_OPTS["record_shapes"], @@ -249,6 +250,7 @@ def setup_torch_profiler( cpu (bool): Enable cpu profiling. Default is True. cuda (bool): Enable cuda profiling. Default is True. xpu (bool): Enable xpu profiling. Default is True. + hpu (bool): Enable hpu profiling. Default is False. profile_memory (bool): Profile memory usage. Default is False. with_stack (bool): Profile stack. Default is False. record_shapes (bool): Record shapes. Default is True. @@ -275,6 +277,8 @@ def setup_torch_profiler( activities.append(torch.profiler.ProfilerActivity.CUDA) if xpu: activities.append(torch.profiler.ProfilerActivity.XPU) + if hpu: + activities.append(torch.profiler.ProfilerActivity.HPU) if len(activities) == 0: _warn("No activities specified, defaulting to CPU + CUDA") activities = DEFAULT_PROFILER_ACTIVITIES @@ -372,6 +376,7 @@ def setup_torch_profiler( "cpu": cpu, "cuda": cuda, "xpu": xpu, + "hpu": hpu, "profile_memory": profile_memory, "with_stack": with_stack, "record_shapes": record_shapes, From d11bbca4f98d8e77248eaefbdda72d37799ec7ca Mon Sep 17 00:00:00 2001 From: Rohit kumar Singh Date: Wed, 4 Jun 2025 17:02:30 +0300 Subject: [PATCH 6/6] Add CP Support for HPU --- recipes/configs/llama3_1/8B_lora.yaml | 6 ++ recipes/lora_finetune_distributed.py | 96 ++++++++++++++++--- torchtune/data/__init__.py | 3 +- torchtune/data/_utils.py | 47 +++++++++ torchtune/models/llama3/_model_builders.py | 2 + torchtune/models/llama3/_tokenizer.py | 14 ++- .../models/llama3_1/_component_builders.py | 11 ++- torchtune/models/llama3_1/_model_builders.py | 2 + torchtune/modules/attention.py | 10 +- torchtune/modules/attention_utils.py | 28 ++++++ torchtune/training/__init__.py | 4 + torchtune/training/_distributed.py | 85 ++++++++++++++-- 12 files changed, 281 insertions(+), 27 deletions(-) diff --git a/recipes/configs/llama3_1/8B_lora.yaml b/recipes/configs/llama3_1/8B_lora.yaml index b1b53690e1..6178fc1feb 100644 --- a/recipes/configs/llama3_1/8B_lora.yaml +++ b/recipes/configs/llama3_1/8B_lora.yaml @@ -23,6 +23,7 @@ output_dir: /tmp/torchtune/llama3_1_8B/lora # /tmp may be deleted by your system tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + pad_to_max_seq_len: True # True pads to max_seq_len, enable (ideally) when dataset.packed=False max_seq_len: null # Model Arguments @@ -34,6 +35,11 @@ model: lora_rank: 8 # higher increases accuracy and memory lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 + attn_func: 'hpu_scaled_dot_product_attention' + +# Parallelism +context_parallel_dim: 2 +context_parallel_rotate_method: 'alltoall' # 'alltoall' or 'all-gather' checkpointer: _component_: torchtune.training.FullModelHFCheckpointer diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 51ddcd49aa..914a0addcb 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -142,16 +142,49 @@ def __init__(self, cfg: DictConfig) -> None: ) init_process_group(self.distributed_backend) + # Initialize distributed variables self.world_size, self.rank = utils.get_world_size_and_rank() - self._is_rank_zero = self.rank == 0 + data_shard = cfg.get("data_parallel_shard_dim", -1) # -1 means to infer + data_replicate = cfg.get("data_parallel_replicate_dim", 1) + self.cp_degree = cfg.get("context_parallel_dim", 1) + self.context_parallel_rotate_method = cfg.get("context_parallel_rotate_method", "none") + + # Set up n-d device mesh + self.parallel_dims = training.ParallelDims( + dp_replicate=data_replicate, + dp_shard=data_shard, + tp=1, + cp=self.cp_degree, + world_size=self.world_size, + ) + self.world_mesh = self.parallel_dims.build_mesh(device_type=cfg.device) + + if self.parallel_dims.dp_enabled: + dp_mesh = self.world_mesh["dp"] + self.dp_degree, self.dp_rank = ( + dp_mesh.size(), + dp_mesh.get_local_rank(), + ) + else: + self.dp_degree, self.dp_rank = 1, 0 + + self.train_context = training.get_train_context( + enable_loss_parallel=False, + enable_compiled_autograd=False, + ) # logging attributes self._output_dir = cfg.output_dir self._log_every_n_steps = cfg.get("log_every_n_steps", 1) self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) self._logger = utils.get_logger(cfg.log_level) + utils.log_rank_zero( + self._logger, + f" WORLD DEVICE MESH: {self.world_mesh}, CP DEVICE MESH: {self.world_mesh['cp']}\n", + ) + if self._log_peak_memory_stats and self._device.type not in {"cuda", "xpu"}: self._logger.info( "log_peak_memory_stats was set to True, however, training does not use cuda or xpu." @@ -519,11 +552,17 @@ def _setup_model( names_to_match=custom_sharded_layers, ) ] + if self.parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + training.shard_model( model=model, shard_conditions=fsdp_shard_conditions, cpu_offload=fsdp_cpu_offload, reshard_after_forward=reshard_after_forward, + dp_mesh=self.world_mesh[dp_mesh_dim_names], ) if lora_weights_state_dict: @@ -641,10 +680,10 @@ def _setup_data( for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) - packed = getattr(ds, "packed", False) + self.packed_dataset = getattr(ds, "packed", False) else: ds = config.instantiate(cfg_dataset, self._tokenizer) - packed = cfg_dataset.get("packed", False) + self.packed_dataset = cfg_dataset.get("packed", False) # Instantiate collate_fn if "left_pad_sequence" in collate_fn: @@ -667,7 +706,7 @@ def _setup_data( padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) - if not packed + if not self.packed_dataset else padded_collate_packed ), # dropping last avoids shape issues with compile + flex attention @@ -930,17 +969,50 @@ def train(self) -> None: def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: # Shape [b, s], needed for the loss not the model labels = batch.pop("labels") + if self.cp_degree>1 and self._tokenizer.pad_to_max_seq_len and not self.packed_dataset: + # For cp we need to pass the buffer and the buffer dim (cp_seq_dims) to apply cp + cp_buffers = [] + for k,v in batch.items(): + if isinstance(v, torch.Tensor): + cp_buffers.append(v) + + cp_buffers.append(labels) + + optional_context_parallel_ctx = ( + training.create_context_parallel_ctx( + cp_mesh=self.world_mesh["cp"], + cp_buffers=cp_buffers, + cp_seq_dims=[1]*len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), + cp_rotate_method=self.context_parallel_rotate_method, + ) + if self.parallel_dims.cp_enabled + else None + ) + + with optional_context_parallel_ctx: + with self.activations_handling_ctx: + outputs = self._model(**batch) - with self.activations_handling_ctx: - outputs = self._model(**batch) + if self.linear_loss: + weight = self._model.linear_projection_weight + loss = self._loss_fn(weight, outputs, labels) + else: + labels = labels.reshape(-1) + outputs = outputs.reshape(-1, outputs.size(-1)) + loss = self._loss_fn(outputs, labels) - if self.linear_loss: - weight = self._model.linear_projection_weight - loss = self._loss_fn(weight, outputs, labels) else: - labels = labels.reshape(-1) - outputs = outputs.reshape(-1, outputs.size(-1)) - loss = self._loss_fn(outputs, labels) + with self.activations_handling_ctx: + outputs = self._model(**batch) + + if self.linear_loss: + weight = self._model.linear_projection_weight + loss = self._loss_fn(weight, outputs, labels) + else: + labels = labels.reshape(-1) + outputs = outputs.reshape(-1, outputs.size(-1)) + loss = self._loss_fn(outputs, labels) # free logits otherwise it peaks backward memory del outputs diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index a75e16780a..0ef8aca2f3 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -32,7 +32,7 @@ QuestionAnswerTemplate, SummarizeTemplate, ) -from torchtune.data._utils import format_content_with_images, load_image, truncate +from torchtune.data._utils import format_content_with_images, load_image, truncate, pad_tokens __all__ = [ "CROSS_ENTROPY_IGNORE_IDX", @@ -60,4 +60,5 @@ "padded_collate_tiled_images_and_mask", "padded_collate_packed", "load_image", + "pad_tokens" ] diff --git a/torchtune/data/_utils.py b/torchtune/data/_utils.py index 6f086a41fa..0b584924b0 100644 --- a/torchtune/data/_utils.py +++ b/torchtune/data/_utils.py @@ -62,6 +62,53 @@ def truncate( return tokens_truncated +def pad_tokens( + tokens: List[Any], + mask: List[Any], + max_seq_len: int, + pad_id: int, + eos_id: Optional[Any] = None, + pad_type: str = "right", +) -> List[Any]: + """ + Pad a list of tokens and masks to a maximum seq length. + If eos_id is provided, the last token will be replaced with eos_id. + + Args: + tokens (List[Any]): list of tokens to pad + mask (List[Any]): list of masks to pad + max_seq_len (int): maximum length of the list + pad_id (int): token used for padding + eos_id (Optional[Any]): token to replace the last token with. If None, the + last token will not be replaced. Default is None. + pad_type (str): type of padding to apply, either "left" or "right". + Default is "right". + + Returns: + Tuple[List[Any], List[bool]]: padded tokens and attention mask (True for real tokens, False for padding) + + Raises: + ValueError: if pad_type is not "left" or "right" + """ + + padding_length = max_seq_len - len(tokens) + if padding_length > 0: + if pad_type == "right": + tokens = tokens + [pad_id] * padding_length + mask = mask + [False] * padding_length + elif pad_type == "left": + tokens = [pad_id] * padding_length + tokens + mask = [False] * padding_length + mask + else: + raise ValueError( + f"truncation_type must be 'left' or 'right', got {pad_type}" + ) + + # Replace the last token with eos_id if necessary + if eos_id is not None and tokens and tokens[-1] != eos_id: + tokens[-1] = eos_id + + return tokens, mask def load_image(image_loc: Union[Path, str]) -> torch.Tensor: """ diff --git a/torchtune/models/llama3/_model_builders.py b/torchtune/models/llama3/_model_builders.py index a09c80a4b7..139b7e5e3b 100644 --- a/torchtune/models/llama3/_model_builders.py +++ b/torchtune/models/llama3/_model_builders.py @@ -71,6 +71,7 @@ def llama3_tokenizer( max_seq_len: Optional[int] = None, prompt_template: Optional[_TemplateType] = None, truncation_type: str = "right", + pad_to_max_seq_len: bool = False, ) -> Llama3Tokenizer: """ Tokenizer for Llama3. @@ -106,6 +107,7 @@ def llama3_tokenizer( max_seq_len=max_seq_len, prompt_template=template, truncation_type=truncation_type, + pad_to_max_seq_len=pad_to_max_seq_len, ) diff --git a/torchtune/models/llama3/_tokenizer.py b/torchtune/models/llama3/_tokenizer.py index 349287f95b..e6a65a20c1 100644 --- a/torchtune/models/llama3/_tokenizer.py +++ b/torchtune/models/llama3/_tokenizer.py @@ -7,7 +7,7 @@ import re from typing import Any, Dict, List, Mapping, Optional, Tuple -from torchtune.data import Message, PromptTemplate, truncate +from torchtune.data import Message, PromptTemplate, truncate, pad_tokens from torchtune.modules.transforms import Transform from torchtune.modules.transforms.tokenizers import ( ModelTokenizer, @@ -82,6 +82,7 @@ def __init__( max_seq_len: Optional[int] = None, prompt_template: Optional[PromptTemplate] = None, truncation_type: str = "right", + pad_to_max_seq_len: bool = False, ): self.special_tokens = ( special_tokens if special_tokens is not None else LLAMA3_SPECIAL_TOKENS @@ -118,7 +119,7 @@ def __init__( special_tokens=self.special_tokens, ) self.max_seq_len = max_seq_len - + self.pad_to_max_seq_len = pad_to_max_seq_len self.prompt_template = prompt_template # Regex for removing special tokens from the decoded string @@ -341,6 +342,15 @@ def tokenize_messages( truncation_type=self.truncation_type, ) + # Add padding if pad_to_max_seq_len=True and max_seq_len is set + if self.pad_to_max_seq_len: + tokens, mask = pad_tokens( + tokens=tokens, + mask=mask, + max_seq_len=self.max_seq_len, + pad_id=self.pad_id, + ) + return tokens, mask def __call__( diff --git a/torchtune/models/llama3_1/_component_builders.py b/torchtune/models/llama3_1/_component_builders.py index 3fc2431c12..c58464fc12 100644 --- a/torchtune/models/llama3_1/_component_builders.py +++ b/torchtune/models/llama3_1/_component_builders.py @@ -4,10 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional +from typing import List, Optional, Callable from torch import nn +from torchtune.modules.attention_utils import _sdpa_or_flex_attention, scaled_dot_product_attention from torchtune.models.llama3._model_utils import scale_hidden_dim_for_mlp from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE from torchtune.modules import ( @@ -139,6 +140,7 @@ def lora_llama3_1( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, apply_lora_to_output: bool = False, + attn_func: str = None, *, # llama3.1 args vocab_size: int, @@ -206,7 +208,9 @@ def lora_llama3_1( hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) head_dim = embed_dim // num_heads rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base, scale_factor=scale_factor) - + + attn_func = _sdpa_or_flex_attention if attn_func is None else scaled_dot_product_attention + layers = nn.ModuleList() for _ in range(num_layers): self_attn = lora_llama3_attention( @@ -218,6 +222,7 @@ def lora_llama3_1( num_kv_heads=num_kv_heads, max_seq_len=max_seq_len, attn_dropout=attn_dropout, + attn_func=attn_func, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, @@ -287,6 +292,7 @@ def lora_llama3_attention( max_seq_len: int, is_causal: bool = True, attn_dropout: float = 0.0, + attn_func: Optional[Callable] = _sdpa_or_flex_attention, # LoRA args lora_rank: int, lora_alpha: float, @@ -420,6 +426,7 @@ def lora_llama3_attention( max_seq_len=max_seq_len, is_causal=is_causal, attn_dropout=attn_dropout, + attn_func=attn_func, ) return self_attn diff --git a/torchtune/models/llama3_1/_model_builders.py b/torchtune/models/llama3_1/_model_builders.py index f48ce580f5..8b6bf9c11b 100644 --- a/torchtune/models/llama3_1/_model_builders.py +++ b/torchtune/models/llama3_1/_model_builders.py @@ -90,6 +90,7 @@ def lora_llama3_1_8b( lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, + attn_func: str = None, ) -> TransformerDecoder: """ Builder for creating a Llama3.1 8B model with LoRA enabled. @@ -135,6 +136,7 @@ def lora_llama3_1_8b( lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, + attn_func=attn_func, ) diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index ff6faccb5d..90d90da8d5 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Optional +from typing import Optional, Callable import torch from torch import nn @@ -98,6 +98,7 @@ def __init__( max_seq_len: int = 4096, is_causal: bool = True, attn_dropout: float = 0.0, + attn_func: Callable = _sdpa_or_flex_attention, ) -> None: super().__init__() if num_heads % num_kv_heads != 0: @@ -137,8 +138,11 @@ def __init__( self.k_norm = k_norm self.pos_embeddings = pos_embeddings - # Use flex attention if supported and we are sample packing - self._attention_call = _sdpa_or_flex_attention() + # Use flex/sdpa attention if supported and we are sample packing else use custom attention + if attn_func is _sdpa_or_flex_attention: + self._attention_call = _sdpa_or_flex_attention() + else: + self._attention_call = attn_func # this flag indicates whether to update the kv-cache during forward # passes. when disabled, we can have the cache setup but still diff --git a/torchtune/modules/attention_utils.py b/torchtune/modules/attention_utils.py index 9d456ca3ea..abe7db4e50 100644 --- a/torchtune/modules/attention_utils.py +++ b/torchtune/modules/attention_utils.py @@ -268,6 +268,34 @@ def _attention_call( return _attention_call +# Efficient implementation taken from https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html +def scaled_dot_product_attention(query, key, value, mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1.0 / (query.size(-1) ** 0.5) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if mask is not None: + if mask.dtype == torch.bool: + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + else: + attn_bias = mask + attn_bias + + if enable_gqa: + key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) + value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + def kv_offset_mask_flex(b, h, q_idx, kv_idx, offset): """ Mask mod for autoregressive generation to be used by flex attention. See https://pytorch.org/blog/flexattention/#mask-mods. diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index b2c327c617..2ae40d180d 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -15,6 +15,8 @@ get_full_optimizer_state_dict, get_shard_conditions, get_world_size_and_rank, + get_train_context, + create_context_parallel_ctx, init_distributed, is_distributed, load_from_full_model_state_dict, @@ -142,4 +144,6 @@ "get_distributed_backend", "disable_dropout", "DATALOADER_KEY", + "get_train_context", + "create_context_parallel_ctx", ] diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 6ffdf69a46..5099f7a173 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -7,6 +7,8 @@ import logging import os +import contextlib +from collections.abc import Generator from dataclasses import dataclass from functools import cached_property from itertools import chain @@ -54,36 +56,38 @@ class ParallelDims: dp_replicate: int dp_shard: int tp: int + cp: int world_size: int def __post_init__(self): self._validate() def _validate(self): - dp_replicate, dp_shard, tp = ( + dp_replicate, dp_shard, tp, cp = ( self.dp_replicate, self.dp_shard, self.tp, + self.cp, ) - for d in (dp_replicate, tp): + for d in (dp_replicate, tp, cp): assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." if dp_shard < 0: - self.dp_shard = dp_shard = self.world_size // (dp_replicate * tp) + self.dp_shard = dp_shard = self.world_size // (dp_replicate * tp * cp) assert dp_shard >= 1 - assert dp_replicate * dp_shard * tp == self.world_size, ( + assert dp_replicate * dp_shard * tp * cp == self.world_size, ( f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " - f"tp({tp}) != WORLD_SIZE({self.world_size})" + f"tp({tp}) * tp({cp}) != WORLD_SIZE({self.world_size})" ) def build_mesh(self, device_type): dims = [] names = [] for d, name in zip( - [self.dp_replicate, self.dp_shard, self.tp], - ["dp_replicate", "dp_shard", "tp"], + [self.dp_replicate, self.dp_shard, self.tp, self.cp], + ["dp_replicate", "dp_shard", "tp", "cp"], ): if d > 1: dims.append(d) @@ -96,14 +100,30 @@ def build_mesh(self, device_type): # initialized: # Mesh for data loading (no communication on this mesh) dp_mesh_dim_names = [] + # Mesh for param sharding + dp_shard_cp_mesh_dim_names = [] + # Mesh for loss all-reduce + dp_cp_mesh_dim_names = [] if self.dp_replicate_enabled: dp_mesh_dim_names.append("dp_replicate") + dp_cp_mesh_dim_names.append("dp_replicate") if self.dp_shard_enabled: dp_mesh_dim_names.append("dp_shard") + dp_shard_cp_mesh_dim_names.append("dp_shard") + dp_cp_mesh_dim_names.append("dp_shard") + if self.cp_enabled: + dp_shard_cp_mesh_dim_names.append("cp") + dp_cp_mesh_dim_names.append("cp") if dp_mesh_dim_names != []: mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + if dp_shard_cp_mesh_dim_names != []: + mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( + mesh_dim_name="dp_shard_cp" + ) + if dp_cp_mesh_dim_names != []: + mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") return mesh @@ -123,6 +143,10 @@ def dp_shard_enabled(self): def tp_enabled(self): return self.tp > 1 + @property + def cp_enabled(self): + return self.cp > 1 + @cached_property def non_data_parallel_size(self): # update below as more parallelism options are implemented @@ -658,6 +682,53 @@ def shard_model( fully_shard(model, **fsdp_kwargs) +def create_context_parallel_ctx( + cp_mesh: DeviceMesh, + cp_buffers: list[torch.Tensor], + cp_seq_dims: list[int], + cp_no_restore_buffers: set[torch.Tensor], + cp_rotate_method: str, +): + try: + from torch.distributed.tensor.experimental import context_parallel + from torch.distributed.tensor.experimental._attention import set_rotate_method + except ImportError: + print( + f"PyTorch version {torch.__version__} does not include the experimental " + "Context Parallel API. Please update to a newer version." + ) + + set_rotate_method(cp_rotate_method) + return context_parallel( + cp_mesh, + buffers=cp_buffers, + buffer_seq_dims=cp_seq_dims, + no_restore_buffers=cp_no_restore_buffers, + ) + + +def get_train_context( + enable_loss_parallel: bool, enable_compiled_autograd: bool +) -> Generator[None, None, None]: + @contextlib.contextmanager + def context(cp_context: Generator[None, None, None] | None = None, activations_handling_ctx: Generator[None, None, None] | None = None): + with contextlib.ExitStack() as stack: + if enable_loss_parallel: + stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) + + if enable_compiled_autograd: + stack.enter_context( + torch._dynamo.utils.maybe_enable_compiled_autograd(True) + ) + + if cp_context is not None: + stack.enter_context(cp_context) + + yield + + return context + + def prepare_mha_for_tp( model: nn.Module, tp_mesh: DeviceMesh,