Skip to content
Merged
22 changes: 16 additions & 6 deletions scripts/prepare_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from tqdm import tqdm
from transformers import AutoConfig, AutoProcessor, AutoTokenizer

from datasets import load_dataset
from datasets import Dataset
from specforge.args import SGLangBackendArgs
from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders
from specforge.distributed import (
Expand All @@ -57,7 +57,11 @@
is_tp_rank_0,
)
from specforge.modeling.target import Eagle3TargetModel, get_eagle3_target_model
from specforge.utils import print_with_rank, rank_0_priority
from specforge.utils import (
print_with_rank,
rank_0_priority,
safe_conversations_generator,
)


@dataclass
Expand Down Expand Up @@ -469,7 +473,6 @@ def generate(
filtered_batch_gpu = {
k: v.cuda(non_blocking=True) for k, v in filtered_batch.items()
}

_, _, aux_hidden_states_list, last_hidden_states_list = self.model.extend(
**filtered_batch_gpu,
return_last_hidden_states=True,
Expand Down Expand Up @@ -574,10 +577,17 @@ def main():
assert os.path.exists(
args.data_path
), f"Dataset path {args.data_path} does not exist"
dataset = load_dataset("json", data_files=args.data_path)["train"]
dataset = Dataset.from_generator(
generator=safe_conversations_generator,
gen_kwargs={"file_path": args.data_path},
cache_dir=os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"cache",
"hf_dataset",
),
)
if args.num_samples is not None:
dataset = dataset.select(range(args.num_samples))

# Tokenizer and cache key
tokenizer = AutoTokenizer.from_pretrained(
args.target_model_path, trust_remote_code=True
Expand Down Expand Up @@ -643,7 +653,7 @@ def main():
# Pass configurable arguments from args if needed
with HiddenStatesGenerator(
target_model,
args.enable_aux_hidden_states,
enable_aux_hidden_states=args.enable_aux_hidden_states,
num_io_threads=args.num_io_threads,
io_queue_size=args.io_queue_size,
file_group_size=args.file_group_size,
Expand Down
23 changes: 16 additions & 7 deletions scripts/train_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tqdm import tqdm
from transformers import AutoProcessor, AutoTokenizer

from datasets import load_dataset
from datasets import Dataset
from specforge import (
AutoDraftModelConfig,
AutoEagle3DraftModel,
Expand Down Expand Up @@ -53,6 +53,7 @@
print_on_rank0,
print_with_rank,
rank_0_priority,
safe_conversations_generator,
)


Expand Down Expand Up @@ -412,7 +413,10 @@ def build_dataloaders(
f"{args.target_model_path}" # Tokenizer may also different
)
cache_key = hashlib.md5(cache_params_string.encode()).hexdigest()
train_dataset = load_dataset("json", data_files=args.train_data_path)["train"]
train_dataset = Dataset.from_generator(
generator=safe_conversations_generator,
gen_kwargs={"file_path": args.train_data_path},
)
is_online = (
args.train_data_path is not None and args.train_hidden_states_path is None
)
Expand Down Expand Up @@ -458,7 +462,10 @@ def build_dataloaders(
)
if args.eval_data_path is not None or args.eval_hidden_states_path is not None:
if args.eval_data_path is not None:
eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"]
eval_dataset = Dataset.from_generator(
generator=safe_conversations_generator,
gen_kwargs={"file_path": args.eval_data_path},
)
eval_eagle3_dataset = build_eagle3_dataset(
eval_dataset,
tokenizer,
Expand Down Expand Up @@ -589,14 +596,16 @@ def run_forward(
hidden_states = get_dp_data_shard_from_tp(eagle3_data.hidden_states)
else:
# we generate the logits using the hidden states loaded from disk
input_ids = data["input_ids"].cuda()
attention_mask = data["attention_mask"].cuda()
loss_mask = data["loss_mask"].cuda()
hidden_states = data["hidden_state"].cuda()
target = target_model(data["target"].cuda())
input_ids, target, loss_mask = target_model.preprocess(
input_ids, target, loss_mask
data["input_ids"], data["target"], data["loss_mask"]
)
input_ids = input_ids.cuda()
target = target_model(
target.cuda()
) # The `data['target']` value occupies a large amount of GPU memory, with a shape of [seqlen, vocab_size]. It needs to be processed before being loaded into the GPU.
loss_mask = loss_mask.cuda()
Comment on lines +604 to +608
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the impact of this on performance? if it is large, maybe we can set it as a flag to control whether do this on GPU or CPU.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can simply compute this: vocab size(150000) * seq_length(64k) will cost 10G more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is target_head's preprocess function will use padding will generate an extra copy of the target memory.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can split hidden state in dataset getitem for usp to reduce memory use.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this is a better optimization method. Can you help add this optimization?

we can split hidden state in dataset getitem for usp to reduce memory use.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#454
Hi, I've finished the updates. Note that SP currently works with batch size 1. This seems reasonable for long-sequence scenarios to avoid OOM, but I'm open to feedback. Ready for review!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, I'll help review it.

plosses, _, acces = eagle3_model(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down
50 changes: 34 additions & 16 deletions specforge/core/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from transformers.cache_utils import DynamicCache
from yunchang import EXTRACT_FUNC_DICT

Expand Down Expand Up @@ -122,6 +123,7 @@ def forward(
length=self.length,
)
del target
torch.cuda.empty_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Explicitly calling torch.cuda.empty_cache() can introduce significant performance overhead due to CPU-GPU synchronization. The preceding del target should be sufficient to free the tensor's memory if there are no other references. Is this call strictly necessary for memory optimization in this case? If so, a comment explaining why would be helpful for future maintenance.


# basic info
batch_size, seq_length, _ = hidden_states.shape
Expand Down Expand Up @@ -166,7 +168,7 @@ def forward(
dtype=torch.bool,
device=hidden_states.device,
)
if self.attention_backend in ("sdpa", "usp"):
if self.attention_backend == "sdpa":
attention_mask = self.draft_model.prepare_decoder_attention_mask(
attention_mask=attention_mask,
hidden_states=hidden_states,
Expand All @@ -175,6 +177,24 @@ def forward(
past_key_values_length=past_key_values_length,
)

def compute_loss_and_acc_checkpointed(hs, tgt_p, pos_mask, l_mask):
# 1. Compute Logits(The part that consumes the most VRAM.)
logits_ = self.draft_model.compute_logits(hs)
logits = gather_outputs_and_unpad(logits_, gather_dim=1)

# 2. Compute Loss
loss_val = LogSoftmaxLoss.apply(logits, tgt_p, pos_mask)

# 3. Compute Accuracy
with torch.no_grad():
acc_val = _compute_metric_acc(
logits=logits,
target_p=tgt_p,
position_mask=pos_mask,
loss_mask=l_mask,
)
return loss_val, acc_val

# Step 5: run TTT
plosses = []
vlosses = []
Expand Down Expand Up @@ -217,24 +237,22 @@ def forward(
# update hidden states for next step
hidden_states = hidden_states_out

# Step 5.4: get logits
logits = self.draft_model.compute_logits(hidden_states)
logits = gather_outputs_and_unpad(logits, gather_dim=1)
# Step 5.5: record metrics first as we in-place modify logits
with torch.no_grad():
acces.append(
_compute_metric_acc(
logits=logits,
target_p=target_p,
position_mask=position_mask,
loss_mask=loss_mask,
)
if hidden_states.requires_grad:
loss, acc = checkpoint(
compute_loss_and_acc_checkpointed,
hidden_states,
target_p,
position_mask,
loss_mask,
use_reentrant=False,
)
else:
loss, acc = compute_loss_and_acc_checkpointed(
hidden_states, target_p, position_mask, loss_mask
)

# Step 5.6: calculate loss, in-place modifies logits!
loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
plosses.append(loss)

acces.append(acc)
if not is_last:
# Step 5.7: we need to update the loss mask
global_input_ids = padding(global_input_ids, left=False)
Expand Down
18 changes: 16 additions & 2 deletions specforge/data/parse.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import re
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -111,6 +112,13 @@ def parse(
f"An 'assistant' message must follow a 'user' or 'tool' message, but was preceded by '{prev_role}'. Conversation truncated."
)
break
tool_calls = sentence.get("tool_calls")
if isinstance(tool_calls, str):
try:
sentence["tool_calls"] = json.loads(tool_calls)
except json.JSONDecodeError:
warnings.warn(f"Failed to parse tool_calls JSON: {tool_calls}")
break
messages.append(sentence)

try:
Expand Down Expand Up @@ -164,11 +172,17 @@ def parse(
# --- Core Alternative Operation: Calculate Token Index Based on Prefix String Length ---
# Encode the text "assistant start", the length of which is the position of the starting token.
prefix_ids = self.tokenizer.encode(
conversation[:content_start_char], add_special_tokens=False
conversation[:content_start_char],
add_special_tokens=False,
truncation=True,
max_length=max_length,
)
# Encodes the text "assistant end", the length of which is the position of the end token.
full_ids = self.tokenizer.encode(
conversation[:content_end_char], add_special_tokens=False
conversation[:content_end_char],
add_special_tokens=False,
truncation=True,
max_length=max_length,
)

start_token_idx = len(prefix_ids)
Expand Down
1 change: 0 additions & 1 deletion specforge/modeling/target/target_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,4 @@ def preprocess(self, input_ids, target, loss_mask):
target = padding(target, left=False)
input_ids = padding(input_ids, left=False)
loss_mask = loss_mask[..., None]
loss_mask = loss_mask.to(target.device)
return input_ids, target, loss_mask
56 changes: 56 additions & 0 deletions specforge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,59 @@ def shard_optimizer_state_with_dtensor(bf16_optimizer, device_mesh):
state[k] = distribute_tensor(
v.to(p.device), device_mesh=mesh, placements=placements
)


def safe_conversations_generator(file_path):
"""
Generator that:
1. Extracts the 'conversations' field.
2. Preserves all original fields within each message.
3. [Key step] Converts all list/dict-type field values to strings to resolve mixed-type conflicts (e.g., for Arrow compatibility).
"""
with open(file_path, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
line = line.strip()
if not line:
continue
try:
row = json.loads(line)
raw_convs = row.get("conversations", [])

# 1. Ensure 'conversations' is a list
if not isinstance(raw_convs, list):
# If it's None or some unexpected type, treat as empty or skip
if raw_convs is None:
raw_convs = []
else:
# Edge case: 'conversations' is a plain string or non-iterable—skip this line
logger.warning(
f"Line {i + 1}: 'conversations' is not a list. Please check!"
)
continue

cleaned_convs = []
for msg in raw_convs:
# 2. Ensure each item in the list is a dictionary
if not isinstance(msg, dict):
# Skip if an element is not a dict (e.g., malformed like ["user", "hi"])
continue

# 3. [Core logic] Iterate over all fields in the message (role, content, tools, etc.)
new_msg = {}
for k, v in msg.items():
# If the value is a list or dict, serialize it to a JSON string
# This ensures Arrow treats the column as string type instead of list/struct
if isinstance(v, (list, dict)):
new_msg[k] = json.dumps(v, ensure_ascii=False)
else:
# Keep primitive types (str, int, float, bool, None) unchanged
new_msg[k] = v

cleaned_convs.append(new_msg)

# Yield only the processed 'conversations'
yield {"conversations": cleaned_convs}

except Exception as e:
logger.warning(f"Skipping line {i + 1}: {e}")
continue