Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion eval/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM

# Safety buffer for vLLM max_gen_toks calculation
VLLM_SAFETY_BUFFER_TOKENS = 16


class BaseBenchmark(ABC):
"""Abstract base class for implementing LLM evaluation benchmarks."""
Expand Down Expand Up @@ -53,7 +56,37 @@ def _normalize_model_args(self, model: LM, instances: List[Instance]) -> List[In
if "4o" in model.model:
instance.args[1]["max_tokens"] = min(max_new_tokens, 16384)
elif isinstance(model, lm_eval_models.vllm_causallms.VLLM):
instance.args[1]["max_gen_toks"] = max_new_tokens
try:
# Get prompt from instance.args[0] (the templated string)
prompt = instance.args[0]
prompt_length = len(model.tokenizer.encode(prompt))

# Get max model length from vLLM engine
max_model_len = model.model.llm_engine.model_config.max_model_len

# Check if prompt itself exceeds model capacity
if prompt_length > max_model_len:
self.logger.warning(
f"Prompt length ({prompt_length}) exceeds model max length ({max_model_len}). "
f"Prompt will be truncated with no room for generation."
)

# Calculate max allowed generation tokens (16 token safety buffer)
max_allowed = max_model_len - prompt_length - VLLM_SAFETY_BUFFER_TOKENS
capped_max_new_tokens = min(max_new_tokens, max(1, max_allowed))

if capped_max_new_tokens < max_new_tokens:
self.logger.warning(
f"max_new_tokens ({max_new_tokens}) capped to {capped_max_new_tokens} "
f"(prompt: {prompt_length} tokens, model max: {max_model_len})"
)

instance.args[1]["max_gen_toks"] = capped_max_new_tokens
except Exception as e:
self.logger.warning(
f"Failed to calculate max_gen_toks for vLLM, using original value: {e}"
)
instance.args[1]["max_gen_toks"] = max_new_tokens
else: # Huggingface
instance.args[1]["max_new_tokens"] = max_new_tokens
return instances
Expand Down