|
| 1 | +from typing import cast |
1 | 2 | from unittest.mock import MagicMock, patch |
2 | 3 |
|
3 | 4 | import numpy |
@@ -98,23 +99,22 @@ async def test_naive_reranker_rerank(naive_reranker_conf, query_result): |
98 | 99 |
|
99 | 100 | @patch("sentence_transformers.CrossEncoder") |
100 | 101 | def test_cross_encoder_reranker_initialization(mock_cross_encoder: MagicMock, config): |
| 102 | + model_name = config.reranker_params["model_name_or_path"] |
101 | 103 | reranker = CrossEncoderReranker(config) |
102 | | - |
103 | 104 | # Verify constructor was called with correct parameters |
104 | | - mock_cross_encoder.assert_called_once_with(**config.reranker_params) |
| 105 | + mock_cross_encoder.assert_called_once_with(model_name, **config.reranker_params) |
105 | 106 | assert reranker.n_result == config.n_result |
106 | 107 |
|
107 | 108 |
|
108 | 109 | @patch("sentence_transformers.CrossEncoder") |
109 | 110 | def test_cross_encoder_reranker_initialization_fallback_model_name( |
110 | 111 | mock_cross_encoder: MagicMock, config |
111 | 112 | ): |
112 | | - expected_params = {"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"} |
113 | 113 | config.reranker_params = {} |
114 | 114 | reranker = CrossEncoderReranker(config) |
115 | 115 |
|
116 | 116 | # Verify constructor was called with correct parameters |
117 | | - mock_cross_encoder.assert_called_once_with(**expected_params) |
| 117 | + mock_cross_encoder.assert_called_once_with("cross-encoder/ms-marco-MiniLM-L-6-v2") |
118 | 118 | assert reranker.n_result == config.n_result |
119 | 119 |
|
120 | 120 |
|
@@ -211,14 +211,10 @@ def test_get_reranker(config, naive_reranker_conf): |
211 | 211 | reranker = get_reranker(config) |
212 | 212 | assert reranker.configs.reranker == "CrossEncoderReranker" |
213 | 213 |
|
214 | | - reranker = get_reranker(config) |
| 214 | + reranker = cast(CrossEncoderReranker, get_reranker(config)) |
215 | 215 | assert reranker.configs.reranker == "CrossEncoderReranker", ( |
216 | 216 | "configs.reranker should fallback to 'CrossEncoderReranker'" |
217 | 217 | ) |
218 | | - assert ( |
219 | | - reranker.configs.reranker_params.get("model_name_or_path") |
220 | | - == "cross-encoder/ms-marco-MiniLM-L-6-v2" |
221 | | - ), "configs.reranker_params should fallback to default params." |
222 | 218 |
|
223 | 219 |
|
224 | 220 | def test_supported_rerankers_initialization(config, naive_reranker_conf): |
|
0 commit comments