diff --git a/eval/task.py b/eval/task.py index 70962115..204c2a12 100644 --- a/eval/task.py +++ b/eval/task.py @@ -15,6 +15,9 @@ from lm_eval.api.instance import Instance from lm_eval.api.model import LM +# Safety buffer for vLLM max_gen_toks calculation +VLLM_SAFETY_BUFFER_TOKENS = 16 + class BaseBenchmark(ABC): """Abstract base class for implementing LLM evaluation benchmarks.""" @@ -53,7 +56,37 @@ def _normalize_model_args(self, model: LM, instances: List[Instance]) -> List[In if "4o" in model.model: instance.args[1]["max_tokens"] = min(max_new_tokens, 16384) elif isinstance(model, lm_eval_models.vllm_causallms.VLLM): - instance.args[1]["max_gen_toks"] = max_new_tokens + try: + # Get prompt from instance.args[0] (the templated string) + prompt = instance.args[0] + prompt_length = len(model.tokenizer.encode(prompt)) + + # Get max model length from vLLM engine + max_model_len = model.model.llm_engine.model_config.max_model_len + + # Check if prompt itself exceeds model capacity + if prompt_length > max_model_len: + self.logger.warning( + f"Prompt length ({prompt_length}) exceeds model max length ({max_model_len}). " + f"Prompt will be truncated with no room for generation." + ) + + # Calculate max allowed generation tokens (16 token safety buffer) + max_allowed = max_model_len - prompt_length - VLLM_SAFETY_BUFFER_TOKENS + capped_max_new_tokens = min(max_new_tokens, max(1, max_allowed)) + + if capped_max_new_tokens < max_new_tokens: + self.logger.warning( + f"max_new_tokens ({max_new_tokens}) capped to {capped_max_new_tokens} " + f"(prompt: {prompt_length} tokens, model max: {max_model_len})" + ) + + instance.args[1]["max_gen_toks"] = capped_max_new_tokens + except Exception as e: + self.logger.warning( + f"Failed to calculate max_gen_toks for vLLM, using original value: {e}" + ) + instance.args[1]["max_gen_toks"] = max_new_tokens else: # Huggingface instance.args[1]["max_new_tokens"] = max_new_tokens return instances