diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index cc9f833c2682..bf7c6a43ba63 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -288,6 +288,17 @@ def _solve(self): logger.error("Solver failed: %s", e) +class QueueTicket: + def __init__(self, priority, ticket_num, tag): + self.priority = priority + self.ticket_num = ticket_num + self.tag = tag + self.wake_event = threading.Event() + + def __lt__(self, other): + return (self.priority, self.ticket_num) < (other.priority, other.ticket_num) + + class ModelManager: """Manages model lifecycles, caching, and resource arbitration. @@ -343,6 +354,7 @@ def __init__( # and also priority for unknown models. self._wait_queue = [] self._ticket_counter = itertools.count() + self._cancelled_tickets = set() # TODO: Consider making the wait to be smarter, i.e. # splitting read/write etc. to avoid potential contention. self._cv = threading.Condition() @@ -417,10 +429,28 @@ def should_spawn_model(self, tag: str, ticket_num: int) -> bool: self._cv.wait(timeout=self._lock_timeout_seconds) return False + def _wake_next_in_queue(self): + if self._wait_queue: + # Clean up cancelled tickets at head of queue + while self._wait_queue and self._wait_queue[ + 0].ticket_num in self._cancelled_tickets: + self._cancelled_tickets.remove(self._wait_queue[0].ticket_num) + heapq.heappop(self._wait_queue) + next_inline = self._wait_queue[0] + next_inline.wake_event.set() + + def _wait_in_queue(self, ticket: QueueTicket): + self._cv.release() + try: + ticket.wake_event.wait(timeout=self._lock_timeout_seconds) + ticket.wake_event.clear() + finally: + self._cv.acquire() + def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: current_priority = 0 if self._estimator.is_unknown(tag) else 1 ticket_num = next(self._ticket_counter) - my_id = object() + my_ticket = QueueTicket(current_priority, ticket_num, tag) with self._cv: # FAST PATH: Grab from idle LRU if available @@ -439,8 +469,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: current_priority, len(self._models[tag]), ticket_num) - heapq.heappush( - self._wait_queue, (current_priority, ticket_num, my_id, tag)) + heapq.heappush(self._wait_queue, my_ticket) est_cost = 0.0 is_unknown = False @@ -453,10 +482,11 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: raise RuntimeError( f"Timeout waiting to acquire model: {tag} " f"after {wait_time_elapsed:.1f} seconds.") - if not self._wait_queue or self._wait_queue[0][2] is not my_id: + if not self._wait_queue or self._wait_queue[ + 0].ticket_num != ticket_num: logger.info( "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num) - self._cv.wait(timeout=self._lock_timeout_seconds) + self._wait_in_queue(my_ticket) continue # Re-evaluate priority in case model became known during wait @@ -467,9 +497,9 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: if current_priority != real_priority: heapq.heappop(self._wait_queue) current_priority = real_priority - heapq.heappush( - self._wait_queue, (current_priority, ticket_num, my_id, tag)) - self._cv.notify_all() + my_ticket = QueueTicket(current_priority, ticket_num, tag) + heapq.heappush(self._wait_queue, my_ticket) + self._wake_next_in_queue() continue # Try grab from LRU again in case model was released during wait @@ -494,7 +524,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: "Waiting due to isolation in progress: tag=%s ticket num%s", tag, ticket_num) - self._cv.wait(timeout=self._lock_timeout_seconds) + self._wait_in_queue(my_ticket) continue if self.should_spawn_model(tag, ticket_num): @@ -508,19 +538,12 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: finally: # Remove self from wait queue once done - if self._wait_queue and self._wait_queue[0][2] is my_id: + if self._wait_queue and self._wait_queue[0].ticket_num == ticket_num: heapq.heappop(self._wait_queue) else: - logger.warning( - "Item not at head of wait queue during cleanup" - ", this is not expected: tag=%s ticket num=%s", - tag, - ticket_num) - for i, item in enumerate(self._wait_queue): - if item[2] is my_id: - self._wait_queue.pop(i) - heapq.heapify(self._wait_queue) - self._cv.notify_all() + # Marked as cancelled so that we skip when we reach head later + self._cancelled_tickets.add(ticket_num) + self._wake_next_in_queue() return self._spawn_new_model(tag, loader_func, is_unknown, est_cost) @@ -553,6 +576,7 @@ def release_model(self, tag: str, instance: Any): self._estimator.add_observation(snapshot, peak_during_job) finally: + self._wake_next_in_queue() self._cv.notify_all() def _try_grab_from_lru(self, tag: str) -> Any: @@ -596,7 +620,7 @@ def _evict_to_make_space( # TODO: Also factor in the active counts to avoid thrashing demand_map = Counter() for item in self._wait_queue: - demand_map[item[3]] += 1 + demand_map[item.tag] += 1 my_demand = demand_map[requesting_tag] am_i_starving = len(self._models[requesting_tag]) == 0 diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 7cfb73cb668f..270401857e04 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -174,12 +174,20 @@ def loader(): def acquire_model_with_timeout(): return self.manager.acquire_model(model_name, loader) - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(acquire_model_with_timeout) + with ThreadPoolExecutor(max_workers=1000) as executor: + futures = [ + executor.submit(acquire_model_with_timeout) for i in range(1000) + ] with self.assertRaises(RuntimeError) as context: - future.result(timeout=5.0) + for future in futures: + future.result() self.assertIn("Timeout waiting to acquire model", str(context.exception)) + # Release the initially acquired model and try to acquire again + # to make sure the manager is still functional + self.manager.release_model(model_name, model_name) + _ = self.manager.acquire_model(model_name, loader) + def test_model_manager_capacity_check(self): """ Test that the manager blocks when spawning models exceeds the limit,