From 56ed186894047ad59ed87e89156258e5977497f0 Mon Sep 17 00:00:00 2001 From: Mircea Trofin Date: Thu, 12 Feb 2026 21:55:18 -0800 Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20initia?= =?UTF-8?q?l=20version?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.7 --- compiler_opt/baseline_cache.py | 22 +++++++-------- compiler_opt/baseline_cache_test.py | 43 +++++++++++++++++++++-------- 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/compiler_opt/baseline_cache.py b/compiler_opt/baseline_cache.py index 96ae52b9..1febb5f9 100644 --- a/compiler_opt/baseline_cache.py +++ b/compiler_opt/baseline_cache.py @@ -22,32 +22,32 @@ class BaselineCache(Generic[T]): """Manages a cache of baseline scores.""" - def __init__(self, *, get_scores: Callable[[list[T]], list[float]], - get_key: Callable[[T], Any]): + def __init__(self, *, get_key: Callable[[T], Any]): """Constructor. Args: - get_scores: A callable that returns the scores for a batch of items. - The callable is responsible for timely completion. It must not - raise, and it must return results in the order of the items - provided. A None value is expected for items that could not produce - a value. get_key: A callable that returns the key for an item. """ - self._get_scores = get_scores self._get_key = get_key self._cache = {} - def get_score(self, items: list[T | None]): + def get_score(self, items: list[T | None], + get_scores_func: Callable[[list[T]], list[float]]): """Get the scores for a batch of items. The scores are returned in the same order as the provided items. A None result indicates the score could not be obtained. Args: items: A list of items to get scores for. + get_scores_func: A callable that returns the scores for a batch of + items. + The callable is responsible for timely completion. It must not + raise, and it must return results in the order of the items + provided. A None value is expected for items that could not produce + a value. """ - todo = [i for i in items if self._get_key(i) not in self._cache] - scores = self._get_scores(todo) + todo = {i for i in items if self._get_key(i) not in self._cache} + scores = get_scores_func(list(todo)) if len(scores) != len(todo): raise ValueError( "got a different number of results for the requested items") diff --git a/compiler_opt/baseline_cache_test.py b/compiler_opt/baseline_cache_test.py index 1997729c..88645de4 100644 --- a/compiler_opt/baseline_cache_test.py +++ b/compiler_opt/baseline_cache_test.py @@ -41,27 +41,49 @@ def track_score(lst): score_asked_for.extend(lst) return [mock[k] if k in mock else None for k in lst] - cache = baseline_cache.BaselineCache( - get_scores=track_score, get_key=lambda x: x) + cache = baseline_cache.BaselineCache(get_key=lambda x: x) self.assertEmpty(cache.get_cache()) - self.assertEqual(cache.get_score(["c", "b"]), [3, 2]) + self.assertEqual( + cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2]) self.assertDictEqual(cache.get_cache(), {"b": 2, "c": 3}) self.assertListEqual(score_asked_for, ["c", "b"]) score_asked_for.clear() - self.assertEqual(cache.get_score(["c", "b"]), [3, 2]) + self.assertEqual( + cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2]) self.assertListEqual(score_asked_for, []) - self.assertEqual(cache.get_score(["a", "c", "b"]), [1, 3, 2]) + self.assertEqual( + cache.get_score(["a", "c", "b"], get_scores_func=track_score), + [1, 3, 2]) self.assertListEqual(score_asked_for, ["a"]) score_asked_for.clear() - self.assertEqual(cache.get_score(["a", "n", "c", "b"]), [1, None, 3, 2]) + self.assertEqual( + cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score), + [1, None, 3, 2]) self.assertListEqual(score_asked_for, ["n"]) score_asked_for.clear() - self.assertEqual(cache.get_score(["a", "n", "c", "b"]), [1, None, 3, 2]) + self.assertEqual( + cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score), + [1, None, 3, 2]) self.assertListEqual(score_asked_for, []) + def test_duplicates(self): + mock = {"a": 1, "b": 2, "c": 3} + score_asked_for = [] + + def track_score(lst): + score_asked_for.extend(lst) + return [mock[k] if k in mock else None for k in lst] + + cache = baseline_cache.BaselineCache(get_key=lambda x: x) + self.assertEmpty(cache.get_cache()) + self.assertEqual( + cache.get_score(["c", "b", "c", "b"], get_scores_func=track_score), + [3, 2, 3, 2]) + self.assertListEqual(sorted(score_asked_for), sorted(["c", "b"])) + def test_with_workers(self): with local_worker_manager.LocalWorkerPoolManager( worker_class=MockWorker, count=4) as lwm: @@ -78,13 +100,12 @@ def get_scores(items: list[str]): futures, return_when=concurrent.futures.ALL_COMPLETED) return [f.result() if f.exception() is None else None for f in futures] - cache = baseline_cache.BaselineCache( - get_key=lambda x: x, get_scores=get_scores) + cache = baseline_cache.BaselineCache(get_key=lambda x: x) self.assertEmpty(cache.get_cache()) - self.assertEqual(cache.get_score(["4", "2"]), [4, 2]) + self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2]) self.assertListEqual(score_asked_for, ["4", "2"]) self.assertDictEqual(cache.get_cache(), {"4": 4, "2": 2}) score_asked_for.clear() - self.assertEqual(cache.get_score(["4", "2"]), [4, 2]) + self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2]) self.assertListEqual(score_asked_for, [])