From fd80b51dc3c9681b36b68491929a6525a8805198 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 9 Feb 2026 20:39:16 +0000 Subject: [PATCH 1/8] Add OOM protection handling for RunInference --- sdks/python/apache_beam/ml/inference/base.py | 29 ++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 1c3f0918bafd..07128f8e8636 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -1330,6 +1330,29 @@ def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: return self._base.get_postprocess_fns() + [self._postprocess_fn] +class OOMProtectedFn: + def __init__(self, func): + self.func = func + + def __call__(self, *args, **kwargs): + try: + return self.func(*args, **kwargs) + except Exception as e: + # Check string to avoid hard import dependency + if 'out of memory' in str(e) and 'CUDA' in str(e): + logging.warning("Caught CUDA OOM during operation. Cleaning memory.") + try: + import gc + import torch + gc.collect() + torch.cuda.empty_cache() + except ImportError: + pass + except Exception as cleanup_error: + logging.error("Failed to clean up CUDA memory: %s", cleanup_error) + raise e + + class RunInference(beam.PTransform[beam.PCollection[Union[ExampleT, Iterable[ExampleT]]], beam.PCollection[PredictionT]]): @@ -1831,7 +1854,9 @@ def __call__(self): unique_tag = self.model_tag + '_' + uuid.uuid4().hex # Ensure that each model loaded in a different process for parallelism multi_process_shared.MultiProcessShared( - self.loader_func, tag=unique_tag, always_proxy=True, + OOMProtectedFn(self.loader_func), + tag=unique_tag, + always_proxy=True, spawn_process=True).acquire() # Only return the tag to avoid pickling issues with the model itself. return unique_tag @@ -2021,7 +2046,7 @@ def _run_inference(self, batch, inference_args): unique_tag = model model = multi_process_shared.MultiProcessShared( lambda: None, tag=model, always_proxy=True).acquire() - result_generator = self._model_handler.run_inference( + result_generator = (OOMProtectedFn(self._model_handler.run_inference))( batch, model, inference_args) if self.use_model_manager: self._model.release_model(self._model_tag, unique_tag) From 2eab9476cea23a295656d992f936095d7603773d Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Tue, 10 Feb 2026 02:09:04 +0000 Subject: [PATCH 2/8] Make sure we release the model regardless --- sdks/python/apache_beam/ml/inference/base.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 07128f8e8636..eb9626f8d475 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -2046,10 +2046,13 @@ def _run_inference(self, batch, inference_args): unique_tag = model model = multi_process_shared.MultiProcessShared( lambda: None, tag=model, always_proxy=True).acquire() - result_generator = (OOMProtectedFn(self._model_handler.run_inference))( - batch, model, inference_args) - if self.use_model_manager: - self._model.release_model(self._model_tag, unique_tag) + try: + result_generator = (OOMProtectedFn(self._model_handler.run_inference))( + batch, model, inference_args) + finally: + # Always release the model so that it can be reloaded. + if self.use_model_manager: + self._model.release_model(self._model_tag, unique_tag) except BaseException as e: if self._metrics_collector: self._metrics_collector.failed_batches_counter.inc() From bd8ca7ee0c154c1b68c956c43dd73c96b5934b13 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Tue, 10 Feb 2026 02:41:37 +0000 Subject: [PATCH 3/8] Add testing coverage --- .../apache_beam/ml/inference/base_test.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index feccd8b0f12e..5950ef4e8f22 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -2338,6 +2338,28 @@ def test_run_inference_impl_with_model_manager_args(self): }) assert_that(actual, equal_to(expected), label='assert:inferences') + @unittest.skipIf( + not try_import_model_manager(), 'Model Manager not available') + def test_run_inference_impl_with_model_manager_oom(self): + class OOMFakeModelHandler(SimpleFakeModelHandler): + def run_inference( + self, + batch: Sequence[int], + model: FakeModel, + inference_args=None) -> Iterable[int]: + if random.random() < 0.8: + raise MemoryError("Simulated OOM") + for example in batch: + yield model.predict(example) + + with self.assertRaises(Exception): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + pcoll = pipeline | 'start' >> beam.Create(examples) + actual = pcoll | base.RunInference( + OOMFakeModelHandler(), use_model_manager=True) + assert_that(actual, equal_to([2, 6, 4, 11]), label='assert:inferences') + if __name__ == '__main__': unittest.main() From ab0620d6e9496338465859b9d07666f52e4c41e6 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Tue, 10 Feb 2026 03:30:00 +0000 Subject: [PATCH 4/8] Lint and make logging optional --- .../apache_beam/ml/inference/base_test.py | 1 + .../apache_beam/ml/inference/model_manager.py | 51 ++++++++++++------- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 5950ef4e8f22..de8227604325 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -20,6 +20,7 @@ import multiprocessing import os import pickle +import random import sys import tempfile import time diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index bf7c6a43ba63..70efd6fb421d 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -176,14 +176,23 @@ class ResourceEstimator: individual models based on aggregate system memory readings and the configuration of active models at that time. """ - def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5): + def __init__( + self, + smoothing_factor: float = 0.2, + min_data_points: int = 5, + verbose_logging: bool = False): self.smoothing_factor = smoothing_factor self.min_data_points = min_data_points + self.verbose_logging = verbose_logging self.estimates: Dict[str, float] = {} self.history = defaultdict(lambda: deque(maxlen=20)) self.known_models = set() self._lock = threading.Lock() + def logging_info(self, message: str, *args): + if self.verbose_logging: + logger.info(message, *args) + def is_unknown(self, model_tag: str) -> bool: with self._lock: return model_tag not in self.estimates @@ -196,7 +205,7 @@ def set_initial_estimate(self, model_tag: str, cost: float): with self._lock: self.estimates[model_tag] = cost self.known_models.add(model_tag) - logger.info("Initial Profile for %s: %s MB", model_tag, cost) + self.logging_info("Initial Profile for %s: %s MB", model_tag, cost) def add_observation( self, active_snapshot: Dict[str, int], peak_memory: float): @@ -207,7 +216,7 @@ def add_observation( else: model_list = "\t- None" - logger.info( + self.logging_info( "Adding Observation:\n PeakMemory: %.1f MB\n Instances:\n%s", peak_memory, model_list) @@ -256,7 +265,7 @@ def _solve(self): # Not enough data to solve yet return - logger.info( + self.logging_info( "Solving with %s total observations for %s models.", len(A), len(unique)) @@ -280,9 +289,9 @@ def _solve(self): else: self.estimates[model] = calculated_cost - logger.info( + self.logging_info( "Updated Estimate for %s: %.1f MB", model, self.estimates[model]) - logger.info("System Bias: %s MB", bias) + self.logging_info("System Bias: %s MB", bias) except Exception as e: logger.error("Solver failed: %s", e) @@ -321,7 +330,8 @@ def __init__( eviction_cooldown_seconds: float = 10.0, min_model_copies: int = 1, wait_timeout_seconds: float = 300.0, - lock_timeout_seconds: float = 60.0): + lock_timeout_seconds: float = 60.0, + verbose_logging: bool = False): self._estimator = ResourceEstimator( min_data_points=min_data_points, smoothing_factor=smoothing_factor) @@ -333,6 +343,7 @@ def __init__( self._min_model_copies = min_model_copies self._wait_timeout_seconds = wait_timeout_seconds self._lock_timeout_seconds = lock_timeout_seconds + self._verbose_logging = verbose_logging # Resource State self._models = defaultdict(list) @@ -361,20 +372,24 @@ def __init__( self._monitor.start() + def logging_info(self, message: str, *args): + if self._verbose_logging: + logger.info(message, *args) + def all_models(self, tag) -> list[Any]: return self._models[tag] # Should hold _cv lock when calling def try_enter_isolation_mode(self, tag: str, ticket_num: int) -> bool: if self._total_active_jobs > 0: - logger.info( + self.logging_info( "Waiting to enter isolation: tag=%s ticket num=%s", tag, ticket_num) self._cv.wait(timeout=self._lock_timeout_seconds) # return False since we have waited and need to re-evaluate # in caller to make sure our priority is still valid. return False - logger.info("Unknown model %s detected. Flushing GPU.", tag) + self.logging_info("Unknown model %s detected. Flushing GPU.", tag) self._delete_all_models() self._isolation_mode = True @@ -412,7 +427,7 @@ def should_spawn_model(self, tag: str, ticket_num: int) -> bool: for _, instances in self._models.items(): total_model_count += len(instances) curr, _, _ = self._monitor.get_stats() - logger.info( + self.logging_info( "Waiting for resources to free up: " "tag=%s ticket num%s model count=%s " "idle count=%s resource usage=%.1f MB " @@ -462,7 +477,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: # SLOW PATH: Enqueue and wait for turn to acquire model, # with unknown models having priority and order enforced # by ticket number as FIFO. - logger.info( + self.logging_info( "Acquire Queued: tag=%s, priority=%d " "total models count=%s ticket num=%s", tag, @@ -484,7 +499,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: f"after {wait_time_elapsed:.1f} seconds.") if not self._wait_queue or self._wait_queue[ 0].ticket_num != ticket_num: - logger.info( + self.logging_info( "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num) self._wait_in_queue(my_ticket) continue @@ -520,7 +535,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: # Path B: Concurrent else: if self._isolation_mode: - logger.info( + self.logging_info( "Waiting due to isolation in progress: tag=%s ticket num%s", tag, ticket_num) @@ -596,7 +611,7 @@ def _try_grab_from_lru(self, tag: str) -> Any: self._total_active_jobs += 1 return target_instance - logger.info("No idle model found for tag: %s", tag) + self.logging_info("No idle model found for tag: %s", tag) return None def _evict_to_make_space( @@ -679,9 +694,9 @@ def _delete_instance(self, instance: Any): del instance def _perform_eviction(self, key: str, tag: str, instance: Any, score: int): - logger.info("Evicting Model: %s (Score %d)", tag, score) + self.logging_info("Evicting Model: %s (Score %d)", tag, score) curr, _, _ = self._monitor.get_stats() - logger.info("Resource Usage Before Eviction: %.1f MB", curr) + self.logging_info("Resource Usage Before Eviction: %.1f MB", curr) if key in self._idle_lru: del self._idle_lru[key] @@ -697,7 +712,7 @@ def _perform_eviction(self, key: str, tag: str, instance: Any, score: int): self._monitor.refresh() self._monitor.reset_peak() curr, _, _ = self._monitor.get_stats() - logger.info("Resource Usage After Eviction: %.1f MB", curr) + self.logging_info("Resource Usage After Eviction: %.1f MB", curr) def _spawn_new_model( self, @@ -707,7 +722,7 @@ def _spawn_new_model( est_cost: float) -> Any: try: with self._cv: - logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown) + self.logging_info("Loading Model: %s (Unknown: %s)", tag, is_unknown) baseline_snap, _, _ = self._monitor.get_stats() instance = loader_func() _, peak_during_load, _ = self._monitor.get_stats() From 6f8103b2ac92a192199a821a33d769524a6aa1b9 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Tue, 10 Feb 2026 03:35:49 +0000 Subject: [PATCH 5/8] Pass verbose logging setting from model manager to estimator --- sdks/python/apache_beam/ml/inference/model_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 70efd6fb421d..c35c1a5aea8e 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -334,7 +334,9 @@ def __init__( verbose_logging: bool = False): self._estimator = ResourceEstimator( - min_data_points=min_data_points, smoothing_factor=smoothing_factor) + min_data_points=min_data_points, + smoothing_factor=smoothing_factor, + verbose_logging=verbose_logging) self._monitor = monitor if monitor else GPUMonitor( poll_interval=poll_interval, peak_window_seconds=peak_window_seconds) self._slack_percentage = slack_percentage From 71e068471b7be3555c2a6ebddf1cc77c16988a49 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Tue, 10 Feb 2026 04:48:45 +0000 Subject: [PATCH 6/8] Lint --- sdks/python/apache_beam/ml/inference/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index eb9626f8d475..ef5d15264b5c 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -1343,6 +1343,7 @@ def __call__(self, *args, **kwargs): logging.warning("Caught CUDA OOM during operation. Cleaning memory.") try: import gc + import torch gc.collect() torch.cuda.empty_cache() From 3acb0423d4f418863eb60c48f86785ccb1a76523 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Tue, 10 Feb 2026 18:34:50 +0000 Subject: [PATCH 7/8] Prevent flakes on EOFError --- sdks/python/apache_beam/ml/inference/model_manager.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index c35c1a5aea8e..186611984df0 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -688,8 +688,12 @@ def _delete_instance(self, instance: Any): if isinstance(instance, str): # If the instance is a string, it's a uuid used # to retrieve the model from MultiProcessShared - multi_process_shared.MultiProcessShared( - lambda: "N/A", tag=instance).unsafe_hard_delete() + try: + multi_process_shared.MultiProcessShared( + lambda: "N/A", tag=instance).unsafe_hard_delete() + except (EOFError, OSError, BrokenPipeError): + # This can happen even in normal operation. + pass if hasattr(instance, 'mock_model_unsafe_hard_delete'): # Call the mock unsafe hard delete method for testing instance.mock_model_unsafe_hard_delete() From fcb49c45cd4f172fa6601ae6cccb0ab21aaae12c Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Wed, 11 Feb 2026 21:59:29 +0000 Subject: [PATCH 8/8] Enforce batch size --- sdks/python/apache_beam/ml/inference/base_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index de8227604325..8236ac5c1e5f 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -2353,6 +2353,9 @@ def run_inference( for example in batch: yield model.predict(example) + def batch_elements_kwargs(self): + return {'min_batch_size': 1, 'max_batch_size': 1} + with self.assertRaises(Exception): with TestPipeline() as pipeline: examples = [1, 5, 3, 10]