From 0476b0040ac77d80427bb96623fe69738cf8672b Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Thu, 5 Feb 2026 18:36:15 +0000 Subject: [PATCH 1/3] Draft of updated lock notify --- .../apache_beam/ml/inference/model_manager.py | 47 ++++++++++++++----- .../ml/inference/model_manager_test.py | 14 ++++-- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index cc9f833c2682..7516c2a77815 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -288,6 +288,17 @@ def _solve(self): logger.error("Solver failed: %s", e) +class QueueTicket: + def __init__(self, priority, ticket_num, tag): + self.priority = priority + self.ticket_num = ticket_num + self.tag = tag + self.wake_event = threading.Event() + + def __lt__(self, other): + return (self.priority, self.ticket_num) < (other.priority, other.ticket_num) + + class ModelManager: """Manages model lifecycles, caching, and resource arbitration. @@ -343,6 +354,7 @@ def __init__( # and also priority for unknown models. self._wait_queue = [] self._ticket_counter = itertools.count() + self._cancelled_tickets = set() # TODO: Consider making the wait to be smarter, i.e. # splitting read/write etc. to avoid potential contention. self._cv = threading.Condition() @@ -417,10 +429,19 @@ def should_spawn_model(self, tag: str, ticket_num: int) -> bool: self._cv.wait(timeout=self._lock_timeout_seconds) return False + def _wake_next_in_queue(self): + if self._wait_queue: + # Clean up cancelled tickets at head of queue + while self._wait_queue and self._wait_queue[ + 0].ticket_num in self._cancelled_tickets: + heapq.heappop(self._wait_queue) + self._cancelled_tickets.remove(self._wait_queue[0].ticket_num) + next_inline = self._wait_queue[0] + next_inline.wake_event.set() + def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: current_priority = 0 if self._estimator.is_unknown(tag) else 1 ticket_num = next(self._ticket_counter) - my_id = object() with self._cv: # FAST PATH: Grab from idle LRU if available @@ -440,7 +461,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: len(self._models[tag]), ticket_num) heapq.heappush( - self._wait_queue, (current_priority, ticket_num, my_id, tag)) + self._wait_queue, QueueTicket(current_priority, ticket_num, tag)) est_cost = 0.0 is_unknown = False @@ -453,7 +474,8 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: raise RuntimeError( f"Timeout waiting to acquire model: {tag} " f"after {wait_time_elapsed:.1f} seconds.") - if not self._wait_queue or self._wait_queue[0][2] is not my_id: + if not self._wait_queue or self._wait_queue[ + 0].ticket_num != ticket_num: logger.info( "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num) self._cv.wait(timeout=self._lock_timeout_seconds) @@ -468,8 +490,9 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: heapq.heappop(self._wait_queue) current_priority = real_priority heapq.heappush( - self._wait_queue, (current_priority, ticket_num, my_id, tag)) - self._cv.notify_all() + self._wait_queue, + QueueTicket(current_priority, ticket_num, tag)) + self._wake_next_in_queue() continue # Try grab from LRU again in case model was released during wait @@ -508,7 +531,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: finally: # Remove self from wait queue once done - if self._wait_queue and self._wait_queue[0][2] is my_id: + if self._wait_queue and self._wait_queue[0].ticket_num == ticket_num: heapq.heappop(self._wait_queue) else: logger.warning( @@ -516,11 +539,9 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: ", this is not expected: tag=%s ticket num=%s", tag, ticket_num) - for i, item in enumerate(self._wait_queue): - if item[2] is my_id: - self._wait_queue.pop(i) - heapq.heapify(self._wait_queue) - self._cv.notify_all() + # Marked as cancelled so that we skip when we reach head later + self._cancelled_tickets.add(ticket_num) + self._wake_next_in_queue() return self._spawn_new_model(tag, loader_func, is_unknown, est_cost) @@ -553,7 +574,7 @@ def release_model(self, tag: str, instance: Any): self._estimator.add_observation(snapshot, peak_during_job) finally: - self._cv.notify_all() + self._wake_next_in_queue() def _try_grab_from_lru(self, tag: str) -> Any: target_key = None @@ -596,7 +617,7 @@ def _evict_to_make_space( # TODO: Also factor in the active counts to avoid thrashing demand_map = Counter() for item in self._wait_queue: - demand_map[item[3]] += 1 + demand_map[item.tag] += 1 my_demand = demand_map[requesting_tag] am_i_starving = len(self._models[requesting_tag]) == 0 diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 1bd8edd34d18..2ed3d3538945 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -174,12 +174,20 @@ def loader(): def acquire_model_with_timeout(): return self.manager.acquire_model(model_name, loader) - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(acquire_model_with_timeout) + with ThreadPoolExecutor(max_workers=16) as executor: + futures = [ + executor.submit(acquire_model_with_timeout) for i in range(1000) + ] with self.assertRaises(RuntimeError) as context: - future.result(timeout=5.0) + for future in futures: + future.result() self.assertIn("Timeout waiting to acquire model", str(context.exception)) + # Release the initially acquired model and try to acquire again + # to make sure the manager is still functional + self.manager.release_model(model_name, model_name) + _ = self.manager.acquire_model(model_name, loader) + def test_model_manager_capacity_check(self): """ Test that the manager blocks when spawning models exceeds the limit, From aee2900e1032f61c9076c788f7058dd1db7cb5f7 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 6 Feb 2026 00:36:51 +0000 Subject: [PATCH 2/3] Complete queue ticket implementation --- .../apache_beam/ml/inference/model_manager.py | 24 ++++++++++++------- .../ml/inference/model_manager_test.py | 2 +- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 7516c2a77815..8a44bc467a40 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -434,14 +434,23 @@ def _wake_next_in_queue(self): # Clean up cancelled tickets at head of queue while self._wait_queue and self._wait_queue[ 0].ticket_num in self._cancelled_tickets: - heapq.heappop(self._wait_queue) self._cancelled_tickets.remove(self._wait_queue[0].ticket_num) + heapq.heappop(self._wait_queue) next_inline = self._wait_queue[0] next_inline.wake_event.set() + def _wait_in_queue(self, ticket: QueueTicket): + self._cv.release() + try: + ticket.wake_event.wait(timeout=self._lock_timeout_seconds) + ticket.wake_event.clear() + finally: + self._cv.acquire() + def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: current_priority = 0 if self._estimator.is_unknown(tag) else 1 ticket_num = next(self._ticket_counter) + my_ticket = QueueTicket(current_priority, ticket_num, tag) with self._cv: # FAST PATH: Grab from idle LRU if available @@ -460,8 +469,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: current_priority, len(self._models[tag]), ticket_num) - heapq.heappush( - self._wait_queue, QueueTicket(current_priority, ticket_num, tag)) + heapq.heappush(self._wait_queue, my_ticket) est_cost = 0.0 is_unknown = False @@ -478,7 +486,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: 0].ticket_num != ticket_num: logger.info( "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num) - self._cv.wait(timeout=self._lock_timeout_seconds) + self._wait_in_queue(my_ticket) continue # Re-evaluate priority in case model became known during wait @@ -489,9 +497,8 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: if current_priority != real_priority: heapq.heappop(self._wait_queue) current_priority = real_priority - heapq.heappush( - self._wait_queue, - QueueTicket(current_priority, ticket_num, tag)) + my_ticket = QueueTicket(current_priority, ticket_num, tag) + heapq.heappush(self._wait_queue, my_ticket) self._wake_next_in_queue() continue @@ -517,7 +524,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: "Waiting due to isolation in progress: tag=%s ticket num%s", tag, ticket_num) - self._cv.wait(timeout=self._lock_timeout_seconds) + self._wait_in_queue(my_ticket) continue if self.should_spawn_model(tag, ticket_num): @@ -575,6 +582,7 @@ def release_model(self, tag: str, instance: Any): finally: self._wake_next_in_queue() + self._cv.notify_all() def _try_grab_from_lru(self, tag: str) -> Any: target_key = None diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 2ed3d3538945..dd6d4a02e29e 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -174,7 +174,7 @@ def loader(): def acquire_model_with_timeout(): return self.manager.acquire_model(model_name, loader) - with ThreadPoolExecutor(max_workers=16) as executor: + with ThreadPoolExecutor(max_workers=1000) as executor: futures = [ executor.submit(acquire_model_with_timeout) for i in range(1000) ] From febb8b15a632b1a05a32b3fc13806fbf9884cef5 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 6 Feb 2026 21:08:15 +0000 Subject: [PATCH 3/3] Remove redudant warning log --- sdks/python/apache_beam/ml/inference/model_manager.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 8a44bc467a40..bf7c6a43ba63 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -541,11 +541,6 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: if self._wait_queue and self._wait_queue[0].ticket_num == ticket_num: heapq.heappop(self._wait_queue) else: - logger.warning( - "Item not at head of wait queue during cleanup" - ", this is not expected: tag=%s ticket num=%s", - tag, - ticket_num) # Marked as cancelled so that we skip when we reach head later self._cancelled_tickets.add(ticket_num) self._wake_next_in_queue()