From 5b509e87c2faa2b482f3fea32d598df914925874 Mon Sep 17 00:00:00 2001 From: Mohammed-Faizzzz <110959467+Mohammed-Faizzzz@users.noreply.github.com> Date: Mon, 23 Jun 2025 19:33:29 +0100 Subject: [PATCH 1/3] Add robust error handling to LoRA validation logic This commit improves fault tolerance and clarity in the LoRA validation module by adding structured try-except blocks around critical operations. The goal is to fail early, clearly, and recoverably when encountering invalid configuration or runtime issues during model evaluation. - Validates LoRA config before TrainingArguments instantiation - Catches and logs errors when loading tokenizer, dataset, model, or trainer - Ensures fallback metrics are returned if evaluation fails gracefully - Adds informative logging for debugging and traceability This improves resilience against malformed adapter configs, missing files, and resource-related evaluation crashes. --- validator/modules/lora/__init__.py | 237 +++++++++++++++++++++++------ 1 file changed, 190 insertions(+), 47 deletions(-) diff --git a/validator/modules/lora/__init__.py b/validator/modules/lora/__init__.py index e4ae232..5316cb0 100644 --- a/validator/modules/lora/__init__.py +++ b/validator/modules/lora/__init__.py @@ -38,6 +38,18 @@ class LoRAValidationModule(BaseValidationModule): def __init__(self, config: LoRAConfig, **kwargs): # Store the config for later use self.config = config + + def validate_config(self): + """Validate logical correctness of LoRA config values.""" + import os + + if self.config.per_device_eval_batch_size <= 0: + raise InvalidConfigValueException("`per_device_eval_batch_size` must be > 0.") + if not self.config.output_dir.strip(): + raise InvalidConfigValueException("`output_dir` must not be empty or blank.") + if not os.path.isdir(self.config.output_dir): + raise InvalidConfigValueException(f"`output_dir` '{self.config.output_dir}' does not exist or is not a directory.") + def validate(self, data: LoRAInputData, **kwargs) -> LoRAMetrics: """Run the validation procedure. @@ -46,6 +58,9 @@ def validate(self, data: LoRAInputData, **kwargs) -> LoRAMetrics: pared-down so that it only performs the forward pass and metric computation needed for the validator. All networking / ledger interaction has been removed. + + For error handling, we use `InvalidConfigValueException` instead of simply returning as + was previously done. Any errors will be bubbled up to the validation runner. """ import json import math @@ -57,6 +72,7 @@ def validate(self, data: LoRAInputData, **kwargs) -> LoRAMetrics: from loguru import logger from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer from peft import PeftModel + from json import JSONDecodeError from .core import ( SFTDataCollator, @@ -73,15 +89,30 @@ def validate(self, data: LoRAInputData, **kwargs) -> LoRAMetrics: # ------------------------------------------------------------------ # Build HF `TrainingArguments` object from config # ------------------------------------------------------------------ - val_args = TrainingArguments( - output_dir=self.config.output_dir, - per_device_eval_batch_size=self.config.per_device_eval_batch_size, - fp16=self.config.fp16, - remove_unused_columns=self.config.remove_unused_columns, - do_train=False, - do_eval=True, - report_to=[], # silence wandb etc. - ) + # Validate config values before creating TrainingArguments (should be done in __init__ ideally) + logger.info("Validating config...") + try: + self.validate_config() + except InvalidConfigValueException as e: + logger.error(f"Invalid config: {e}") + raise InvalidConfigValueException(f"Config validation failed: {e}") + logger.info("Config validation passed.") + + # Enclose in try-except to catch any issues with TrainingArguments creation because there might be issues with the config values + try: + logger.info("Creating TrainingArguments...") + val_args = TrainingArguments( + output_dir=self.config.output_dir, + per_device_eval_batch_size=self.config.per_device_eval_batch_size, + fp16=self.config.fp16, + remove_unused_columns=self.config.remove_unused_columns, + do_train=False, + do_eval=True, + report_to=[], # silence wandb etc. + ) + except Exception as e: + logger.error(f"Failed to create TrainingArguments: {e}") + raise InvalidConfigValueException(f"TrainingArguments creation failed: {e}") # Might be good to introduce more exceptions for specific cases model_repo = data.hg_repo_id revision = data.revision or "main" @@ -90,42 +121,123 @@ def validate(self, data: LoRAInputData, **kwargs) -> LoRAMetrics: # ------------------------------------------------------------------ # Determine if the repo contains LoRA adapter weights or a full model # ------------------------------------------------------------------ + # Add more robust error handling for LoRA detection is_lora = download_lora_config(model_repo, revision) + if is_lora: adapter_cfg_path = Path("lora/adapter_config.json") - try: - with open(adapter_cfg_path, "r") as f: - adapter_cfg = json.load(f) - base_model_path = adapter_cfg.get("base_model_name_or_path") - if base_model_path: + if not adapter_cfg_path.exists(): + # is_lora is True, but adapter_config.json does not exist + logger.error( + f"Model {model_repo} is identified as LoRA, but its adapter_config.json was not downloaded or found at {adapter_config_path}. " + f"This could be due to an issue with 'download_lora_config' or the repository structure for the LoRA model. " + ) + raise InvalidConfigValueException(f"adapter_config.json not found for LoRA model {model_repo}. ") + else: + try: + with open(adapter_cfg_path, "r") as f: + adapter_cfg = json.load(f) + + base_model_path = adapter_cfg.get("base_model_name_or_path") + if not base_model_path or not base_model_path.strip(): + logger.error( + f"LoRA model {model_repo} does not specify 'base_model_name_or_path' " + ) + return # exit function early + tokenizer_repo = base_model_path - except Exception as e: - logger.warning(f"Failed to parse adapter_config.json: {e}") - + + except (FileNotFoundError, JSONDecodeError) as e: # case where adapter_config.json is missing is already handled + logger.error(f"Failed to read adapter_config.json: {e}") + raise InvalidConfigValueException(f"adapter_config.json parsing failed: {e}") + except InvalidConfigValueException as e: + logger.error(str(e)) + raise + except Exception as e: + logger.error(f"Unexpected error reading adapter_config.json: {e}") + raise + else: + logger.info( + f"Model {model_repo} is not identified as a LoRA model. " + f"Using its own path for tokenizer: {model_repo}." + ) # ------------------------------------------------------------------ # Tokeniser # ------------------------------------------------------------------ - tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo, use_fast=True) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + # Enclose the tokenizer loading in a try-except block to handle potential errors + def load_tokenizer(model_name_or_path: str) -> AutoTokenizer: # move this to a separate file and import it + tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + use_fast=True, + ) + if "gemma" in model_name_or_path.lower(): + tokenizer.add_special_tokens( + {"additional_special_tokens": ["", ""]} + ) + + if tokenizer.__class__.__name__ == "QWenTokenizer": + tokenizer.pad_token_id = tokenizer.eod_id + tokenizer.bos_token_id = tokenizer.eod_id + tokenizer.eos_token_id = tokenizer.eod_id + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + assert tokenizer.pad_token_id is not None, "pad_token_id should not be None" + assert tokenizer.eos_token_id is not None, "eos_token_id should not be None" + logger.info(f"vocab_size of tokenizer: {tokenizer.vocab_size}") + return tokenizer + + try: + logger.info("Loading tokenizer …") + tokenizer = load_tokenizer(tokenizer_repo) + except Exception as e: + logger.error(f"Failed to load tokenizer from {tokenizer_repo}: {e}") + raise InvalidConfigValueException(f"Tokenizer loading failed: {e}") # ------------------------------------------------------------------ # Evaluation dataset # ------------------------------------------------------------------ + template_name = data.base_model if data.base_model in template_dict else "default" template = template_dict[template_name] - eval_dataset = UnifiedSFTDataset( - file=data.eval_file, - tokenizer=tokenizer, - max_seq_length=data.context_length, - template=template, - ) - - total_bytes, total_target_tokens = calculate_bytes_and_tokens( - eval_dataset, tokenizer, logger - ) - token_byte_ratio_value = get_token_byte_ratio(total_target_tokens, total_bytes) - + + def load_sft_dataset(eval_file: str, max_seq_length: int, template_name: str, tokenizer: AutoTokenizer) -> UnifiedSFTDataset: + if template_name not in template_dict.keys(): + raise ValueError( + f"template_name doesn't exist, all template_name: {template_dict.keys()}" + ) + template = template_dict[template_name] + logger.info("Loading data with UnifiedSFTDataset") + return UnifiedSFTDataset(eval_file, tokenizer, max_seq_length, template) + + try: + # Possible errors here: data.eval_file not found, etc. + eval_dataset = load_sft_dataset( + eval_file=data.eval_file, + max_seq_length=data.context_length, + template_name=template, + tokenizer=tokenizer + ) + + total_bytes, total_target_tokens = calculate_bytes_and_tokens(eval_dataset, tokenizer, logger) + if total_bytes == 0: + logger.warning( + "Total bytes in the evaluation dataset is 0. Cannot calculate BPC. Check dataset processing." + ) + else: + logger.info(f"Total target tokens (T): {total_target_tokens}") + logger.info(f"Total target bytes (B): {total_bytes}") + token_byte_ratio_value = get_token_byte_ratio( + total_target_tokens, total_bytes + ) + logger.info(f"Token/Byte ratio (T/B): {token_byte_ratio_value:.4f}") + if token_byte_ratio_value < 0.1: + logger.warning( + f"Token/Byte ratio ({token_byte_ratio_value:.4f}) is unusually low. Potential manipulation detected." + ) + except Exception as e: + logger.error(f"Failed to load evaluation dataset from {data.eval_file}: {e}") + raise InvalidConfigValueException(f"Evaluation dataset loading failed: {e}") + # ------------------------------------------------------------------ # Model loading helper # ------------------------------------------------------------------ @@ -147,7 +259,15 @@ def _load_model() -> AutoModelForCausalLM: return AutoModelForCausalLM.from_pretrained(model_repo, **model_kwargs) logger.info("Loading model …") - model = _load_model() + try: + model = _load_model() + except Exception as e: + logger.error(f"Exception occurred while loading model: {e}") + raise InvalidConfigValueException(f"Model loading failed: {e}") + + if model is None: + logger.error(f"Failed to load model from {model_repo}.") + raise InvalidConfigValueException(f"Model loading failed: {model_repo}.") # Simple parameter count check total_params = sum(p.numel() for p in model.parameters()) @@ -168,25 +288,41 @@ def _load_model() -> AutoModelForCausalLM: # ------------------------------------------------------------------ # Prepare trainer and run evaluation # ------------------------------------------------------------------ - data_collator = SFTDataCollator(tokenizer, max_seq_length=data.context_length) - trainer = Trainer( - model=model, - args=val_args, - eval_dataset=eval_dataset, - tokenizer=tokenizer, - data_collator=data_collator, - ) + try: + data_collator = SFTDataCollator(tokenizer, max_seq_length=data.context_length) + trainer = Trainer( + model=model, + args=val_args, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + data_collator=data_collator, + ) + except Exception as e: + logger.error(f"Failed to set up Trainer or data collator: {e}") + raise InvalidConfigValueException(f"Trainer setup failed: {e}") - logger.info("Running evaluation …") - eval_metrics = trainer.evaluate() - eval_loss = eval_metrics.get("eval_loss", float("nan")) + try: + logger.info("Running evaluation …") + eval_metrics = trainer.evaluate() + eval_loss = eval_metrics.get("eval_loss", float("nan")) + except Exception as e: + logger.error(f"Model evaluation failed: {e}") + raise RuntimeError(f"Evaluation failed: {e}") # ------------------------------------------------------------------ # Compute derived metrics # ------------------------------------------------------------------ - bpc_metrics = calculate_bpc_bppl_metrics( - eval_loss, total_target_tokens, total_bytes - ) + bpc_metrics = { # for failsafe scenario if total_bytes is 0 + "loss": float("nan"), + "bpc": float("inf"), + "bppl": float("inf"), + "nll_token_nats_total": float("nan"), + "nll_token_bits_total": float("nan"), + } + if total_bytes > 0: + bpc_metrics = calculate_bpc_bppl_metrics( + eval_loss, total_target_tokens, total_bytes + ) _log_summary_table( model_name_or_path=model_repo, @@ -211,6 +347,13 @@ def _load_model() -> AutoModelForCausalLM: # Clean-up GPU / CPU memory – important when running many validations # ------------------------------------------------------------------ gc.collect() + if model is not None: + logger.debug("Offloading model to save memory") + model.cpu() + del model + if eval_dataset is not None: + logger.debug("Offloading eval_dataset to save memory") + del eval_dataset torch.cuda.empty_cache() # purge temporary lora directory if it exists if os.path.exists("lora"): From 44a31210bcac05866cfbdd1ed8b9891fda9ebb31 Mon Sep 17 00:00:00 2001 From: Mohammed-Faizzzz <110959467+Mohammed-Faizzzz@users.noreply.github.com> Date: Tue, 22 Jul 2025 13:37:45 +0100 Subject: [PATCH 2/3] Add check for supported base models --- validator/modules/lora/__init__.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/validator/modules/lora/__init__.py b/validator/modules/lora/__init__.py index 5316cb0..e8ba50b 100644 --- a/validator/modules/lora/__init__.py +++ b/validator/modules/lora/__init__.py @@ -1,5 +1,6 @@ from validator.exceptions import RecoverableException from validator.modules.base import BaseValidationModule, BaseConfig, BaseInputData, BaseMetrics +from constant import SUPPORTED_BASE_MODELS # When raised, the assignment won't be marked as failed automatically and it will be retried after the user # fixes the problem and restarts the process. @@ -134,6 +135,9 @@ def validate(self, data: LoRAInputData, **kwargs) -> LoRAMetrics: ) raise InvalidConfigValueException(f"adapter_config.json not found for LoRA model {model_repo}. ") else: + logger.info( + f"Model {model_repo} is a LoRA model. Validating its base model for tokenizer." + ) try: with open(adapter_cfg_path, "r") as f: adapter_cfg = json.load(f) @@ -144,7 +148,10 @@ def validate(self, data: LoRAInputData, **kwargs) -> LoRAMetrics: f"LoRA model {model_repo} does not specify 'base_model_name_or_path' " ) return # exit function early - + if base_model_path not in SUPPORTED_BASE_MODELS: # need to define SUPPORTED_BASE_MODELS + logger.error( + f"LoRA's base model '{base_model_path}' is not in SUPPORTED_BASE_MODELS. " + ) tokenizer_repo = base_model_path except (FileNotFoundError, JSONDecodeError) as e: # case where adapter_config.json is missing is already handled From c81d5538783f1a2b8b802baf6bbc2007ef47b2be Mon Sep 17 00:00:00 2001 From: Mohammed-Faizzzz <110959467+Mohammed-Faizzzz@users.noreply.github.com> Date: Wed, 23 Jul 2025 09:48:50 +0100 Subject: [PATCH 3/3] Raise Exception for unsupported model --- validator/modules/lora/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/validator/modules/lora/__init__.py b/validator/modules/lora/__init__.py index e8ba50b..3c815cd 100644 --- a/validator/modules/lora/__init__.py +++ b/validator/modules/lora/__init__.py @@ -147,11 +147,14 @@ def validate(self, data: LoRAInputData, **kwargs) -> LoRAMetrics: logger.error( f"LoRA model {model_repo} does not specify 'base_model_name_or_path' " ) - return # exit function early + raise InvalidConfigValueException(f"LoRA model {model_repo} does not specify 'base_model_name_or_path'.") if base_model_path not in SUPPORTED_BASE_MODELS: # need to define SUPPORTED_BASE_MODELS logger.error( f"LoRA's base model '{base_model_path}' is not in SUPPORTED_BASE_MODELS. " ) + raise InvalidConfigValueException( + f"LoRA's base model '{base_model_path}' is not from a supported list." + ) tokenizer_repo = base_model_path except (FileNotFoundError, JSONDecodeError) as e: # case where adapter_config.json is missing is already handled