Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flexeval/core/language_model/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def wrapper(self: VLLM, *args: tuple, **kwargs: dict) -> Callable:

self.llm = LLM(self.model_name, **self.model_kwargs)
if self.model_limit_tokens == "default":
self.model_limit_tokens = self.llm.llm_engine.get_model_config().max_model_len
self.model_limit_tokens = self.llm.llm_engine.model_config.max_model_len
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return method(self, *args, **kwargs)

return wrapper
Expand Down Expand Up @@ -306,7 +306,7 @@ def _batch_compute_log_probs(
prefix + continuation for prefix, continuation in zip(batch_prefix_ids, batch_continuation_ids)
]

max_length = self.llm.llm_engine.get_model_config().max_seq_len_to_capture
max_length = self.llm.llm_engine.model_config.max_model_len
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stride = stride or max_length // 2
if not (0 < stride < max_length):
msg = f"stride must be in (0, {max_length}), but got {stride}"
Expand All @@ -315,7 +315,7 @@ def _batch_compute_log_probs(

from vllm import RequestOutput, SamplingParams
from vllm.inputs import TokensPrompt
from vllm.sequence import Logprob
from vllm.logprobs import Logprob
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


sampling_params = SamplingParams(temperature=0.0, max_tokens=1, prompt_logprobs=1)

Expand Down
2,438 changes: 1,714 additions & 724 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ sacrebleu = {extras = ["ja"], version = "^2.4.1"}
jiwer = "^3.0.4"
openai = "^1.52.2"
google-api-python-client = "^2.131.0"
vllm = {version = "0.10.2", optional = true }
vllm = {version = "0.16.0", optional = true }
loguru = "^0.7.2"
wandb = {version = "^0.17.2", optional = true}
pyarrow = "16.1.0" # set the version because we get "Unable to find installation candidates" with 17.0.0
Expand Down
24 changes: 2 additions & 22 deletions tests/core/language_model/vllm/test_vllm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,14 @@ def chat_lm() -> VLLM:
model_kwargs={
"seed": 42,
"gpu_memory_utilization": 0.1,
"max_model_len": 2048,
"enforce_eager": True,
"disable_custom_all_reduce": True,
},
tokenizer_kwargs={"use_fast": False},
)
yield llm
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory

cleanup_dist_env_and_memory()


@pytest.fixture(scope="module")
def chat_lm_with_system_message() -> VLLM:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why this function was originally placed here, but since it was causing errors, I moved it to test_vllm_specific.py .

llm = VLLM(
model="sbintuitions/tiny-lm-chat",
model_kwargs={
"seed": 42,
"gpu_memory_utilization": 0.1,
"enforce_eager": True,
"disable_custom_all_reduce": True,
},
tokenizer_kwargs={"use_fast": False},
system_message="You are a helpful assistant.",
)
yield llm
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory

cleanup_dist_env_and_memory()
llm.cleanup_resources()


@pytest.mark.skipif(not is_vllm_enabled(), reason="vllm library is not installed")
Expand Down
15 changes: 6 additions & 9 deletions tests/core/language_model/vllm/test_vllm_custom_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,15 @@ def chat_lm_with_custom_chat_template() -> VLLM:
model_kwargs={
"seed": 42,
"gpu_memory_utilization": 0.1,
"max_model_len": 2048,
"enforce_eager": True,
"disable_custom_all_reduce": True,
},
tokenizer_kwargs={"use_fast": False},
custom_chat_template=custom_chat_template,
)
yield llm
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory

cleanup_dist_env_and_memory()
llm.cleanup_resources()


# With 1 1 1, the continuation was not 1.
Expand All @@ -55,6 +54,7 @@ def chat_lm_with_fill_zeros() -> VLLM:
model_kwargs={
"seed": 42,
"gpu_memory_utilization": 0.1,
"max_model_len": 2048,
"enforce_eager": True,
"disable_custom_all_reduce": True,
},
Expand All @@ -63,9 +63,7 @@ def chat_lm_with_fill_zeros() -> VLLM:
chat_template_kwargs={"fill_with_zeros": True},
)
yield llm
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory

cleanup_dist_env_and_memory()
llm.cleanup_resources()


@pytest.fixture(scope="module")
Expand All @@ -78,6 +76,7 @@ def chat_lm_with_fill_xs() -> VLLM:
model_kwargs={
"seed": 42,
"gpu_memory_utilization": 0.1,
"max_model_len": 2048,
"enforce_eager": True,
"disable_custom_all_reduce": True,
},
Expand All @@ -86,9 +85,7 @@ def chat_lm_with_fill_xs() -> VLLM:
chat_template_kwargs={"fill_with_zeros": False},
)
yield llm
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory

cleanup_dist_env_and_memory()
llm.cleanup_resources()


