forked from NVIDIA/Model-Optimizer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheagle_utils.py
More file actions
717 lines (609 loc) · 26.7 KB
/
eagle_utils.py
File metadata and controls
717 lines (609 loc) · 26.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from types import FrameType
from typing import Any
import numpy as np
import torch
import transformers
from datasets import load_dataset
from packaging.version import Version
from PIL import Image
from scripts.ar_validate import validate_ar
from torch.distributed.tensor.experimental._attention import _SDPAMerger
from torch.utils.data import Dataset
from transformers import AutoProcessor, Trainer, TrainerCallback
from transformers.trainer_pt_utils import LabelSmoother
import modelopt
from modelopt.torch.speculative.utils import get_ttt_msk_func
from modelopt.torch.utils import print_rank_0
from modelopt.torch.utils.distributed import is_master
try:
import wandb
except ImportError:
wandb = None
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
REMOVE_THINK_CHAT_TEMPLATE = (
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
)
def preprocess(examples, tokenizer, **kwargs):
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
new_examples = {
"input_ids": [],
"attention_mask": [],
"loss_mask": [],
"labels": [],
}
for i in range(len(examples)):
messages = []
source = examples[i]["conversations"]
# Detect format: either role/content or from/value
def get_role_content(item):
if "role" in item and "content" in item:
return item["role"], item["content"]
elif "from" in item and "value" in item:
return item["from"], item["value"]
else:
raise ValueError(f"Unknown conversation format: {item}")
for sentence in source:
role, content = get_role_content(sentence)
messages.append({"role": role.lower(), "content": content})
conversation = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
output = tokenizer(
conversation,
return_tensors="pt",
add_special_tokens=False,
truncation=True,
)
input_ids = output.input_ids[0]
attention_mask = output.attention_mask[0]
loss_mask = torch.ones_like(input_ids)
labels = torch.cat([input_ids[1:], torch.tensor([IGNORE_TOKEN_ID], dtype=input_ids.dtype)])
new_examples["input_ids"].append(input_ids)
new_examples["attention_mask"].append(attention_mask)
new_examples["loss_mask"].append(loss_mask)
new_examples["labels"].append(labels)
return new_examples
def preprocess_vlm(examples, tokenizer, processor, img_dir):
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
new_examples = {
"input_ids": [],
"attention_mask": [],
"loss_mask": [],
"labels": [],
"pixel_values": [],
"image_flags": [],
}
for i in range(len(examples)):
messages = []
source = examples[i]["conversations"]
# Detect format: either role/content or from/value
def get_role_content(item):
if "role" in item and "content" in item:
return item["role"], item["content"]
elif "from" in item and "value" in item:
return item["from"], item["value"]
else:
raise ValueError(f"Unknown conversation format: {item}")
# align role to user-assistant format
def convert_role(role):
role_map = {
"human": "user",
"gpt": "assistant",
}
return role_map[role.lower()] if role.lower() in role_map else role.lower()
for sentence in source:
role, content = get_role_content(sentence)
new_role = convert_role(role)
messages.append({"role": new_role, "content": content})
conversation = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
img_filename = os.path.join(img_dir, examples[i]["image"])
img = Image.open(img_filename)
output = processor(images=img, text=conversation, return_tensors="pt")
input_ids = output.input_ids[0]
attention_mask = output.attention_mask[0]
loss_mask = torch.ones_like(input_ids)
labels = torch.cat([input_ids[1:], torch.tensor([IGNORE_TOKEN_ID], dtype=input_ids.dtype)])
# TODO: add labels and answer-only loss masking?
new_examples["input_ids"].append(input_ids)
new_examples["attention_mask"].append(attention_mask)
new_examples["loss_mask"].append(loss_mask)
new_examples["labels"].append(labels)
new_examples["pixel_values"].append(output.pixel_values)
new_examples["image_flags"].append(
torch.ones((output.pixel_values.shape[0],), dtype=torch.int64)
)
return new_examples
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning.
Args:
raw_data (list): A list of raw data examples.
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
"""
def __init__(
self,
raw_data,
tokenizer: transformers.PreTrainedTokenizer,
vlm_processor=None,
img_dir=None,
):
super().__init__()
print_rank_0("Formatting inputs...")
sources = raw_data
self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess
self.data_dict = self.preprocess_fn(
sources, tokenizer, processor=vlm_processor, img_dir=img_dir
)
def __len__(self):
return len(self.data_dict["input_ids"])
def __getitem__(self, i) -> dict[str, torch.Tensor]:
return {k: self.data_dict[k][i] for k in self.data_dict}
class LazySupervisedDataset(Dataset):
"""Lazy dataset for supervised fine-tuning.
This dataset loads data on-the-fly when requested, which can be memory-efficient but slower.
Args:
raw_data (list): A list of raw data examples.
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
"""
def __init__(
self,
raw_data,
tokenizer: transformers.PreTrainedTokenizer,
vlm_processor=None,
img_dir=None,
):
super().__init__()
print_rank_0("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.raw_data = raw_data
self.cached_data_dict = {}
self.vlm_processor = vlm_processor
self.img_dir = img_dir
self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> dict[str, torch.Tensor]:
if i in self.cached_data_dict:
return self.cached_data_dict[i]
ret = self.preprocess_fn(
[self.raw_data[i]], self.tokenizer, processor=self.vlm_processor, img_dir=self.img_dir
)
ret = {k: ret[k][0] for k in ret}
self.cached_data_dict[i] = ret
return ret
class OfflineSupervisedDataset(Dataset):
"""Lazy offline dataset for supervised fine-tuning.
This dataset loads data on-the-fly from pre-processed .pt data files as well as
input conversations in JSON format.
Args:
data_entries (list): A list of tuples (raw_data_example, file_path).
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
"""
def __init__(
self,
data_entries,
tokenizer: transformers.PreTrainedTokenizer,
vlm_processor=None,
img_dir=None,
):
super().__init__()
print_rank_0("Formatting inputs...Skip in offline mode")
self.tokenizer = tokenizer
self.data_entries = data_entries
self.vlm_processor = vlm_processor
self.img_dir = img_dir
self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess
# Does not cache the hidden states, as those have an extremely large memory footprint.
self.cached_data_dict = {}
def __len__(self):
return len(self.data_entries)
def __getitem__(self, i) -> dict[str, torch.Tensor]:
# Load the conversational data, using the cache
raw_data, offline_file_path = self.data_entries[i]
# Extend the data sample with the hidden states from the .pt file
max_length = self.tokenizer.model_max_length
offline_data = torch.load(offline_file_path)
offline_data["input_ids"] = offline_data["input_ids"][:max_length]
offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :]
offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :]
ret = {
"input_ids": offline_data["input_ids"],
"attention_mask": torch.ones_like(offline_data["input_ids"]),
"loss_mask": torch.ones_like(offline_data["input_ids"]),
"labels": torch.full_like(offline_data["input_ids"], IGNORE_TOKEN_ID),
"kwargs": {
"base_model_outputs": {
"base_model_hidden_states": offline_data["hidden_states"],
"aux_hidden_states": offline_data["aux_hidden_states"],
}
},
}
return ret
def make_eagle_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
max_length=None,
) -> dict:
"""Make dataset and collator for supervised fine-tuning.
Args:
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
data_args: Data arguments.
Returns:
dict: A dictionary containing train and eval datasets.
"""
if data_args.vlm_processor:
vlm_processor = AutoProcessor.from_pretrained(
data_args.vlm_processor, trust_remote_code=True, use_fast=True
)
vlm_img_dir = data_args.vlm_img_dir
else:
vlm_processor, vlm_img_dir = None, None
# Load the conversations from the source file
print_rank_0("Loading input conversations...")
data_json = []
data_path_p = Path(data_args.data_path)
if data_path_p.is_dir():
# Load all .jsonl files in the directory and combine them
for jsonl_file in sorted(data_path_p.glob("*.jsonl")):
with open(jsonl_file) as f:
data_json.extend(json.loads(line) for line in f)
else:
with open(data_args.data_path) as f:
if data_args.data_path.endswith("jsonl"):
data_json = [json.loads(line) for line in f]
else:
data_json = json.load(f)
if data_args.offline_data_path is not None:
print_rank_0("Loading pre-processed data for offline training...")
dataset_cls = OfflineSupervisedDataset
# Glob for all .pt files in the data_path directory
assert data_args.offline_data_path is not None, (
"offline_data_path must be provided for offline training."
)
offline_data_path = Path(data_args.offline_data_path)
# Collect all pt file paths
all_files = {str(p) for p in offline_data_path.glob("*.pt")}
all_files |= {str(p) for p in offline_data_path.glob("**/*.pt")}
if not all_files:
raise ValueError(f"No .pt files found in {data_args.offline_data_path}")
# Build a map from conv_id to file_path for fast lookup
print("building conv_id_to_file map...")
conv_id_to_file = {}
for pt_path in all_files:
pt_name = Path(pt_path).name
# Expect conv_id.pt
if pt_name.endswith(".pt"):
conv_id = pt_name[:-3]
conv_id_to_file[conv_id] = pt_path
valid_entries = []
print("filtering valid entries...")
for entry in data_json:
conv_id = entry.get("conversation_id")
if conv_id is None:
conv_id = entry.get("uuid")
if conv_id is None:
conv_id = entry.get("id")
if conv_id is None:
raise ValueError(f"Conversation ID required but not found for entry {entry}")
file_path = conv_id_to_file.get(str(conv_id))
if file_path is None:
continue
valid_entries.append((entry, file_path))
if len(valid_entries) == 0:
msg = """No valid files found in the offline data path that match the conversation IDs
in the provided data json. Please ensure that the offline data path is correct and
contains .pt files named after the conversation IDs, and that the input conversations
json has the correct format (with 'conversation_id' or 'id' fields)."""
raise ValueError(msg)
elif len(valid_entries) < len(data_json):
print_rank_0(
f"Warning: Only {len(valid_entries)} out of {len(data_json)} conversations"
" have corresponding .pt files in the offline data path. Continuing..."
)
num_train = int(len(valid_entries) * 0.95)
train_dataset = dataset_cls(
valid_entries[:num_train],
tokenizer=tokenizer,
vlm_processor=vlm_processor,
img_dir=vlm_img_dir,
)
eval_dataset = dataset_cls(
valid_entries[num_train:],
tokenizer=tokenizer,
vlm_processor=vlm_processor,
img_dir=vlm_img_dir,
)
data_collator = DataCollatorForOffline(max_length=max_length)
else:
print_rank_0("Loading input conversations...")
dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
train_dataset = dataset_cls(
data_json[: int(len(data_json) * 0.95)],
tokenizer=tokenizer,
vlm_processor=vlm_processor,
img_dir=vlm_img_dir,
)
eval_dataset = dataset_cls(
data_json[int(len(data_json) * 0.95) :],
tokenizer=tokenizer,
vlm_processor=vlm_processor,
img_dir=vlm_img_dir,
)
data_collator = DataCollatorWithPadding(max_length=max_length)
return {
"train_dataset": train_dataset,
"eval_dataset": eval_dataset,
"data_collator": data_collator,
}
class DataCollatorWithPadding:
def __init__(self, max_length):
self.max_length = max_length
def paddingtensor2d(self, intensors, length):
n, dim = intensors.shape
if n > length:
return intensors[:length, :]
padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype)
outtensors = torch.cat((intensors, padding_tensor))
return outtensors
def paddingtensor(self, intensors, length):
if intensors.shape[0] > length:
return intensors[:length]
padding_tensor = torch.zeros(length - intensors.shape[0], dtype=intensors.dtype)
outtensors = torch.cat((intensors, padding_tensor))
return outtensors
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
batch_input_ids = torch.stack(
[self.paddingtensor(item["input_ids"], self.max_length) for item in features]
)
batch_attention_mask = torch.stack(
[self.paddingtensor(item["attention_mask"], self.max_length) for item in features]
)
batch_loss_mask = torch.stack(
[self.paddingtensor(item["loss_mask"], self.max_length) for item in features]
)
batch_labels = torch.stack(
[self.paddingtensor(item["labels"], self.max_length) for item in features]
)
batch = {
"input_ids": batch_input_ids,
"attention_mask": batch_attention_mask,
"loss_mask": batch_loss_mask,
"labels": batch_labels,
}
# Collate VLM data
if "pixel_values" in features[0]:
# pixel values and image flags should be flattened inside a batch
batch["pixel_values"] = torch.cat([item["pixel_values"] for item in features], dim=0)
batch["image_flags"] = torch.cat([item["image_flags"] for item in features], dim=0)
return batch
class DataCollatorForOffline(DataCollatorWithPadding):
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
base_batch = super().__call__(features)
if "kwargs" not in features[0]:
raise ValueError("No kwargs found in batch features. Offline data required.")
features = [item["kwargs"]["base_model_outputs"] for item in features]
batch_hidden_states = torch.stack(
[
self.paddingtensor2d(item["base_model_hidden_states"], self.max_length)
for item in features
]
)
batch_aux_hidden_states = torch.stack(
[self.paddingtensor2d(item["aux_hidden_states"], self.max_length) for item in features]
)
batch = {
**base_batch,
"base_model_outputs": {
"base_model_hidden_states": batch_hidden_states,
"aux_hidden_states": batch_aux_hidden_states,
},
}
return batch
class EagleTrainerWithAccLog(Trainer):
"""Wrapper around Trainer that logs training accuracy."""
def compute_loss(self, *args, **kwargs):
"""Override compute_loss to save train accs in trainer state."""
if not hasattr(self.state, "training_accs"):
self.state.training_accs = []
kwargs.pop("num_items_in_batch", None)
loss, outputs = super().compute_loss(return_outputs=True, *args, **kwargs)
if hasattr(outputs, "train_acc"):
self.state.training_accs.append(outputs.train_acc)
return loss
class EagleTrainingPlot(TrainerCallback):
"""Callback that plot training acc and AR during training."""
def __init__(self, ar_validate_steps: int = 1000, estimate_ar: bool = False):
self.ar_validate_steps = ar_validate_steps
if wandb and is_master():
wandb.init()
self.estimate_ar = estimate_ar
def on_log(self, args, state, control, **kwargs):
"""Log training acc and estimate AR during log step."""
if not hasattr(state, "training_accs") or len(state.training_accs) == 0:
return control
average_acc = np.mean(state.training_accs, axis=0)
if self.estimate_ar:
# Calculate mean training AR since last log
# NOTE: This is only an estimate of the real AR.
est_ar = 1
acc_cumprod = 1
for step_acc in average_acc[0]:
acc_cumprod *= step_acc
est_ar += acc_cumprod
# Parallel draft tokens only used after all eagle tokens
for draft_acc in average_acc[1:]:
acc_cumprod *= draft_acc[-1]
est_ar += acc_cumprod
print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}")
# log to wandb
if wandb and is_master():
for i, draft_acc in enumerate(average_acc):
for j, step_acc in enumerate(draft_acc):
wandb.log(
{f"parallel_{i}_step_{j}_train_acc": step_acc}, step=state.global_step
)
if self.estimate_ar:
wandb.log({"estimated_training_ar": est_ar}, step=state.global_step)
# reset training_accs
state.training_accs = []
return control
def on_step_end(self, args, state, control, **kwargs):
"""Run AR validation periodically, if available."""
if self.ar_validate_steps <= 0:
return control
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
print_rank_0("Running AR validation...")
try:
ars = validate_ar(
model=kwargs["model"],
tokenizer=kwargs["processing_class"],
ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
device=kwargs["model"].device,
)
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
if wandb and is_master():
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
except Exception:
print_rank_0("AR validation not available.")
return control
def get_patched_templated_ring_attn(orig_templated_attn: Callable):
"""
Return patched version of
torch.distributed.tensor.experimental._attention._templated_ring_attention
to support TTT.
"""
def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype):
"""Get chunk-interleaved TTT mask for current rank.
e.g.:
2 ranks, ttt_step=1;
full_ttt_mask = [[0, 0, 0, 0, x, 0, 0, 0],
[x, 0, 0, 0, 0, x, 0, 0],
[x, x, 0, 0, 0, 0, x, 0],
[x, x, x, 0, 0, 0, 0, x],
rank 0, step0: [[0, 0, x, 0],
[x, 0, 0, x]]
rank 1, step0: [[0, 0, x, 0],
[x, 0, 0, x]]
rank 0, step1: [[0, 0, 0, 0],
[0, 0, 0, 0]]
rank 1, step1: [[x, x, 0, 0],
[x, x, 0, 0]]
"""
device = torch.cuda.current_device()
q_indices = torch.arange(q_len * rank, q_len * (rank + 1), device=device)
kv_indices = (
torch.arange(q_len * size * (ttt_step + 1), device=device)
.view(ttt_step + 1, size, q_len)[:, (rank - i) % size, :]
.reshape(-1)
)
msk_func = get_ttt_msk_func(q_len * size, ttt_step)
attn_mask = msk_func(
None,
None,
q_indices.view(1, 1, -1, 1),
kv_indices.view(1, 1, 1, -1),
)
attn_bias = torch.where(
attn_mask,
torch.zeros((), dtype=dtype, device=attn_mask.device),
torch.full((), torch.finfo(dtype).min, dtype=dtype, device=attn_mask.device),
)
return attn_bias
def patched_templated_attn(*args, **kwargs):
"""Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention."""
# Get original attention op
# Sensitive to impl of _templated_ring_attention
original_op = args[2]
# This patch is only enabled for eagle model by context manager, not base model.
patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH
if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention:
raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}")
# Unset is_causal to use custom attn mask
if patch_enbabled:
kwargs["is_causal"] = False
def patched_op(*args, **kwargs):
# Inspect the parent frame to get current shard info
# This is sensitive to torch _templated_ring_attention impl
try:
frame: FrameType = inspect.currentframe()
f_back: FrameType = frame.f_back
rank = f_back.f_locals["rank"]
size = f_back.f_locals["size"]
query = f_back.f_locals["query"]
key = f_back.f_locals["key"]
i = f_back.f_locals["i"]
ttt_step = (key.shape[2] // query.shape[2]) - 1
except Exception as e:
raise RuntimeError(
f"Failed to capture loop variables in patched _templated_ring_attention: {e}"
) from e
# Set attn mask to permuted TTT mask
if "attn_bias" in kwargs:
kwargs["attn_bias"] = _get_sharded_ttt_msk(
i, rank, size, query.shape[2], ttt_step, query.dtype
)
# Perform shard attention
return original_op(*args, **kwargs)
return orig_templated_attn(args[0], args[1], patched_op, *args[3:], **kwargs)
return patched_templated_attn
def patch_ring_attention_for_ttt():
"""Patch torch ring attention to support context parallelism for TTT."""
# Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask.
if not (
Version(torch.__version__) > Version("2.7.1")
and Version(torch.__version__) < Version("2.9.0")
):
raise RuntimeError(
f"Context parallel TTT only supported for PyTorch 2.8.0 now. "
f"Got {torch.__version__}. "
f"Please use nvcr.io/nvidia/pytorch:25.08-py3 or torch 2.8.0 or cp_size=1."
)
# 1. Disable load balance, which is designed for causal mask.
# This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init.
torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False
# 2. Patch templated ring attention for TTT mask.
original_templated_ring_attention = (
torch.distributed.tensor.experimental._attention._templated_ring_attention
)
original_templated_ring_attention_backward = (
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward
)
torch.distributed.tensor.experimental._attention._templated_ring_attention = (
get_patched_templated_ring_attn(original_templated_ring_attention)
)
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward = (
get_patched_templated_ring_attn(original_templated_ring_attention_backward)
)
# 3. Patch merger to skip the blank shard to avoid difference in output.
original_sdpa_merger_step = _SDPAMerger.step
def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool):
if lse.sum() <= 0:
return
return original_sdpa_merger_step(self, out, lse, partial)
_SDPAMerger.step = patched_sdpa_merger_step