Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions compiler_opt/baseline_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,33 @@
class BaselineCache(Generic[T]):
"""Manages a cache of baseline scores."""

def __init__(self, *, get_scores: Callable[[list[T]], list[float]],
get_key: Callable[[T], Any]):
def __init__(self, *, get_key: Callable[[T], Any]):
"""Constructor.

Args:
get_scores: A callable that returns the scores for a batch of items.
The callable is responsible for timely completion. It must not
raise, and it must return results in the order of the items
provided. A None value is expected for items that could not produce
a value.
get_key: A callable that returns the key for an item.
"""
self._get_scores = get_scores
self._get_key = get_key
self._cache = {}

def get_score(self, items: list[T | None]):
def get_score(self, items: list[T | None],
get_scores_func: Callable[[list[T]], list[float]]):
"""Get the scores for a batch of items.
The scores are returned in the same order as the provided items. A None
result indicates the score could not be obtained.

Args:
items: A list of items to get scores for.
get_scores_func: A callable that returns the scores for a batch of
items.

get_scores_func: Responsible for timely completion. It must not
raise, and it must return results in the order of the items
provided. A None value is expected for items that could not
produce a value.
"""
todo = [i for i in items if self._get_key(i) not in self._cache]
scores = self._get_scores(todo)
todo = {i for i in items if self._get_key(i) not in self._cache}
scores = get_scores_func(list(todo))
if len(scores) != len(todo):
raise ValueError(
"got a different number of results for the requested items")
Expand Down
47 changes: 34 additions & 13 deletions compiler_opt/baseline_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,49 @@ def track_score(lst):
score_asked_for.extend(lst)
return [mock[k] if k in mock else None for k in lst]

cache = baseline_cache.BaselineCache(
get_scores=track_score, get_key=lambda x: x)
cache = baseline_cache.BaselineCache(get_key=lambda x: x)
self.assertEmpty(cache.get_cache())
self.assertEqual(cache.get_score(["c", "b"]), [3, 2])
self.assertEqual(
cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2])
self.assertDictEqual(cache.get_cache(), {"b": 2, "c": 3})
self.assertListEqual(score_asked_for, ["c", "b"])
self.assertListEqual(sorted(score_asked_for), sorted(["c", "b"]))
score_asked_for.clear()

self.assertEqual(cache.get_score(["c", "b"]), [3, 2])
self.assertEqual(
cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2])
self.assertListEqual(score_asked_for, [])
self.assertEqual(cache.get_score(["a", "c", "b"]), [1, 3, 2])
self.assertEqual(
cache.get_score(["a", "c", "b"], get_scores_func=track_score),
[1, 3, 2])
self.assertListEqual(score_asked_for, ["a"])
score_asked_for.clear()

self.assertEqual(cache.get_score(["a", "n", "c", "b"]), [1, None, 3, 2])
self.assertEqual(
cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score),
[1, None, 3, 2])
self.assertListEqual(score_asked_for, ["n"])
score_asked_for.clear()

self.assertEqual(cache.get_score(["a", "n", "c", "b"]), [1, None, 3, 2])
self.assertEqual(
cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score),
[1, None, 3, 2])
self.assertListEqual(score_asked_for, [])

def test_duplicates(self):
mock = {"a": 1, "b": 2, "c": 3}
score_asked_for = []

def track_score(lst):
score_asked_for.extend(lst)
return [mock[k] if k in mock else None for k in lst]

cache = baseline_cache.BaselineCache(get_key=lambda x: x)
self.assertEmpty(cache.get_cache())
self.assertEqual(
cache.get_score(["c", "b", "c", "b"], get_scores_func=track_score),
[3, 2, 3, 2])
self.assertListEqual(sorted(score_asked_for), sorted(["c", "b"]))

def test_with_workers(self):
with local_worker_manager.LocalWorkerPoolManager(
worker_class=MockWorker, count=4) as lwm:
Expand All @@ -78,13 +100,12 @@ def get_scores(items: list[str]):
futures, return_when=concurrent.futures.ALL_COMPLETED)
return [f.result() if f.exception() is None else None for f in futures]

cache = baseline_cache.BaselineCache(
get_key=lambda x: x, get_scores=get_scores)
cache = baseline_cache.BaselineCache(get_key=lambda x: x)
self.assertEmpty(cache.get_cache())
self.assertEqual(cache.get_score(["4", "2"]), [4, 2])
self.assertListEqual(score_asked_for, ["4", "2"])
self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2])
self.assertListEqual(sorted(score_asked_for), sorted(["4", "2"]))
self.assertDictEqual(cache.get_cache(), {"4": 4, "2": 2})
score_asked_for.clear()

self.assertEqual(cache.get_score(["4", "2"]), [4, 2])
self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2])
self.assertListEqual(score_asked_for, [])