From 6c778df8fd890cde34b24fc272daae649c622214 Mon Sep 17 00:00:00 2001 From: canghua Date: Wed, 14 Jan 2026 15:39:36 +0800 Subject: [PATCH 01/10] support handle comprehensive jsonl data file --- scripts/prepare_hidden_states.py | 5 ++--- scripts/train_eagle3.py | 6 +++--- specforge/utils.py | 13 +++++++++++++ 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/scripts/prepare_hidden_states.py b/scripts/prepare_hidden_states.py index a1d45fe18..690111557 100644 --- a/scripts/prepare_hidden_states.py +++ b/scripts/prepare_hidden_states.py @@ -46,7 +46,6 @@ from tqdm import tqdm from transformers import AutoConfig, AutoProcessor, AutoTokenizer -from datasets import load_dataset from specforge.args import SGLangBackendArgs from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders from specforge.distributed import ( @@ -57,7 +56,7 @@ is_tp_rank_0, ) from specforge.modeling.target import Eagle3TargetModel, get_eagle3_target_model -from specforge.utils import print_with_rank, rank_0_priority +from specforge.utils import load_dataset_from_jsonl, print_with_rank, rank_0_priority @dataclass @@ -574,7 +573,7 @@ def main(): assert os.path.exists( args.data_path ), f"Dataset path {args.data_path} does not exist" - dataset = load_dataset("json", data_files=args.data_path)["train"] + dataset = load_dataset_from_jsonl(data_path=args.data_path) if args.num_samples is not None: dataset = dataset.select(range(args.num_samples)) diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 04e07b479..e3412b744 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -17,7 +17,6 @@ from tqdm import tqdm from transformers import AutoProcessor, AutoTokenizer -from datasets import load_dataset from specforge import ( AutoDraftModelConfig, AutoEagle3DraftModel, @@ -48,6 +47,7 @@ from specforge.utils import ( create_draft_config_from_target, get_last_checkpoint, + load_dataset_from_jsonl, print_args_with_dots, print_on_rank0, print_with_rank, @@ -409,7 +409,7 @@ def build_dataloaders( f"{args.target_model_path}" # Tokenizer may also different ) cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() - train_dataset = load_dataset("json", data_files=args.train_data_path)["train"] + train_dataset = load_dataset_from_jsonl(data_path=args.train_data_path) with rank_0_priority(): train_eagle3_dataset = build_eagle3_dataset( dataset=train_dataset, @@ -450,7 +450,7 @@ def build_dataloaders( if args.eval_data_path is not None or args.eval_hidden_states_path is not None: if args.eval_data_path is not None: - eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"] + eval_dataset = load_dataset_from_jsonl(data_path=args.eval_data_path) eval_eagle3_dataset = build_eagle3_dataset( eval_dataset, tokenizer, diff --git a/specforge/utils.py b/specforge/utils.py index 57a423bbd..f33324f7c 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -9,6 +9,8 @@ from torch.distributed._tensor import DTensor, Shard, distribute_tensor from transformers import AutoConfig, PretrainedConfig +from datasets import Dataset + logger = logging.getLogger(__name__) @@ -301,3 +303,14 @@ def shard_optimizer_state_with_dtensor(bf16_optimizer, device_mesh): state[k] = distribute_tensor( v.to(p.device), device_mesh=mesh, placements=placements ) + + +def load_dataset_from_jsonl(data_path: str) -> Dataset: + data = [] + with open(data_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + data.append(json.loads(line)) + dataset = Dataset.from_list(data) + return dataset From 2560f941658246445b7a20646b3407d0da0a37fb Mon Sep 17 00:00:00 2001 From: canghua Date: Thu, 15 Jan 2026 22:08:19 +0800 Subject: [PATCH 02/10] support handle different tool-use message --- scripts/prepare_hidden_states.py | 11 +++++- scripts/train_eagle3.py | 13 +++++-- specforge/data/parse.py | 7 +++- specforge/utils.py | 63 +++++++++++++++++++++++++++----- 4 files changed, 78 insertions(+), 16 deletions(-) diff --git a/scripts/prepare_hidden_states.py b/scripts/prepare_hidden_states.py index 690111557..4df35011a 100644 --- a/scripts/prepare_hidden_states.py +++ b/scripts/prepare_hidden_states.py @@ -46,6 +46,7 @@ from tqdm import tqdm from transformers import AutoConfig, AutoProcessor, AutoTokenizer +from datasets import Dataset from specforge.args import SGLangBackendArgs from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders from specforge.distributed import ( @@ -56,7 +57,11 @@ is_tp_rank_0, ) from specforge.modeling.target import Eagle3TargetModel, get_eagle3_target_model -from specforge.utils import load_dataset_from_jsonl, print_with_rank, rank_0_priority +from specforge.utils import ( + print_with_rank, + rank_0_priority, + safe_conversations_generator, +) @dataclass @@ -573,7 +578,9 @@ def main(): assert os.path.exists( args.data_path ), f"Dataset path {args.data_path} does not exist" - dataset = load_dataset_from_jsonl(data_path=args.data_path) + dataset = Dataset.from_generator( + generator=safe_conversations_generator, gen_kwargs={"file_path": args.data_path} + ) if args.num_samples is not None: dataset = dataset.select(range(args.num_samples)) diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index e3412b744..710ada8e0 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -17,6 +17,7 @@ from tqdm import tqdm from transformers import AutoProcessor, AutoTokenizer +from datasets import Dataset from specforge import ( AutoDraftModelConfig, AutoEagle3DraftModel, @@ -47,11 +48,11 @@ from specforge.utils import ( create_draft_config_from_target, get_last_checkpoint, - load_dataset_from_jsonl, print_args_with_dots, print_on_rank0, print_with_rank, rank_0_priority, + safe_conversations_generator, ) @@ -409,7 +410,10 @@ def build_dataloaders( f"{args.target_model_path}" # Tokenizer may also different ) cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() - train_dataset = load_dataset_from_jsonl(data_path=args.train_data_path) + train_dataset = Dataset.from_generator( + generator=safe_conversations_generator, + gen_kwargs={"file_path": args.train_data_path}, + ) with rank_0_priority(): train_eagle3_dataset = build_eagle3_dataset( dataset=train_dataset, @@ -450,7 +454,10 @@ def build_dataloaders( if args.eval_data_path is not None or args.eval_hidden_states_path is not None: if args.eval_data_path is not None: - eval_dataset = load_dataset_from_jsonl(data_path=args.eval_data_path) + eval_dataset = Dataset.from_generator( + generator=safe_conversations_generator, + gen_kwargs={"file_path": args.eval_data_path}, + ) eval_eagle3_dataset = build_eagle3_dataset( eval_dataset, tokenizer, diff --git a/specforge/data/parse.py b/specforge/data/parse.py index e0e316b2d..3125412a1 100644 --- a/specforge/data/parse.py +++ b/specforge/data/parse.py @@ -1,3 +1,4 @@ +import json import re import warnings from abc import ABC, abstractmethod @@ -91,10 +92,14 @@ def parse( f"An 'assistant' message must follow a 'user' or 'tool' message, but was preceded by '{prev_role}'. Conversation truncated." ) break + if sentence["tool_calls"] is not None: + sentence["tool_calls"] = json.loads(sentence["tool_calls"]) messages.append(sentence) try: - conversation = self.apply_chat_template(messages, **kwargs) + conversation = self.apply_chat_template( + messages, max_length=max_length, **kwargs + ) except (ValueError, TypeError): # Fallback rendering for tokenizers without built-in chat_template warnings.warn( diff --git a/specforge/utils.py b/specforge/utils.py index f33324f7c..859d064cc 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -9,8 +9,6 @@ from torch.distributed._tensor import DTensor, Shard, distribute_tensor from transformers import AutoConfig, PretrainedConfig -from datasets import Dataset - logger = logging.getLogger(__name__) @@ -305,12 +303,57 @@ def shard_optimizer_state_with_dtensor(bf16_optimizer, device_mesh): ) -def load_dataset_from_jsonl(data_path: str) -> Dataset: - data = [] - with open(data_path, "r", encoding="utf-8") as f: - for line in f: +def safe_conversations_generator(file_path): + """ + Generator that: + 1. Extracts the 'conversations' field. + 2. Preserves all original fields within each message. + 3. [Key step] Converts all list/dict-type field values to strings to resolve mixed-type conflicts (e.g., for Arrow compatibility). + """ + with open(file_path, "r", encoding="utf-8") as f: + for i, line in enumerate(f): line = line.strip() - if line: - data.append(json.loads(line)) - dataset = Dataset.from_list(data) - return dataset + if not line: + continue + try: + row = json.loads(line) + raw_convs = row.get("conversations", []) + + # 1. Ensure 'conversations' is a list + if not isinstance(raw_convs, list): + # If it's None or some unexpected type, treat as empty or skip + if raw_convs is None: + raw_convs = [] + else: + # Edge case: 'conversations' is a plain string or non-iterable—skip this line + print( + f"⚠️ Line {i}: 'conversations' is not a list. Please check!" + ) + continue + + cleaned_convs = [] + for idx, msg in enumerate(raw_convs): + # 2. Ensure each item in the list is a dictionary + if not isinstance(msg, dict): + # Skip if an element is not a dict (e.g., malformed like ["user", "hi"]) + continue + + # 3. [Core logic] Iterate over all fields in the message (role, content, tools, etc.) + new_msg = {} + for k, v in msg.items(): + # If the value is a list or dict, serialize it to a JSON string + # This ensures Arrow treats the column as string type instead of list/struct + if isinstance(v, (list, dict)): + new_msg[k] = json.dumps(v, ensure_ascii=False) + else: + # Keep primitive types (str, int, float, bool, None) unchanged + new_msg[k] = v + + cleaned_convs.append(new_msg) + + # Yield only the processed 'conversations' + yield {"conversations": cleaned_convs} + + except Exception as e: + print(f"⚠️ Skipping line {i}: {e}") + continue From 61cbcc19cb13aeb4eca4ea7ecc7380c127528731 Mon Sep 17 00:00:00 2001 From: canghua Date: Thu, 15 Jan 2026 22:25:55 +0800 Subject: [PATCH 03/10] polish code --- specforge/data/parse.py | 9 +++++++-- specforge/utils.py | 8 ++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/specforge/data/parse.py b/specforge/data/parse.py index 3125412a1..19d8d25af 100644 --- a/specforge/data/parse.py +++ b/specforge/data/parse.py @@ -92,8 +92,13 @@ def parse( f"An 'assistant' message must follow a 'user' or 'tool' message, but was preceded by '{prev_role}'. Conversation truncated." ) break - if sentence["tool_calls"] is not None: - sentence["tool_calls"] = json.loads(sentence["tool_calls"]) + tool_calls = sentence.get("tool_calls") + if isinstance(tool_calls, str): + try: + sentence["tool_calls"] = json.loads(tool_calls) + except json.JSONDecodeError: + warnings.warn(f"Failed to parse tool_calls JSON: {tool_calls}") + break messages.append(sentence) try: diff --git a/specforge/utils.py b/specforge/utils.py index 859d064cc..59724a824 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -326,13 +326,13 @@ def safe_conversations_generator(file_path): raw_convs = [] else: # Edge case: 'conversations' is a plain string or non-iterable—skip this line - print( - f"⚠️ Line {i}: 'conversations' is not a list. Please check!" + logger.warning( + f"Line {i + 1}: 'conversations' is not a list. Please check!" ) continue cleaned_convs = [] - for idx, msg in enumerate(raw_convs): + for msg in raw_convs: # 2. Ensure each item in the list is a dictionary if not isinstance(msg, dict): # Skip if an element is not a dict (e.g., malformed like ["user", "hi"]) @@ -355,5 +355,5 @@ def safe_conversations_generator(file_path): yield {"conversations": cleaned_convs} except Exception as e: - print(f"⚠️ Skipping line {i}: {e}") + logger.warning(f"Skipping line {i + 1}: {e}") continue From 9c9407961dbf6ffcde7b21d6f70e457ddb1e89b3 Mon Sep 17 00:00:00 2001 From: canghua Date: Fri, 16 Jan 2026 12:47:00 +0800 Subject: [PATCH 04/10] tokenizer add max_length --- specforge/data/parse.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/specforge/data/parse.py b/specforge/data/parse.py index 19d8d25af..d9d75b10e 100644 --- a/specforge/data/parse.py +++ b/specforge/data/parse.py @@ -103,7 +103,7 @@ def parse( try: conversation = self.apply_chat_template( - messages, max_length=max_length, **kwargs + messages, **kwargs ) except (ValueError, TypeError): # Fallback rendering for tokenizers without built-in chat_template @@ -155,11 +155,11 @@ def parse( # --- Core Alternative Operation: Calculate Token Index Based on Prefix String Length --- # Encode the text "assistant start", the length of which is the position of the starting token. prefix_ids = self.tokenizer.encode( - conversation[:content_start_char], add_special_tokens=False + conversation[:content_start_char], add_special_tokens=False, truncation=True, max_length=max_length, ) # Encodes the text "assistant end", the length of which is the position of the end token. full_ids = self.tokenizer.encode( - conversation[:content_end_char], add_special_tokens=False + conversation[:content_end_char], add_special_tokens=False, truncation=True, max_length=max_length, ) start_token_idx = len(prefix_ids) From 83b83d82088ecf8baacd52214ddd932447ab44bd Mon Sep 17 00:00:00 2001 From: canghua Date: Fri, 16 Jan 2026 12:47:36 +0800 Subject: [PATCH 05/10] polish code --- specforge/data/parse.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/specforge/data/parse.py b/specforge/data/parse.py index d9d75b10e..39ec6680e 100644 --- a/specforge/data/parse.py +++ b/specforge/data/parse.py @@ -102,9 +102,7 @@ def parse( messages.append(sentence) try: - conversation = self.apply_chat_template( - messages, **kwargs - ) + conversation = self.apply_chat_template(messages, **kwargs) except (ValueError, TypeError): # Fallback rendering for tokenizers without built-in chat_template warnings.warn( @@ -155,11 +153,17 @@ def parse( # --- Core Alternative Operation: Calculate Token Index Based on Prefix String Length --- # Encode the text "assistant start", the length of which is the position of the starting token. prefix_ids = self.tokenizer.encode( - conversation[:content_start_char], add_special_tokens=False, truncation=True, max_length=max_length, + conversation[:content_start_char], + add_special_tokens=False, + truncation=True, + max_length=max_length, ) # Encodes the text "assistant end", the length of which is the position of the end token. full_ids = self.tokenizer.encode( - conversation[:content_end_char], add_special_tokens=False, truncation=True, max_length=max_length, + conversation[:content_end_char], + add_special_tokens=False, + truncation=True, + max_length=max_length, ) start_token_idx = len(prefix_ids) From 79324f8582142da98659d1ad6d8da2520fdb620f Mon Sep 17 00:00:00 2001 From: canghua Date: Fri, 16 Jan 2026 12:49:31 +0800 Subject: [PATCH 06/10] add repo-wiki template --- examples/repo-wiki.sh | 92 ++++++++++++++++++++++++++++++++++++++ specforge/data/template.py | 11 +++++ 2 files changed, 103 insertions(+) create mode 100755 examples/repo-wiki.sh diff --git a/examples/repo-wiki.sh b/examples/repo-wiki.sh new file mode 100755 index 000000000..3d48afe6b --- /dev/null +++ b/examples/repo-wiki.sh @@ -0,0 +1,92 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# support tp4/tp8 train eagle3 for Qwen3-30B-A3B + +# export TOKENIZERS_PARALLELISM=false + +NUM_GPUS=8 +TP_SIZE=2 + +TARGET_MODEL_PATH=/disk3/wjp/pretrained_models/Qwen3-Coder-30B-A3B-Instruct +TRAIN_DATA_PATH=/disk3/wjp/datasets/repowiki/data_for_SpecForge_test.jsonl + + + +# # Prepare hidden states +# export TORCH_NCCL_TIMEOUT_SEC=1800 +# torchrun \ +# --standalone \ +# --nproc_per_node $NUM_GPUS \ +# scripts/prepare_hidden_states.py \ +# --target-model-path $TARGET_MODEL_PATH \ +# --enable-aux-hidden-states \ +# --data-path $TRAIN_DATA_PATH \ +# --chat-template repo-wiki \ +# --tp-size $TP_SIZE \ +# --batch-size 4 \ +# --max-length 65536 \ +# --output-path $ROOT_DIR/outputs/repo-wiki/train_hidden_states \ +# --sglang-mem-fraction-static 0.8 + + + +# offline training +BUILD_DATASET_NUM_PROC=1 + +LOR_INTERNAL=200 +SAVE_INTERNAL=10 + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path $TARGET_MODEL_PATH \ + --train-hidden-states-path $ROOT_DIR/outputs/repo-wiki/train_hidden_states \ + --draft-model-config $ROOT_DIR/configs/qwen3-30B-A3B-eagle3.json \ + --train-data-path $TRAIN_DATA_PATH \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/repo-wiki \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 65536 \ + --chat-template repo-wiki \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size 1 \ + --report-to tensorboard \ + --save-interval $LOR_INTERNAL \ + --log-interval $SAVE_INTERNAL \ + --sp-ring-size 2 \ + --sp-ulysses-size 4 \ + --attention-backend usp + + +# online training +# torchrun \ +# --standalone \ +# --nproc_per_node $NUM_GPUS \ +# $ROOT_DIR/scripts/train_eagle3.py \ +# --target-model-path $TARGET_MODEL_PATH \ +# --draft-model-config $ROOT_DIR/configs/qwen3-30B-A3B-eagle3.json \ +# --train-data-path $TRAIN_DATA_PATH \ +# --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ +# --output-dir $ROOT_DIR/outputs/repo-wiki \ +# --num-epochs 10 \ +# --batch-size 1 \ +# --learning-rate 1e-4 \ +# --max-length 32768 \ +# --chat-template repo-wiki \ +# --cache-dir $ROOT_DIR/cache \ +# --embedding-key model.embed_tokens.weight \ +# --tp-size 1 \ +# --report-to tensorboard \ +# --save-interval $LOR_INTERNAL \ +# --log-interval $SAVE_INTERNAL \ +# --sp-ring-size 2 \ +# --sp-ulysses-size 4 \ +# --attention-backend usp \ No newline at end of file diff --git a/specforge/data/template.py b/specforge/data/template.py index 4ca009485..ff786ee9c 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -276,3 +276,14 @@ def get_all_template_names(self) -> List[str]: enable_thinking=True, ), ) + + +TEMPLATE_REGISTRY.register( + name="repo-wiki", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n", + user_header="<|im_start|>user\n", + system_prompt="", + end_of_turn_token="<|im_end|>\n", + ), +) From d00ee3993afa0778be006245ed4535011d9f4463 Mon Sep 17 00:00:00 2001 From: canghua Date: Wed, 21 Jan 2026 16:16:24 +0800 Subject: [PATCH 07/10] optimize long content training DRAM --- scripts/train_eagle3.py | 12 +++++++----- specforge/core/eagle3.py | 11 ++++++++--- specforge/modeling/target/target_head.py | 1 - 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 279bf91d7..685e4c637 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -546,6 +546,7 @@ def run_forward( target_model: Optional[Eagle3TargetModel] = None, is_online: bool = True, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + print(data["input_ids"].shape) if args.is_vlm: plosses, _, acces = eagle3_model( input_ids=data["input_ids"].cuda(), @@ -570,15 +571,16 @@ def run_forward( hidden_states = get_dp_data_shard_from_tp(eagle3_data.hidden_states) else: # we generate the logits using the hidden states loaded from disk - input_ids = data["input_ids"].cuda() attention_mask = data["attention_mask"].cuda() - loss_mask = data["loss_mask"].cuda() hidden_states = data["hidden_state"].cuda() - target = target_model(data["target"].cuda()) input_ids, target, loss_mask = target_model.preprocess( - input_ids, target, loss_mask + data["input_ids"], data["target"], data["loss_mask"] ) - + input_ids = input_ids.cuda() + target = target_model( + target.cuda() + ) # The `data['target']` value occupies a large amount of GPU memory, with a shape of [seqlen, vocab_size]. It needs to be processed before being loaded into the GPU. + loss_mask = loss_mask.cuda() plosses, _, acces = eagle3_model( input_ids=input_ids, attention_mask=attention_mask, diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index 34c6d6b5e..0f06951f1 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -118,6 +118,7 @@ def forward( length=self.length, ) del target + torch.cuda.empty_cache() # basic info batch_size, seq_length, _ = hidden_states.shape @@ -150,7 +151,7 @@ def forward( dtype=torch.bool, device=hidden_states.device, ) - if self.attention_backend in ("sdpa", "usp"): + if self.attention_backend == "sdpa": attention_mask = self.draft_model.prepare_decoder_attention_mask( attention_mask=attention_mask, hidden_states=hidden_states, @@ -206,8 +207,12 @@ def forward( hidden_states = hidden_states_out # Step 5.4: get logits - logits = self.draft_model.compute_logits(hidden_states) - logits = gather_outputs_and_unpad(logits, gather_dim=1) + logits_ = self.draft_model.compute_logits(hidden_states) + # from .forkedpdb import ForkedPdb + # ForkedPdb().set_trace() + logits = gather_outputs_and_unpad(logits_, gather_dim=1) + del logits_ + torch.cuda.empty_cache() # Step 5.5: record metrics first as we in-place modify logits with torch.no_grad(): acces.append( diff --git a/specforge/modeling/target/target_head.py b/specforge/modeling/target/target_head.py index 3918d6d69..7231117ce 100644 --- a/specforge/modeling/target/target_head.py +++ b/specforge/modeling/target/target_head.py @@ -89,5 +89,4 @@ def preprocess(self, input_ids, target, loss_mask): target = padding(target, left=False) input_ids = padding(input_ids, left=False) loss_mask = loss_mask[..., None] - loss_mask = loss_mask.to(target.device) return input_ids, target, loss_mask From 54124913fa12fa08429887aad3c28e95b6915c47 Mon Sep 17 00:00:00 2001 From: canghua Date: Wed, 21 Jan 2026 21:34:40 +0800 Subject: [PATCH 08/10] del unuseful file --- examples/repo-wiki.sh | 92 ---------------------------------------- specforge/core/eagle3.py | 2 - 2 files changed, 94 deletions(-) delete mode 100755 examples/repo-wiki.sh diff --git a/examples/repo-wiki.sh b/examples/repo-wiki.sh deleted file mode 100755 index 3d48afe6b..000000000 --- a/examples/repo-wiki.sh +++ /dev/null @@ -1,92 +0,0 @@ -#!/bin/bash - -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -ROOT_DIR=$(dirname $SCRIPT_DIR) -export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels - -# support tp4/tp8 train eagle3 for Qwen3-30B-A3B - -# export TOKENIZERS_PARALLELISM=false - -NUM_GPUS=8 -TP_SIZE=2 - -TARGET_MODEL_PATH=/disk3/wjp/pretrained_models/Qwen3-Coder-30B-A3B-Instruct -TRAIN_DATA_PATH=/disk3/wjp/datasets/repowiki/data_for_SpecForge_test.jsonl - - - -# # Prepare hidden states -# export TORCH_NCCL_TIMEOUT_SEC=1800 -# torchrun \ -# --standalone \ -# --nproc_per_node $NUM_GPUS \ -# scripts/prepare_hidden_states.py \ -# --target-model-path $TARGET_MODEL_PATH \ -# --enable-aux-hidden-states \ -# --data-path $TRAIN_DATA_PATH \ -# --chat-template repo-wiki \ -# --tp-size $TP_SIZE \ -# --batch-size 4 \ -# --max-length 65536 \ -# --output-path $ROOT_DIR/outputs/repo-wiki/train_hidden_states \ -# --sglang-mem-fraction-static 0.8 - - - -# offline training -BUILD_DATASET_NUM_PROC=1 - -LOR_INTERNAL=200 -SAVE_INTERNAL=10 - -torchrun \ - --standalone \ - --nproc_per_node $NUM_GPUS \ - $ROOT_DIR/scripts/train_eagle3.py \ - --target-model-path $TARGET_MODEL_PATH \ - --train-hidden-states-path $ROOT_DIR/outputs/repo-wiki/train_hidden_states \ - --draft-model-config $ROOT_DIR/configs/qwen3-30B-A3B-eagle3.json \ - --train-data-path $TRAIN_DATA_PATH \ - --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ - --output-dir $ROOT_DIR/outputs/repo-wiki \ - --num-epochs 10 \ - --batch-size 1 \ - --learning-rate 1e-4 \ - --max-length 65536 \ - --chat-template repo-wiki \ - --cache-dir $ROOT_DIR/cache \ - --embedding-key model.embed_tokens.weight \ - --tp-size 1 \ - --report-to tensorboard \ - --save-interval $LOR_INTERNAL \ - --log-interval $SAVE_INTERNAL \ - --sp-ring-size 2 \ - --sp-ulysses-size 4 \ - --attention-backend usp - - -# online training -# torchrun \ -# --standalone \ -# --nproc_per_node $NUM_GPUS \ -# $ROOT_DIR/scripts/train_eagle3.py \ -# --target-model-path $TARGET_MODEL_PATH \ -# --draft-model-config $ROOT_DIR/configs/qwen3-30B-A3B-eagle3.json \ -# --train-data-path $TRAIN_DATA_PATH \ -# --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ -# --output-dir $ROOT_DIR/outputs/repo-wiki \ -# --num-epochs 10 \ -# --batch-size 1 \ -# --learning-rate 1e-4 \ -# --max-length 32768 \ -# --chat-template repo-wiki \ -# --cache-dir $ROOT_DIR/cache \ -# --embedding-key model.embed_tokens.weight \ -# --tp-size 1 \ -# --report-to tensorboard \ -# --save-interval $LOR_INTERNAL \ -# --log-interval $SAVE_INTERNAL \ -# --sp-ring-size 2 \ -# --sp-ulysses-size 4 \ -# --attention-backend usp \ No newline at end of file diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index 5a7d5fda0..aad37e4f9 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -220,8 +220,6 @@ def forward( # Step 5.4: get logits logits_ = self.draft_model.compute_logits(hidden_states) - # from .forkedpdb import ForkedPdb - # ForkedPdb().set_trace() logits = gather_outputs_and_unpad(logits_, gather_dim=1) del logits_ torch.cuda.empty_cache() From 79975fc5437d608feb93339c5571d2bc38f06e7f Mon Sep 17 00:00:00 2001 From: canghua Date: Wed, 21 Jan 2026 23:35:39 +0800 Subject: [PATCH 09/10] use Activation Checkpointing to compute loss --- specforge/core/eagle3.py | 49 ++++++++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index aad37e4f9..abf43527f 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -25,6 +25,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint from transformers.cache_utils import DynamicCache from yunchang import EXTRACT_FUNC_DICT @@ -176,6 +177,24 @@ def forward( past_key_values_length=past_key_values_length, ) + def compute_loss_and_acc_checkpointed(hs, tgt_p, pos_mask, l_mask): + # 1. Compute Logits(The part that consumes the most VRAM.) + logits_ = self.draft_model.compute_logits(hs) + logits = gather_outputs_and_unpad(logits_, gather_dim=1) + + # 2. Compute Loss + loss_val = LogSoftmaxLoss.apply(logits, tgt_p, pos_mask) + + # 3. Compute Accuracy + with torch.no_grad(): + acc_val = _compute_metric_acc( + logits=logits, + target_p=tgt_p, + position_mask=pos_mask, + loss_mask=l_mask, + ) + return loss_val, acc_val + # Step 5: run TTT plosses = [] vlosses = [] @@ -218,26 +237,22 @@ def forward( # update hidden states for next step hidden_states = hidden_states_out - # Step 5.4: get logits - logits_ = self.draft_model.compute_logits(hidden_states) - logits = gather_outputs_and_unpad(logits_, gather_dim=1) - del logits_ - torch.cuda.empty_cache() - # Step 5.5: record metrics first as we in-place modify logits - with torch.no_grad(): - acces.append( - _compute_metric_acc( - logits=logits, - target_p=target_p, - position_mask=position_mask, - loss_mask=loss_mask, - ) + if hidden_states.requires_grad: + loss, acc = checkpoint( + compute_loss_and_acc_checkpointed, + hidden_states, + target_p, + position_mask, + loss_mask, + use_reentrant=False, + ) + else: + loss, acc = compute_loss_and_acc_checkpointed( + hidden_states, target_p, position_mask, loss_mask ) - # Step 5.6: calculate loss, in-place modifies logits! - loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) plosses.append(loss) - + acces.append(acc) if not is_last: # Step 5.7: we need to update the loss mask global_input_ids = padding(global_input_ids, left=False) From 5adf8402bf85b35de183f5d48942970bc26c34f0 Mon Sep 17 00:00:00 2001 From: jiapingW <1969554248@qq.com> Date: Thu, 22 Jan 2026 16:46:21 +0800 Subject: [PATCH 10/10] can modify dataset rather than load from default cache --- scripts/prepare_hidden_states.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/scripts/prepare_hidden_states.py b/scripts/prepare_hidden_states.py index 4df35011a..d201ca479 100644 --- a/scripts/prepare_hidden_states.py +++ b/scripts/prepare_hidden_states.py @@ -473,7 +473,6 @@ def generate( filtered_batch_gpu = { k: v.cuda(non_blocking=True) for k, v in filtered_batch.items() } - _, _, aux_hidden_states_list, last_hidden_states_list = self.model.extend( **filtered_batch_gpu, return_last_hidden_states=True, @@ -579,11 +578,16 @@ def main(): args.data_path ), f"Dataset path {args.data_path} does not exist" dataset = Dataset.from_generator( - generator=safe_conversations_generator, gen_kwargs={"file_path": args.data_path} + generator=safe_conversations_generator, + gen_kwargs={"file_path": args.data_path}, + cache_dir=os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "cache", + "hf_dataset", + ), ) if args.num_samples is not None: dataset = dataset.select(range(args.num_samples)) - # Tokenizer and cache key tokenizer = AutoTokenizer.from_pretrained( args.target_model_path, trust_remote_code=True @@ -649,7 +653,7 @@ def main(): # Pass configurable arguments from args if needed with HiddenStatesGenerator( target_model, - args.enable_aux_hidden_states, + enable_aux_hidden_states=args.enable_aux_hidden_states, num_io_threads=args.num_io_threads, io_queue_size=args.io_queue_size, file_group_size=args.file_group_size,