@pytest.mark.skipif(not is_vllm_enabled(), reason="vllm library is not installed")
Expand Down
3 changes: 2 additions & 1 deletion tests/core/language_model/vllm/test_vllm_serve_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ def chat_lm() -> VLLMServeLM:
model_kwargs={
"seed": 42,
"gpu_memory_utilization": 0.1,
"max_model_len": 2048,
"enforce_eager": True,
"disable_custom_all_reduce": True,
"tokenizer_mode": "slow",
},
)
yield llm
llm.manager.stop()
llm.cleanup_resources()
if openai_api_key is not None:
os.environ["OPENAI_API_KEY"] = openai_api_key

Expand Down
57 changes: 41 additions & 16 deletions tests/core/language_model/vllm/test_vllm_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,20 @@


@pytest.fixture(scope="module")
def chat_lm() -> Generator[VLLM, None, None]:
def chat_lm() -> VLLM:
llm = VLLM(
model="sbintuitions/tiny-lm-chat",
model_kwargs={
"seed": 42,
"gpu_memory_utilization": 0.1,
"max_model_len": 2048,
"enforce_eager": True,
"dtype": "float32",
"disable_custom_all_reduce": True,
},
tokenizer_kwargs={"use_fast": False},
)
yield llm
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory

cleanup_dist_env_and_memory()
llm.cleanup_resources()


@pytest.fixture(scope="module")
Expand All @@ -39,14 +37,29 @@ def chat_lm_qwen() -> Generator[VLLM, None, None]:
model_kwargs={
"seed": 42,
"gpu_memory_utilization": 0.1,
"max_model_len": 2048,
},
)
yield llm
llm.cleanup_resources()


@pytest.fixture(scope="module")
def chat_lm_with_system_message() -> VLLM:
llm = VLLM(
model="sbintuitions/tiny-lm-chat",
model_kwargs={
"seed": 42,
"gpu_memory_utilization": 0.1,
"max_model_len": 2048,
"enforce_eager": True,
"disable_custom_all_reduce": True,
},
tokenizer_kwargs={"use_fast": False},
system_message="You are a helpful assistant.",
)
yield llm
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory

cleanup_dist_env_and_memory()
llm.cleanup_resources()


@pytest.fixture(scope="module")
Expand All @@ -57,17 +70,15 @@ def chat_lm_for_tool_calling() -> Generator[VLLM, None, None]:
model_kwargs={
"seed": 42,
"gpu_memory_utilization": 0.1,
"max_model_len": 2048,
"enforce_eager": True,
"dtype": "float32",
"disable_custom_all_reduce": True,
},
tokenizer_kwargs={"use_fast": False},
tool_parser=tool_parser,
)
yield llm
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory

cleanup_dist_env_and_memory()
llm.cleanup_resources()


@pytest.fixture(scope="module")
Expand All @@ -77,24 +88,38 @@ def hf_lm(model_name: str = "sbintuitions/tiny-lm-chat") -> HuggingFaceLM:
)


@pytest.fixture(scope="module")
def hf_lm_qwen(model_name: str = "Qwen/Qwen3-0.6B-Base") -> HuggingFaceLM:
return HuggingFaceLM(
model=model_name, model_kwargs={"torch_dtype": "float32"}, default_gen_kwargs={"temperature": 0.0}
)


@pytest.mark.skipif(not is_vllm_enabled(), reason="vllm library is not installed")
@pytest.mark.parametrize("chat_lm_name", ["chat_lm", "chat_lm_qwen"])
@pytest.mark.parametrize(
("chat_lm_name", "hf_lm_name"),
[
("chat_lm", "hf_lm"),
("chat_lm_qwen", "hf_lm_qwen"),
],
)
def test_batch_compute_log_probs_approximates_hf_lm(
request: pytest.FixtureRequest,
chat_lm_name: str,
hf_lm: HuggingFaceLM,
hf_lm_name: str,
) -> None:
chat_lm = request.getfixturevalue(chat_lm_name)
hf_lm = request.getfixturevalue(hf_lm_name)
prefix_list = ["それは正しい日本語ですか?"]
text_list = ["これは正しい日本語です。"]

vllm_log_probs = chat_lm.compute_log_probs(text_list)
hf_log_probs = hf_lm.compute_log_probs(text_list)
assert vllm_log_probs == pytest.approx(hf_log_probs, abs=1e-2)
assert vllm_log_probs == pytest.approx(hf_log_probs, abs=0.5)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that differences of this magnitude can occur depending on the seed (or the environment), so I widen the acceptable tolerance range.
The values are roughly around -33.2 (Qwen) and -47.2 (Sarashina), so I think allowing an error margin of about ±0.5 is reasonable in practice.


vllm_log_probs = chat_lm.compute_log_probs(text_list, prefix_list=prefix_list)
hf_log_probs = hf_lm.compute_log_probs(text_list, prefix_list=prefix_list)
assert vllm_log_probs == pytest.approx(hf_log_probs, abs=1e-2)
assert vllm_log_probs == pytest.approx(hf_log_probs, abs=0.5)


@pytest.mark.skipif(not is_vllm_enabled(), reason="vllm library is not installed")
Expand Down