Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions recipes/configs/llama3_1/8B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ output_dir: /tmp/torchtune/llama3_1_8B/lora # /tmp may be deleted by your system
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
pad_to_max_seq_len: True # True pads to max_seq_len, enable (ideally) when dataset.packed=False
max_seq_len: null

# Model Arguments
Expand All @@ -34,6 +35,11 @@ model:
lora_rank: 8 # higher increases accuracy and memory
lora_alpha: 16 # usually alpha=2*rank
lora_dropout: 0.0
attn_func: 'hpu_scaled_dot_product_attention'

# Parallelism
context_parallel_dim: 2
context_parallel_rotate_method: 'alltoall' # 'alltoall' or 'all-gather'

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
Expand Down
3 changes: 3 additions & 0 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ def __init__(
self._batch_size = batch_size
self._dtype = dtype
self._enable_kv_cache = enable_kv_cache
# Set device explicitely here since HPU is not included in
# `device_list` in `HFLM` class
self._device = torch.device(device)

@property
def model(self):
Expand Down
133 changes: 110 additions & 23 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,49 @@ def __init__(self, cfg: DictConfig) -> None:
)
init_process_group(self.distributed_backend)

# Initialize distributed variables
self.world_size, self.rank = utils.get_world_size_and_rank()

self._is_rank_zero = self.rank == 0
data_shard = cfg.get("data_parallel_shard_dim", -1) # -1 means to infer
data_replicate = cfg.get("data_parallel_replicate_dim", 1)
self.cp_degree = cfg.get("context_parallel_dim", 1)
self.context_parallel_rotate_method = cfg.get("context_parallel_rotate_method", "none")

# Set up n-d device mesh
self.parallel_dims = training.ParallelDims(
dp_replicate=data_replicate,
dp_shard=data_shard,
tp=1,
cp=self.cp_degree,
world_size=self.world_size,
)
self.world_mesh = self.parallel_dims.build_mesh(device_type=cfg.device)

if self.parallel_dims.dp_enabled:
dp_mesh = self.world_mesh["dp"]
self.dp_degree, self.dp_rank = (
dp_mesh.size(),
dp_mesh.get_local_rank(),
)
else:
self.dp_degree, self.dp_rank = 1, 0


self.train_context = training.get_train_context(
enable_loss_parallel=False,
enable_compiled_autograd=False,
)
# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
self._logger = utils.get_logger(cfg.log_level)

utils.log_rank_zero(
self._logger,
f" WORLD DEVICE MESH: {self.world_mesh}, CP DEVICE MESH: {self.world_mesh['cp']}\n",
)

