From b1e5ffbbb12b07f600d62cd46b7d5c4e48705e2f Mon Sep 17 00:00:00 2001 From: xiaonengmiao Date: Tue, 20 Jan 2026 16:31:12 +0800 Subject: [PATCH] fix: typo and dict access via dot notation bug --- eagle/traineagle3/cnets.py | 4 ++-- eagle/traineagle3/main.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/eagle/traineagle3/cnets.py b/eagle/traineagle3/cnets.py index 4c623002..d379d1af 100644 --- a/eagle/traineagle3/cnets.py +++ b/eagle/traineagle3/cnets.py @@ -489,7 +489,7 @@ def __init__(self, config, ds_config, training_config, load_head=False, load_emb else: dschf = None self.midlayer = LlamaDecoderLayeremb(config) - self.gradient_checkpointing = self.train_config.gradient_checkpointing + self.gradient_checkpointing = self.train_config["gradient_checkpointing"] self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.hidden_size = config.hidden_size @@ -588,7 +588,7 @@ def preprocess_function(examples): # When construct draft model vocab, # filter out samples which is longer than max_len, # instead of truncating them. - if len(input_ids) > self.train_config.max_len: + if len(input_ids) > self.train_config["max_len"]: continue loss_mask = torch.ones_like(input_ids) # print(i) diff --git a/eagle/traineagle3/main.py b/eagle/traineagle3/main.py index 37012a16..a6efa46b 100644 --- a/eagle/traineagle3/main.py +++ b/eagle/traineagle3/main.py @@ -23,7 +23,7 @@ "num_workers": 2, "max_len": 2048, "config_path": "config.json", - "gradient_checkpoint": True + "gradient_checkpointing": True } from safetensors import safe_open