diff --git a/scripts/prepare_hidden_states.py b/scripts/prepare_hidden_states.py index a1d45fe1..d201ca47 100644 --- a/scripts/prepare_hidden_states.py +++ b/scripts/prepare_hidden_states.py @@ -46,7 +46,7 @@ from tqdm import tqdm from transformers import AutoConfig, AutoProcessor, AutoTokenizer -from datasets import load_dataset +from datasets import Dataset from specforge.args import SGLangBackendArgs from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders from specforge.distributed import ( @@ -57,7 +57,11 @@ is_tp_rank_0, ) from specforge.modeling.target import Eagle3TargetModel, get_eagle3_target_model -from specforge.utils import print_with_rank, rank_0_priority +from specforge.utils import ( + print_with_rank, + rank_0_priority, + safe_conversations_generator, +) @dataclass @@ -469,7 +473,6 @@ def generate( filtered_batch_gpu = { k: v.cuda(non_blocking=True) for k, v in filtered_batch.items() } - _, _, aux_hidden_states_list, last_hidden_states_list = self.model.extend( **filtered_batch_gpu, return_last_hidden_states=True, @@ -574,10 +577,17 @@ def main(): assert os.path.exists( args.data_path ), f"Dataset path {args.data_path} does not exist" - dataset = load_dataset("json", data_files=args.data_path)["train"] + dataset = Dataset.from_generator( + generator=safe_conversations_generator, + gen_kwargs={"file_path": args.data_path}, + cache_dir=os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "cache", + "hf_dataset", + ), + ) if args.num_samples is not None: dataset = dataset.select(range(args.num_samples)) - # Tokenizer and cache key tokenizer = AutoTokenizer.from_pretrained( args.target_model_path, trust_remote_code=True @@ -643,7 +653,7 @@ def main(): # Pass configurable arguments from args if needed with HiddenStatesGenerator( target_model, - args.enable_aux_hidden_states, + enable_aux_hidden_states=args.enable_aux_hidden_states, num_io_threads=args.num_io_threads, io_queue_size=args.io_queue_size, file_group_size=args.file_group_size, diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index b5ab6a1d..929e2640 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -17,7 +17,7 @@ from tqdm import tqdm from transformers import AutoProcessor, AutoTokenizer -from datasets import load_dataset +from datasets import Dataset from specforge import ( AutoDraftModelConfig, AutoEagle3DraftModel, @@ -53,6 +53,7 @@ print_on_rank0, print_with_rank, rank_0_priority, + safe_conversations_generator, ) @@ -412,7 +413,10 @@ def build_dataloaders( f"{args.target_model_path}" # Tokenizer may also different ) cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() - train_dataset = load_dataset("json", data_files=args.train_data_path)["train"] + train_dataset = Dataset.from_generator( + generator=safe_conversations_generator, + gen_kwargs={"file_path": args.train_data_path}, + ) is_online = ( args.train_data_path is not None and args.train_hidden_states_path is None ) @@ -458,7 +462,10 @@ def build_dataloaders( ) if args.eval_data_path is not None or args.eval_hidden_states_path is not None: if args.eval_data_path is not None: - eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"] + eval_dataset = Dataset.from_generator( + generator=safe_conversations_generator, + gen_kwargs={"file_path": args.eval_data_path}, + ) eval_eagle3_dataset = build_eagle3_dataset( eval_dataset, tokenizer, @@ -589,14 +596,16 @@ def run_forward( hidden_states = get_dp_data_shard_from_tp(eagle3_data.hidden_states) else: # we generate the logits using the hidden states loaded from disk - input_ids = data["input_ids"].cuda() attention_mask = data["attention_mask"].cuda() - loss_mask = data["loss_mask"].cuda() hidden_states = data["hidden_state"].cuda() - target = target_model(data["target"].cuda()) input_ids, target, loss_mask = target_model.preprocess( - input_ids, target, loss_mask + data["input_ids"], data["target"], data["loss_mask"] ) + input_ids = input_ids.cuda() + target = target_model( + target.cuda() + ) # The `data['target']` value occupies a large amount of GPU memory, with a shape of [seqlen, vocab_size]. It needs to be processed before being loaded into the GPU. + loss_mask = loss_mask.cuda() plosses, _, acces = eagle3_model( input_ids=input_ids, attention_mask=attention_mask, diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index 154dfd17..abf43527 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -25,6 +25,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint from transformers.cache_utils import DynamicCache from yunchang import EXTRACT_FUNC_DICT @@ -122,6 +123,7 @@ def forward( length=self.length, ) del target + torch.cuda.empty_cache() # basic info batch_size, seq_length, _ = hidden_states.shape @@ -166,7 +168,7 @@ def forward( dtype=torch.bool, device=hidden_states.device, ) - if self.attention_backend in ("sdpa", "usp"): + if self.attention_backend == "sdpa": attention_mask = self.draft_model.prepare_decoder_attention_mask( attention_mask=attention_mask, hidden_states=hidden_states, @@ -175,6 +177,24 @@ def forward( past_key_values_length=past_key_values_length, ) + def compute_loss_and_acc_checkpointed(hs, tgt_p, pos_mask, l_mask): + # 1. Compute Logits(The part that consumes the most VRAM.) + logits_ = self.draft_model.compute_logits(hs) + logits = gather_outputs_and_unpad(logits_, gather_dim=1) + + # 2. Compute Loss + loss_val = LogSoftmaxLoss.apply(logits, tgt_p, pos_mask) + + # 3. Compute Accuracy + with torch.no_grad(): + acc_val = _compute_metric_acc( + logits=logits, + target_p=tgt_p, + position_mask=pos_mask, + loss_mask=l_mask, + ) + return loss_val, acc_val + # Step 5: run TTT plosses = [] vlosses = [] @@ -217,24 +237,22 @@ def forward( # update hidden states for next step hidden_states = hidden_states_out - # Step 5.4: get logits - logits = self.draft_model.compute_logits(hidden_states) - logits = gather_outputs_and_unpad(logits, gather_dim=1) - # Step 5.5: record metrics first as we in-place modify logits - with torch.no_grad(): - acces.append( - _compute_metric_acc( - logits=logits, - target_p=target_p, - position_mask=position_mask, - loss_mask=loss_mask, - ) + if hidden_states.requires_grad: + loss, acc = checkpoint( + compute_loss_and_acc_checkpointed, + hidden_states, + target_p, + position_mask, + loss_mask, + use_reentrant=False, + ) + else: + loss, acc = compute_loss_and_acc_checkpointed( + hidden_states, target_p, position_mask, loss_mask ) - # Step 5.6: calculate loss, in-place modifies logits! - loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) plosses.append(loss) - + acces.append(acc) if not is_last: # Step 5.7: we need to update the loss mask global_input_ids = padding(global_input_ids, left=False) diff --git a/specforge/data/parse.py b/specforge/data/parse.py index d96f37b1..073e882a 100644 --- a/specforge/data/parse.py +++ b/specforge/data/parse.py @@ -1,3 +1,4 @@ +import json import re import warnings from abc import ABC, abstractmethod @@ -111,6 +112,13 @@ def parse( f"An 'assistant' message must follow a 'user' or 'tool' message, but was preceded by '{prev_role}'. Conversation truncated." ) break + tool_calls = sentence.get("tool_calls") + if isinstance(tool_calls, str): + try: + sentence["tool_calls"] = json.loads(tool_calls) + except json.JSONDecodeError: + warnings.warn(f"Failed to parse tool_calls JSON: {tool_calls}") + break messages.append(sentence) try: @@ -164,11 +172,17 @@ def parse( # --- Core Alternative Operation: Calculate Token Index Based on Prefix String Length --- # Encode the text "assistant start", the length of which is the position of the starting token. prefix_ids = self.tokenizer.encode( - conversation[:content_start_char], add_special_tokens=False + conversation[:content_start_char], + add_special_tokens=False, + truncation=True, + max_length=max_length, ) # Encodes the text "assistant end", the length of which is the position of the end token. full_ids = self.tokenizer.encode( - conversation[:content_end_char], add_special_tokens=False + conversation[:content_end_char], + add_special_tokens=False, + truncation=True, + max_length=max_length, ) start_token_idx = len(prefix_ids) diff --git a/specforge/modeling/target/target_head.py b/specforge/modeling/target/target_head.py index 3918d6d6..7231117c 100644 --- a/specforge/modeling/target/target_head.py +++ b/specforge/modeling/target/target_head.py @@ -89,5 +89,4 @@ def preprocess(self, input_ids, target, loss_mask): target = padding(target, left=False) input_ids = padding(input_ids, left=False) loss_mask = loss_mask[..., None] - loss_mask = loss_mask.to(target.device) return input_ids, target, loss_mask diff --git a/specforge/utils.py b/specforge/utils.py index 57a423bb..59724a82 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -301,3 +301,59 @@ def shard_optimizer_state_with_dtensor(bf16_optimizer, device_mesh): state[k] = distribute_tensor( v.to(p.device), device_mesh=mesh, placements=placements ) + + +def safe_conversations_generator(file_path): + """ + Generator that: + 1. Extracts the 'conversations' field. + 2. Preserves all original fields within each message. + 3. [Key step] Converts all list/dict-type field values to strings to resolve mixed-type conflicts (e.g., for Arrow compatibility). + """ + with open(file_path, "r", encoding="utf-8") as f: + for i, line in enumerate(f): + line = line.strip() + if not line: + continue + try: + row = json.loads(line) + raw_convs = row.get("conversations", []) + + # 1. Ensure 'conversations' is a list + if not isinstance(raw_convs, list): + # If it's None or some unexpected type, treat as empty or skip + if raw_convs is None: + raw_convs = [] + else: + # Edge case: 'conversations' is a plain string or non-iterable—skip this line + logger.warning( + f"Line {i + 1}: 'conversations' is not a list. Please check!" + ) + continue + + cleaned_convs = [] + for msg in raw_convs: + # 2. Ensure each item in the list is a dictionary + if not isinstance(msg, dict): + # Skip if an element is not a dict (e.g., malformed like ["user", "hi"]) + continue + + # 3. [Core logic] Iterate over all fields in the message (role, content, tools, etc.) + new_msg = {} + for k, v in msg.items(): + # If the value is a list or dict, serialize it to a JSON string + # This ensures Arrow treats the column as string type instead of list/struct + if isinstance(v, (list, dict)): + new_msg[k] = json.dumps(v, ensure_ascii=False) + else: + # Keep primitive types (str, int, float, bool, None) unchanged + new_msg[k] = v + + cleaned_convs.append(new_msg) + + # Yield only the processed 'conversations' + yield {"conversations": cleaned_convs} + + except Exception as e: + logger.warning(f"Skipping line {i + 1}: {e}") + continue