if self._log_peak_memory_stats and self._device.type not in {"cuda", "xpu"}:
self._logger.info(
"log_peak_memory_stats was set to True, however, training does not use cuda or xpu."
Expand Down Expand Up @@ -498,18 +531,38 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

if torch.hpu.is_available():
# Initialize LoRA params and RoPE buffers (Before FSDP sharding)
with training.set_default_dtype(self._dtype), self._device:
lora_device = "cpu" if fsdp_cpu_offload else self._device
for m in model.modules():
if (isinstance(m, AdapterModule)) and not lora_weights_state_dict:
# lora may not be covered in state dict
# if finetune for the 1st time
m.to_empty(device=lora_device)
m.initialize_parameters()

if hasattr(m, "rope_init"):
m.rope_init()

# For FSDP sharding
fsdp_shard_conditions = [
partial(
training.get_shard_conditions,
names_to_match=custom_sharded_layers,
)
]
if self.parallel_dims.dp_replicate_enabled:
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
else:
dp_mesh_dim_names = ("dp_shard_cp",)

training.shard_model(
model=model,
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
dp_mesh=self.world_mesh[dp_mesh_dim_names],
)

if lora_weights_state_dict:
Expand All @@ -522,18 +575,19 @@ def _setup_model(
else:
lora_missing, lora_unexpected = None, None

# Initialize LoRA params and RoPE buffers
with training.set_default_dtype(self._dtype), self._device:
lora_device = "cpu" if fsdp_cpu_offload else self._device
for m in model.modules():
if (isinstance(m, AdapterModule)) and not lora_weights_state_dict:
# lora may not be covered in state dict
# if finetune for the 1st time
m.to_empty(device=lora_device)
m.initialize_parameters()
if torch.cuda.is_available():
# Initialize LoRA params and RoPE buffers
with training.set_default_dtype(self._dtype), self._device:
lora_device = "cpu" if fsdp_cpu_offload else self._device
for m in model.modules():
if (isinstance(m, AdapterModule)) and not lora_weights_state_dict:
# lora may not be covered in state dict
# if finetune for the 1st time
m.to_empty(device=lora_device)
m.initialize_parameters()

if hasattr(m, "rope_init"):
m.rope_init()
if hasattr(m, "rope_init"):
m.rope_init()

base_missing, base_unexpected = training.load_from_full_model_state_dict(
model,
Expand Down Expand Up @@ -626,10 +680,10 @@ def _setup_data(
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
packed = getattr(ds, "packed", False)
self.packed_dataset = getattr(ds, "packed", False)
else:
ds = config.instantiate(cfg_dataset, self._tokenizer)
packed = cfg_dataset.get("packed", False)
self.packed_dataset = cfg_dataset.get("packed", False)

# Instantiate collate_fn
if "left_pad_sequence" in collate_fn:
Expand All @@ -652,7 +706,7 @@ def _setup_data(
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
if not self.packed_dataset
else padded_collate_packed
),
# dropping last avoids shape issues with compile + flex attention
Expand Down Expand Up @@ -915,17 +969,50 @@ def train(self) -> None:
def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")
if self.cp_degree>1 and self._tokenizer.pad_to_max_seq_len and not self.packed_dataset:
# For cp we need to pass the buffer and the buffer dim (cp_seq_dims) to apply cp
cp_buffers = []
for k,v in batch.items():
if isinstance(v, torch.Tensor):
cp_buffers.append(v)

cp_buffers.append(labels)

optional_context_parallel_ctx = (
training.create_context_parallel_ctx(
cp_mesh=self.world_mesh["cp"],
cp_buffers=cp_buffers,
cp_seq_dims=[1]*len(cp_buffers),
cp_no_restore_buffers=set(cp_buffers),
cp_rotate_method=self.context_parallel_rotate_method,
)
if self.parallel_dims.cp_enabled
else None
)

with optional_context_parallel_ctx:
with self.activations_handling_ctx:
outputs = self._model(**batch)

with self.activations_handling_ctx:
outputs = self._model(**batch)
if self.linear_loss:
weight = self._model.linear_projection_weight
loss = self._loss_fn(weight, outputs, labels)
else:
labels = labels.reshape(-1)
outputs = outputs.reshape(-1, outputs.size(-1))
loss = self._loss_fn(outputs, labels)

if self.linear_loss:
weight = self._model.linear_projection_weight
loss = self._loss_fn(weight, outputs, labels)
else:
labels = labels.reshape(-1)
outputs = outputs.reshape(-1, outputs.size(-1))
loss = self._loss_fn(outputs, labels)
with self.activations_handling_ctx:
outputs = self._model(**batch)

if self.linear_loss:
weight = self._model.linear_projection_weight
loss = self._loss_fn(weight, outputs, labels)
else:
labels = labels.reshape(-1)
outputs = outputs.reshape(-1, outputs.size(-1))
loss = self._loss_fn(outputs, labels)

# free logits otherwise it peaks backward memory
del outputs
Expand Down
3 changes: 2 additions & 1 deletion torchtune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
QuestionAnswerTemplate,
SummarizeTemplate,
)
from torchtune.data._utils import format_content_with_images, load_image, truncate
from torchtune.data._utils import format_content_with_images, load_image, truncate, pad_tokens

__all__ = [
"CROSS_ENTROPY_IGNORE_IDX",
Expand Down Expand Up @@ -60,4 +60,5 @@
"padded_collate_tiled_images_and_mask",
"padded_collate_packed",
"load_image",
"pad_tokens"
]
47 changes: 47 additions & 0 deletions torchtune/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,53 @@ def truncate(

return tokens_truncated

def pad_tokens(
tokens: List[Any],
mask: List[Any],
max_seq_len: int,
pad_id: int,
eos_id: Optional[Any] = None,
pad_type: str = "right",
) -> List[Any]:
"""
Pad a list of tokens and masks to a maximum seq length.
If eos_id is provided, the last token will be replaced with eos_id.

Args:
tokens (List[Any]): list of tokens to pad
mask (List[Any]): list of masks to pad
max_seq_len (int): maximum length of the list
pad_id (int): token used for padding
eos_id (Optional[Any]): token to replace the last token with. If None, the
last token will not be replaced. Default is None.
pad_type (str): type of padding to apply, either "left" or "right".
Default is "right".

Returns:
Tuple[List[Any], List[bool]]: padded tokens and attention mask (True for real tokens, False for padding)

Raises:
ValueError: if pad_type is not "left" or "right"
"""

padding_length = max_seq_len - len(tokens)
if padding_length > 0:
if pad_type == "right":
tokens = tokens + [pad_id] * padding_length
mask = mask + [False] * padding_length
elif pad_type == "left":
tokens = [pad_id] * padding_length + tokens
mask = [False] * padding_length + mask
else:
raise ValueError(
f"truncation_type must be 'left' or 'right', got {pad_type}"
)

# Replace the last token with eos_id if necessary
if eos_id is not None and tokens and tokens[-1] != eos_id:
tokens[-1] = eos_id

return tokens, mask

def load_image(image_loc: Union[Path, str]) -> torch.Tensor:
"""
Expand Down
2 changes: 2 additions & 0 deletions torchtune/models/llama3/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def llama3_tokenizer(
max_seq_len: Optional[int] = None,
prompt_template: Optional[_TemplateType] = None,
truncation_type: str = "right",
pad_to_max_seq_len: bool = False,
) -> Llama3Tokenizer:
"""
Tokenizer for Llama3.
Expand Down Expand Up @@ -106,6 +107,7 @@ def llama3_tokenizer(
max_seq_len=max_seq_len,
prompt_template=template,
truncation_type=truncation_type,
pad_to_max_seq_len=pad_to_max_seq_len,
)


Expand Down
14 changes: 12 additions & 2 deletions torchtune/models/llama3/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
from typing import Any, Dict, List, Mapping, Optional, Tuple

from torchtune.data import Message, PromptTemplate, truncate
from torchtune.data import Message, PromptTemplate, truncate, pad_tokens
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import (
ModelTokenizer,
Expand Down Expand Up @@ -82,6 +82,7 @@ def __init__(
max_seq_len: Optional[int] = None,
prompt_template: Optional[PromptTemplate] = None,
truncation_type: str = "right",
pad_to_max_seq_len: bool = False,
):
self.special_tokens = (
special_tokens if special_tokens is not None else LLAMA3_SPECIAL_TOKENS
Expand Down Expand Up @@ -118,7 +119,7 @@ def __init__(
special_tokens=self.special_tokens,
)
self.max_seq_len = max_seq_len

self.pad_to_max_seq_len = pad_to_max_seq_len
self.prompt_template = prompt_template

# Regex for removing special tokens from the decoded string
Expand Down Expand Up @@ -341,6 +342,15 @@ def tokenize_messages(
truncation_type=self.truncation_type,
)

# Add padding if pad_to_max_seq_len=True and max_seq_len is set
if self.pad_to_max_seq_len:
tokens, mask = pad_tokens(
tokens=tokens,
mask=mask,
max_seq_len=self.max_seq_len,
pad_id=self.pad_id,
)

return tokens, mask

def __call__(
Expand Down
Loading