diff --git a/compiler_opt/baseline_cache_test.py b/compiler_opt/baseline_cache_test.py index 4885c8a9..d9cc9364 100644 --- a/compiler_opt/baseline_cache_test.py +++ b/compiler_opt/baseline_cache_test.py @@ -41,29 +41,43 @@ def track_score(lst): score_asked_for.extend(lst) return [mock[k] if k in mock else None for k in lst] + cache = baseline_cache.BaselineCache(get_key=lambda x: x) cache = baseline_cache.BaselineCache(get_key=lambda x: x) self.assertEmpty(cache.get_cache()) + self.assertEqual( + cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2]) self.assertEqual( cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2]) self.assertDictEqual(cache.get_cache(), {"b": 2, "c": 3}) self.assertListEqual(sorted(score_asked_for), sorted(["c", "b"])) score_asked_for.clear() + self.assertEqual( + cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2]) self.assertEqual( cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2]) self.assertListEqual(score_asked_for, []) + self.assertEqual( + cache.get_score(["a", "c", "b"], get_scores_func=track_score), + [1, 3, 2]) self.assertEqual( cache.get_score(["a", "c", "b"], get_scores_func=track_score), [1, 3, 2]) self.assertListEqual(score_asked_for, ["a"]) score_asked_for.clear() + self.assertEqual( + cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score), + [1, None, 3, 2]) self.assertEqual( cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score), [1, None, 3, 2]) self.assertListEqual(score_asked_for, ["n"]) score_asked_for.clear() + self.assertEqual( + cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score), + [1, None, 3, 2]) self.assertEqual( cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score), [1, None, 3, 2]) @@ -84,6 +98,21 @@ def track_score(lst): [3, 2, 3, 2]) self.assertListEqual(sorted(score_asked_for), sorted(["c", "b"])) + def test_duplicates(self): + mock = {"a": 1, "b": 2, "c": 3} + score_asked_for = [] + + def track_score(lst): + score_asked_for.extend(lst) + return [mock[k] if k in mock else None for k in lst] + + cache = baseline_cache.BaselineCache(get_key=lambda x: x) + self.assertEmpty(cache.get_cache()) + self.assertEqual( + cache.get_score(["c", "b", "c", "b"], get_scores_func=track_score), + [3, 2, 3, 2]) + self.assertListEqual(sorted(score_asked_for), sorted(["c", "b"])) + def test_with_workers(self): with local_worker_manager.LocalWorkerPoolManager( worker_class=MockWorker, count=4) as lwm: @@ -100,12 +129,14 @@ def get_scores(items: list[str]): futures, return_when=concurrent.futures.ALL_COMPLETED) return [f.result() if f.exception() is None else None for f in futures] + cache = baseline_cache.BaselineCache(get_key=lambda x: x) cache = baseline_cache.BaselineCache(get_key=lambda x: x) self.assertEmpty(cache.get_cache()) self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2]) - self.assertListEqual(sorted(score_asked_for), sorted(["4", "2"])) + self.assertListEqual(score_asked_for, ["4", "2"]) self.assertDictEqual(cache.get_cache(), {"4": 4, "2": 2}) score_asked_for.clear() + self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2]) self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2]) self.assertListEqual(score_asked_for, [])