Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 200 additions & 47 deletions validator/modules/lora/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from validator.exceptions import RecoverableException
from validator.modules.base import BaseValidationModule, BaseConfig, BaseInputData, BaseMetrics
from constant import SUPPORTED_BASE_MODELS

# When raised, the assignment won't be marked as failed automatically and it will be retried after the user
# fixes the problem and restarts the process.
Expand Down Expand Up @@ -38,6 +39,18 @@ class LoRAValidationModule(BaseValidationModule):
def __init__(self, config: LoRAConfig, **kwargs):
# Store the config for later use
self.config = config

def validate_config(self):
"""Validate logical correctness of LoRA config values."""
import os

if self.config.per_device_eval_batch_size <= 0:
raise InvalidConfigValueException("`per_device_eval_batch_size` must be > 0.")
if not self.config.output_dir.strip():
raise InvalidConfigValueException("`output_dir` must not be empty or blank.")
if not os.path.isdir(self.config.output_dir):
raise InvalidConfigValueException(f"`output_dir` '{self.config.output_dir}' does not exist or is not a directory.")


def validate(self, data: LoRAInputData, **kwargs) -> LoRAMetrics:
"""Run the validation procedure.
Expand All @@ -46,6 +59,9 @@ def validate(self, data: LoRAInputData, **kwargs) -> LoRAMetrics:
pared-down so that it only performs the forward pass and metric
computation needed for the validator. All networking / ledger interaction
has been removed.

For error handling, we use `InvalidConfigValueException` instead of simply returning as
was previously done. Any errors will be bubbled up to the validation runner.
"""
import json
import math
Expand All @@ -57,6 +73,7 @@ def validate(self, data: LoRAInputData, **kwargs) -> LoRAMetrics:
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import PeftModel
from json import JSONDecodeError

from .core import (
SFTDataCollator,
Expand All @@ -73,15 +90,30 @@ def validate(self, data: LoRAInputData, **kwargs) -> LoRAMetrics:
# ------------------------------------------------------------------
# Build HF `TrainingArguments` object from config
# ------------------------------------------------------------------
val_args = TrainingArguments(
output_dir=self.config.output_dir,
per_device_eval_batch_size=self.config.per_device_eval_batch_size,
fp16=self.config.fp16,
remove_unused_columns=self.config.remove_unused_columns,
do_train=False,
do_eval=True,
report_to=[], # silence wandb etc.
)
# Validate config values before creating TrainingArguments (should be done in __init__ ideally)
logger.info("Validating config...")
try:
self.validate_config()
except InvalidConfigValueException as e:
logger.error(f"Invalid config: {e}")
raise InvalidConfigValueException(f"Config validation failed: {e}")
logger.info("Config validation passed.")

# Enclose in try-except to catch any issues with TrainingArguments creation because there might be issues with the config values
try:
logger.info("Creating TrainingArguments...")
val_args = TrainingArguments(
output_dir=self.config.output_dir,
per_device_eval_batch_size=self.config.per_device_eval_batch_size,
fp16=self.config.fp16,
remove_unused_columns=self.config.remove_unused_columns,
do_train=False,
do_eval=True,
report_to=[], # silence wandb etc.
)
except Exception as e:
logger.error(f"Failed to create TrainingArguments: {e}")
raise InvalidConfigValueException(f"TrainingArguments creation failed: {e}") # Might be good to introduce more exceptions for specific cases

model_repo = data.hg_repo_id
revision = data.revision or "main"
Expand All @@ -90,42 +122,132 @@ def validate(self, data: LoRAInputData, **kwargs) -> LoRAMetrics:
# ------------------------------------------------------------------
# Determine if the repo contains LoRA adapter weights or a full model
# ------------------------------------------------------------------
# Add more robust error handling for LoRA detection
is_lora = download_lora_config(model_repo, revision)

if is_lora:
adapter_cfg_path = Path("lora/adapter_config.json")
try:
with open(adapter_cfg_path, "r") as f:
adapter_cfg = json.load(f)
base_model_path = adapter_cfg.get("base_model_name_or_path")
if base_model_path:
if not adapter_cfg_path.exists():
# is_lora is True, but adapter_config.json does not exist
logger.error(
f"Model {model_repo} is identified as LoRA, but its adapter_config.json was not downloaded or found at {adapter_config_path}. "
f"This could be due to an issue with 'download_lora_config' or the repository structure for the LoRA model. "
)
raise InvalidConfigValueException(f"adapter_config.json not found for LoRA model {model_repo}. ")
else:
logger.info(
f"Model {model_repo} is a LoRA model. Validating its base model for tokenizer."
)
try:
with open(adapter_cfg_path, "r") as f:
adapter_cfg = json.load(f)

base_model_path = adapter_cfg.get("base_model_name_or_path")
if not base_model_path or not base_model_path.strip():
logger.error(
f"LoRA model {model_repo} does not specify 'base_model_name_or_path' "
)
raise InvalidConfigValueException(f"LoRA model {model_repo} does not specify 'base_model_name_or_path'.")
if base_model_path not in SUPPORTED_BASE_MODELS: # need to define SUPPORTED_BASE_MODELS
logger.error(
f"LoRA's base model '{base_model_path}' is not in SUPPORTED_BASE_MODELS. "
)
raise InvalidConfigValueException(
f"LoRA's base model '{base_model_path}' is not from a supported list."
)
tokenizer_repo = base_model_path
except Exception as e:
logger.warning(f"Failed to parse adapter_config.json: {e}")


except (FileNotFoundError, JSONDecodeError) as e: # case where adapter_config.json is missing is already handled
logger.error(f"Failed to read adapter_config.json: {e}")
raise InvalidConfigValueException(f"adapter_config.json parsing failed: {e}")
except InvalidConfigValueException as e:
logger.error(str(e))
raise
except Exception as e:
logger.error(f"Unexpected error reading adapter_config.json: {e}")
raise
else:
logger.info(
f"Model {model_repo} is not identified as a LoRA model. "
f"Using its own path for tokenizer: {model_repo}."
)
# ------------------------------------------------------------------
# Tokeniser
# ------------------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Enclose the tokenizer loading in a try-except block to handle potential errors
def load_tokenizer(model_name_or_path: str) -> AutoTokenizer: # move this to a separate file and import it
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
use_fast=True,
)
if "gemma" in model_name_or_path.lower():
tokenizer.add_special_tokens(
{"additional_special_tokens": ["<start_of_turn>", "<end_of_turn>"]}
)

