Skip to content

Commit 141bf87

Browse files
committed
fix(cli): Correct reranker parameter usage
1 parent 8727832 commit 141bf87

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

src/vectorcode/subcommands/query/reranker/cross_encoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def __init__(
3131
configs.reranker_params["model_name_or_path"] = (
3232
"cross-encoder/ms-marco-MiniLM-L-6-v2"
3333
)
34-
self.model = CrossEncoder(**configs.reranker_params)
34+
model_name = configs.reranker_params.pop("model_name_or_path")
35+
self.model = CrossEncoder(model_name, **configs.reranker_params)
3536

3637
async def compute_similarity(self, results: list[str], query_message: str):
3738
scores = await asyncio.to_thread(

tests/subcommands/query/test_reranker.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import cast
12
from unittest.mock import MagicMock, patch
23

34
import numpy
@@ -98,23 +99,22 @@ async def test_naive_reranker_rerank(naive_reranker_conf, query_result):
9899

99100
@patch("sentence_transformers.CrossEncoder")
100101
def test_cross_encoder_reranker_initialization(mock_cross_encoder: MagicMock, config):
102+
model_name = config.reranker_params["model_name_or_path"]
101103
reranker = CrossEncoderReranker(config)
102-
103104
# Verify constructor was called with correct parameters
104-
mock_cross_encoder.assert_called_once_with(**config.reranker_params)
105+
mock_cross_encoder.assert_called_once_with(model_name, **config.reranker_params)
105106
assert reranker.n_result == config.n_result
106107

107108

108109
@patch("sentence_transformers.CrossEncoder")
109110
def test_cross_encoder_reranker_initialization_fallback_model_name(
110111
mock_cross_encoder: MagicMock, config
111112
):
112-
expected_params = {"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"}
113113
config.reranker_params = {}
114114
reranker = CrossEncoderReranker(config)
115115

116116
# Verify constructor was called with correct parameters
117-
mock_cross_encoder.assert_called_once_with(**expected_params)
117+
mock_cross_encoder.assert_called_once_with("cross-encoder/ms-marco-MiniLM-L-6-v2")
118118
assert reranker.n_result == config.n_result
119119

120120

@@ -211,14 +211,10 @@ def test_get_reranker(config, naive_reranker_conf):
211211
reranker = get_reranker(config)
212212
assert reranker.configs.reranker == "CrossEncoderReranker"
213213

214-
reranker = get_reranker(config)
214+
reranker = cast(CrossEncoderReranker, get_reranker(config))
215215
assert reranker.configs.reranker == "CrossEncoderReranker", (
216216
"configs.reranker should fallback to 'CrossEncoderReranker'"
217217
)
218-
assert (
219-
reranker.configs.reranker_params.get("model_name_or_path")
220-
== "cross-encoder/ms-marco-MiniLM-L-6-v2"
221-
), "configs.reranker_params should fallback to default params."
222218

223219

224220
def test_supported_rerankers_initialization(config, naive_reranker_conf):

0 commit comments

Comments
 (0)