if tokenizer.__class__.__name__ == "QWenTokenizer":
tokenizer.pad_token_id = tokenizer.eod_id
tokenizer.bos_token_id = tokenizer.eod_id
tokenizer.eos_token_id = tokenizer.eod_id
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
assert tokenizer.pad_token_id is not None, "pad_token_id should not be None"
assert tokenizer.eos_token_id is not None, "eos_token_id should not be None"
logger.info(f"vocab_size of tokenizer: {tokenizer.vocab_size}")
return tokenizer

try:
logger.info("Loading tokenizer …")
tokenizer = load_tokenizer(tokenizer_repo)
except Exception as e:
logger.error(f"Failed to load tokenizer from {tokenizer_repo}: {e}")
raise InvalidConfigValueException(f"Tokenizer loading failed: {e}")

# ------------------------------------------------------------------
# Evaluation dataset
# ------------------------------------------------------------------

template_name = data.base_model if data.base_model in template_dict else "default"
template = template_dict[template_name]
eval_dataset = UnifiedSFTDataset(
file=data.eval_file,
tokenizer=tokenizer,
max_seq_length=data.context_length,
template=template,
)

total_bytes, total_target_tokens = calculate_bytes_and_tokens(
eval_dataset, tokenizer, logger
)
token_byte_ratio_value = get_token_byte_ratio(total_target_tokens, total_bytes)


def load_sft_dataset(eval_file: str, max_seq_length: int, template_name: str, tokenizer: AutoTokenizer) -> UnifiedSFTDataset:
if template_name not in template_dict.keys():
raise ValueError(
f"template_name doesn't exist, all template_name: {template_dict.keys()}"
)
template = template_dict[template_name]
logger.info("Loading data with UnifiedSFTDataset")
return UnifiedSFTDataset(eval_file, tokenizer, max_seq_length, template)

try:
# Possible errors here: data.eval_file not found, etc.
eval_dataset = load_sft_dataset(
eval_file=data.eval_file,
max_seq_length=data.context_length,
template_name=template,
tokenizer=tokenizer
)

total_bytes, total_target_tokens = calculate_bytes_and_tokens(eval_dataset, tokenizer, logger)
if total_bytes == 0:
logger.warning(
"Total bytes in the evaluation dataset is 0. Cannot calculate BPC. Check dataset processing."
)
else:
logger.info(f"Total target tokens (T): {total_target_tokens}")
logger.info(f"Total target bytes (B): {total_bytes}")
token_byte_ratio_value = get_token_byte_ratio(
total_target_tokens, total_bytes
)
logger.info(f"Token/Byte ratio (T/B): {token_byte_ratio_value:.4f}")
if token_byte_ratio_value < 0.1:
logger.warning(
f"Token/Byte ratio ({token_byte_ratio_value:.4f}) is unusually low. Potential manipulation detected."
)
except Exception as e:
logger.error(f"Failed to load evaluation dataset from {data.eval_file}: {e}")
raise InvalidConfigValueException(f"Evaluation dataset loading failed: {e}")

# ------------------------------------------------------------------
# Model loading helper
# ------------------------------------------------------------------
Expand All @@ -147,7 +269,15 @@ def _load_model() -> AutoModelForCausalLM:
return AutoModelForCausalLM.from_pretrained(model_repo, **model_kwargs)

logger.info("Loading model …")
model = _load_model()
try:
model = _load_model()
except Exception as e:
logger.error(f"Exception occurred while loading model: {e}")
raise InvalidConfigValueException(f"Model loading failed: {e}")

if model is None:
logger.error(f"Failed to load model from {model_repo}.")
raise InvalidConfigValueException(f"Model loading failed: {model_repo}.")

# Simple parameter count check
total_params = sum(p.numel() for p in model.parameters())
Expand All @@ -168,25 +298,41 @@ def _load_model() -> AutoModelForCausalLM:
# ------------------------------------------------------------------
# Prepare trainer and run evaluation
# ------------------------------------------------------------------
data_collator = SFTDataCollator(tokenizer, max_seq_length=data.context_length)
trainer = Trainer(
model=model,
args=val_args,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
try:
data_collator = SFTDataCollator(tokenizer, max_seq_length=data.context_length)
trainer = Trainer(
model=model,
args=val_args,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
except Exception as e:
logger.error(f"Failed to set up Trainer or data collator: {e}")
raise InvalidConfigValueException(f"Trainer setup failed: {e}")

logger.info("Running evaluation …")
eval_metrics = trainer.evaluate()
eval_loss = eval_metrics.get("eval_loss", float("nan"))
try:
logger.info("Running evaluation …")
eval_metrics = trainer.evaluate()
eval_loss = eval_metrics.get("eval_loss", float("nan"))
except Exception as e:
logger.error(f"Model evaluation failed: {e}")
raise RuntimeError(f"Evaluation failed: {e}")

# ------------------------------------------------------------------
# Compute derived metrics
# ------------------------------------------------------------------
bpc_metrics = calculate_bpc_bppl_metrics(
eval_loss, total_target_tokens, total_bytes
)
bpc_metrics = { # for failsafe scenario if total_bytes is 0
"loss": float("nan"),
"bpc": float("inf"),
"bppl": float("inf"),
"nll_token_nats_total": float("nan"),
"nll_token_bits_total": float("nan"),
}
if total_bytes > 0:
bpc_metrics = calculate_bpc_bppl_metrics(
eval_loss, total_target_tokens, total_bytes
)

_log_summary_table(
model_name_or_path=model_repo,
Expand All @@ -211,6 +357,13 @@ def _load_model() -> AutoModelForCausalLM:
# Clean-up GPU / CPU memory – important when running many validations
# ------------------------------------------------------------------
gc.collect()
if model is not None:
logger.debug("Offloading model to save memory")
model.cpu()
del model
if eval_dataset is not None:
logger.debug("Offloading eval_dataset to save memory")
del eval_dataset
torch.cuda.empty_cache()
# purge temporary lora directory if it exists
if os.path.exists("lora"):
Expand Down