From d4267d12a052b0a29acfa05a171826d01483eead Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Wed, 4 Feb 2026 21:02:22 +0000 Subject: [PATCH 1/8] Update RunInference to work with model manager --- sdks/python/apache_beam/ml/inference/base.py | 100 ++++++++-- .../apache_beam/ml/inference/base_test.py | 78 +++++++- .../ml/inference/model_manager_it_test.py | 184 ++++++++++++++++++ 3 files changed, 337 insertions(+), 25 deletions(-) create mode 100644 sdks/python/apache_beam/ml/inference/model_manager_it_test.py diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index ad2e2f8d0e3c..e6a9a15b7a29 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -68,8 +68,10 @@ try: # pylint: disable=wrong-import-order, wrong-import-position import resource + from apache_beam.ml.inference.model_manager import ModelManager except ImportError: resource = None # type: ignore[assignment] + ModelManager = None # type: ignore[assignment] _NANOSECOND_TO_MILLISECOND = 1_000_000 _NANOSECOND_TO_MICROSECOND = 1_000 @@ -533,11 +535,12 @@ def request( raise NotImplementedError(type(self)) -class _ModelManager: +class _ModelHandlerManager: """ - A class for efficiently managing copies of multiple models. Will load a - single copy of each model into a multi_process_shared object and then - return a lookup key for that object. + A class for efficiently managing copies of multiple model handlers. + Will load a single copy of each model from the model handler into a + multi_process_shared object and then return a lookup key for that + object. Used for KeyedModelHandler only. """ def __init__(self, mh_map: dict[str, ModelHandler]): """ @@ -602,8 +605,9 @@ def load(self, key: str) -> _ModelLoadStats: def increment_max_models(self, increment: int): """ - Increments the number of models that this instance of a _ModelManager is - able to hold. If it is never called, no limit is imposed. + Increments the number of models that this instance of a + _ModelHandlerManager is able to hold. If it is never called, + no limit is imposed. Args: increment: the amount by which we are incrementing the number of models. """ @@ -656,7 +660,7 @@ def __init__( class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], ModelHandler[tuple[KeyT, ExampleT], tuple[KeyT, PredictionT], - Union[ModelT, _ModelManager]]): + Union[ModelT, _ModelHandlerManager]]): def __init__( self, unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT], @@ -809,15 +813,15 @@ def __init__( 'to exactly one model handler.') self._key_to_id_map[key] = keys[0] - def load_model(self) -> Union[ModelT, _ModelManager]: + def load_model(self) -> Union[ModelT, _ModelHandlerManager]: if self._single_model: return self._unkeyed.load_model() - return _ModelManager(self._id_to_mh_map) + return _ModelHandlerManager(self._id_to_mh_map) def run_inference( self, batch: Sequence[tuple[KeyT, ExampleT]], - model: Union[ModelT, _ModelManager], + model: Union[ModelT, _ModelHandlerManager], inference_args: Optional[dict[str, Any]] = None ) -> Iterable[tuple[KeyT, PredictionT]]: if self._single_model: @@ -919,7 +923,7 @@ def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): def update_model_paths( self, - model: Union[ModelT, _ModelManager], + model: Union[ModelT, _ModelHandlerManager], model_paths: list[KeyModelPathMapping[KeyT]] = None): # When there are many models, the keyed model handler is responsible for # reorganizing the model handlers into cohorts and telling the model @@ -1338,6 +1342,8 @@ def __init__( model_metadata_pcoll: beam.PCollection[ModelMetadata] = None, watch_model_pattern: Optional[str] = None, model_identifier: Optional[str] = None, + use_model_manager: bool = False, + model_manager_args: Optional[dict[str, Any]] = None, **kwargs): """ A transform that takes a PCollection of examples (or features) for use @@ -1378,6 +1384,8 @@ def __init__( self._exception_handling_timeout = None self._timeout = None self._watch_model_pattern = watch_model_pattern + self._use_model_manager = use_model_manager + self._model_manager_args = model_manager_args self._kwargs = kwargs # Generate a random tag to use for shared.py and multi_process_shared.py to # allow us to effectively disambiguate in multi-model settings. Only use @@ -1490,7 +1498,9 @@ def expand( self._clock, self._metrics_namespace, load_model_at_runtime, - self._model_tag), + self._model_tag, + self._use_model_manager, + self._model_manager_args), self._inference_args, beam.pvalue.AsSingleton( self._model_metadata_pcoll, @@ -1803,21 +1813,50 @@ def load_model_status( return shared.Shared().acquire(lambda: _ModelStatus(False), tag=tag) +class _ProxyLoader: + """ + A helper callable to wrap the loader for MultiProcessShared. + """ + def __init__(self, loader_func, model_tag): + self.loader_func = loader_func + self.model_tag = model_tag + + def __call__(self): + unique_tag = self.model_tag + '_' + uuid.uuid4().hex + # Ensure that each model loaded in a different process for parallelism + multi_process_shared.MultiProcessShared( + self.loader_func, tag=unique_tag, always_proxy=True, + spawn_process=True).acquire() + # Only return the tag to avoid pickling issues with the model itself. + return unique_tag + + class _SharedModelWrapper(): """A router class to map incoming calls to the correct model. This allows us to round robin calls to models sitting in different processes so that we can more efficiently use resources (e.g. GPUs). """ - def __init__(self, models: list[Any], model_tag: str): + def __init__( + self, + models: Union[list[Any], ModelManager], + model_tag: str, + loader_func: Callable[[], Any] = None): self.models = models - if len(models) > 1: + self.use_model_manager = not isinstance(models, list) + self.model_tag = model_tag + self.loader_func = loader_func + if not self.use_model_manager and len(models) > 1: self.model_router = multi_process_shared.MultiProcessShared( lambda: _ModelRoutingStrategy(), tag=f'{model_tag}_counter', always_proxy=True).acquire() def next_model(self): + if self.use_model_manager: + loader_wrapper = _ProxyLoader(self.loader_func, self.model_tag) + return self.models.acquire_model(self.model_tag, loader_wrapper) + if len(self.models) == 1: # Short circuit if there's no routing strategy needed in order to # avoid the cross-process call @@ -1825,9 +1864,19 @@ def next_model(self): return self.models[self.model_router.next_model_index(len(self.models))] + def release_model(self, model_tag: str, model: Any): + if self.use_model_manager: + self.models.release_model(model_tag, model) + def all_models(self): + if self.use_model_manager: + return self.models.all_models()[self.model_tag] return self.models + def force_reset(self): + if self.use_model_manager: + self.models.force_reset() + class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]): def __init__( @@ -1836,7 +1885,9 @@ def __init__( clock, metrics_namespace, load_model_at_runtime: bool = False, - model_tag: str = "RunInference"): + model_tag: str = "RunInference", + use_model_manager: bool = False, + model_manager_args: Optional[dict[str, Any]] = None): """A DoFn implementation generic to frameworks. Args: @@ -1860,6 +1911,8 @@ def __init__( # _cur_tag is the tag of the actually loaded model self._model_tag = model_tag self._cur_tag = model_tag + self.use_model_manager = use_model_manager + self._model_manager_args = model_manager_args or {} def _load_model( self, @@ -1894,7 +1947,15 @@ def load(): model_tag = side_input_model_path # Ensure the tag we're loading is valid, if not replace it with a valid tag self._cur_tag = self._model_metadata.get_valid_tag(model_tag) - if self._model_handler.share_model_across_processes(): + if self.use_model_manager: + logging.info("Using Model Manager to manage models automatically.") + model_manager = multi_process_shared.MultiProcessShared( + lambda: ModelManager(**self._model_manager_args), + tag='model_manager', + always_proxy=True).acquire() + model_wrapper = _SharedModelWrapper( + model_manager, self._cur_tag, self._model_handler.load_model) + elif self._model_handler.share_model_across_processes(): models = [] for copy_tag in _get_tags_for_copies(self._cur_tag, self._model_handler.model_copies()): @@ -1949,8 +2010,15 @@ def _run_inference(self, batch, inference_args): start_time = _to_microseconds(self._clock.time_ns()) try: model = self._model.next_model() + if isinstance(model, str): + # ModelManager with MultiProcessShared returns the model tag + unique_tag = model + model = multi_process_shared.MultiProcessShared( + lambda: None, tag=model, always_proxy=True).acquire() result_generator = self._model_handler.run_inference( batch, model, inference_args) + if self.use_model_manager: + self._model.release_model(self._model_tag, unique_tag) except BaseException as e: if self._metrics_collector: self._metrics_collector.failed_batches_counter.inc() diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 55784166ad5d..fd35ca2e3ff2 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -17,6 +17,7 @@ """Tests for apache_beam.ml.base.""" import math +import multiprocessing import os import pickle import sys @@ -1599,13 +1600,13 @@ def test_child_class_without_env_vars(self): actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars()) assert_that(actual, equal_to(expected), label='assert:inferences') - def test_model_manager_loads_shared_model(self): + def test_model_handler_manager_loads_shared_model(self): mhs = { 'key1': FakeModelHandler(state=1), 'key2': FakeModelHandler(state=2), 'key3': FakeModelHandler(state=3) } - mm = base._ModelManager(mh_map=mhs) + mm = base._ModelHandlerManager(mh_map=mhs) tag1 = mm.load('key1').model_tag # Use bad_mh's load function to make sure we're actually loading the # version already stored @@ -1623,12 +1624,12 @@ def test_model_manager_loads_shared_model(self): self.assertEqual(2, model2.predict(10)) self.assertEqual(3, model3.predict(10)) - def test_model_manager_evicts_models(self): + def test_model_handler_manager_evicts_models(self): mh1 = FakeModelHandler(state=1) mh2 = FakeModelHandler(state=2) mh3 = FakeModelHandler(state=3) mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3} - mm = base._ModelManager(mh_map=mhs) + mm = base._ModelHandlerManager(mh_map=mhs) mm.increment_max_models(2) tag1 = mm.load('key1').model_tag sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1) @@ -1667,10 +1668,10 @@ def test_model_manager_evicts_models(self): mh3.load_model, tag=tag3).acquire() self.assertEqual(8, model3.predict(10)) - def test_model_manager_evicts_models_after_update(self): + def test_model_handler_manager_evicts_models_after_update(self): mh1 = FakeModelHandler(state=1) mhs = {'key1': mh1} - mm = base._ModelManager(mh_map=mhs) + mm = base._ModelHandlerManager(mh_map=mhs) tag1 = mm.load('key1').model_tag sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1) model1 = sh1.acquire() @@ -1697,13 +1698,12 @@ def test_model_manager_evicts_models_after_update(self): self.assertEqual(6, model1.predict(10)) sh1.release(model1) - def test_model_manager_evicts_correct_num_of_models_after_being_incremented( - self): + def test_model_handler_manager_evicts_models_after_being_incremented(self): mh1 = FakeModelHandler(state=1) mh2 = FakeModelHandler(state=2) mh3 = FakeModelHandler(state=3) mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3} - mm = base._ModelManager(mh_map=mhs) + mm = base._ModelHandlerManager(mh_map=mhs) mm.increment_max_models(1) mm.increment_max_models(1) tag1 = mm.load('key1').model_tag @@ -2279,5 +2279,65 @@ def test_max_batch_duration_secs_only(self): self.assertEqual(kwargs, {'max_batch_duration_secs': 60}) +class SimpleFakeModelHanlder(base.ModelHandler[int, int, FakeModel]): + def load_model(self): + return FakeModel() + + def run_inference( + self, + batch: Sequence[int], + model: FakeModel, + inference_args=None) -> Iterable[int]: + for example in batch: + yield model.predict(example) + + +def try_import_model_manager(): + try: + # pylint: disable=unused-import + from apache_beam.ml.inference.model_manager import ModelManager + return True + except ImportError: + return False + + +class ModelManagerTest(unittest.TestCase): + """Tests for RunInference with Model Manager integration.""" + def tearDown(self): + for p in multiprocessing.active_children(): + p.terminate() + p.join() + + @unittest.skipIf( + not try_import_model_manager(), 'Model Manager not available') + def test_run_inference_impl_with_model_manager(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + expected = [example + 1 for example in examples] + pcoll = pipeline | 'start' >> beam.Create(examples) + actual = pcoll | base.RunInference( + SimpleFakeModelHanlder(), use_model_manager=True) + assert_that(actual, equal_to(expected), label='assert:inferences') + + @unittest.skipIf( + not try_import_model_manager(), 'Model Manager not available') + def test_run_inference_impl_with_model_manager_args(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + expected = [example + 1 for example in examples] + pcoll = pipeline | 'start' >> beam.Create(examples) + actual = pcoll | base.RunInference( + SimpleFakeModelHanlder(), + use_model_manager=True, + model_manager_args={ + 'slack_percentage': 0.2, + 'poll_interval': 1.0, + 'peak_window_seconds': 10.0, + 'min_data_points': 10, + 'smoothing_factor': 0.5 + }) + assert_that(actual, equal_to(expected), label='assert:inferences') + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py new file mode 100644 index 000000000000..609de0606a14 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -0,0 +1,184 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import apache_beam as beam +from apache_beam.ml.inference.base import RunInference +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that, equal_to + +# pylint: disable=ungrouped-imports +try: + import torch + from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler +except ImportError as e: + raise unittest.SkipTest( + "HuggingFace model handler dependencies are not installed") + + +class HuggingFaceGpuTest(unittest.TestCase): + + # Skips the test if you run it on a machine without a GPU + @unittest.skipIf( + not torch.cuda.is_available(), "No GPU detected, skipping GPU test") + def test_sentiment_analysis_on_gpu_large_input(self): + """ + Runs inference on a GPU (device=0) with a larger set of inputs. + """ + model_handler = HuggingFacePipelineModelHandler( + task="sentiment-analysis", + model="distilbert-base-uncased-finetuned-sst-2-english", + device=0, + inference_args={"batch_size": 4}) + DUPLICATE_FACTOR = 2 + + with TestPipeline() as pipeline: + examples = [ + "I absolutely love this product, it's a game changer!", + "This is the worst experience I have ever had.", + "The weather is okay, but I wish it were sunnier.", + "Apache Beam makes parallel processing incredibly efficient.", + "I am extremely disappointed with the service.", + "Logic and reason are the pillars of good debugging.", + "I'm so happy today!", + "This error message is confusing and unhelpful.", + "The movie was fantastic and the acting was superb.", + "I hate waiting in line for so long." + ] * DUPLICATE_FACTOR + + pcoll = pipeline | 'CreateInputs' >> beam.Create(examples) + + predictions = pcoll | 'RunInference' >> RunInference( + model_handler, use_model_manager=True) + + actual_labels = predictions | beam.Map(lambda x: x.inference['label']) + + expected_labels = [ + 'POSITIVE', # "love this product" + 'NEGATIVE', # "worst experience" + 'NEGATIVE', # "weather is okay, but..." + 'POSITIVE', # "incredibly efficient" + 'NEGATIVE', # "disappointed" + 'POSITIVE', # "pillars of good debugging" + 'POSITIVE', # "so happy" + 'NEGATIVE', # "confusing and unhelpful" + 'POSITIVE', # "fantastic" + 'NEGATIVE' # "hate waiting" + ] * DUPLICATE_FACTOR + + assert_that( + actual_labels, equal_to(expected_labels), label='CheckPredictions') + + @unittest.skipIf(not torch.cuda.is_available(), "No GPU detected") + def test_sentiment_analysis_large_roberta_gpu(self): + """ + Runs inference using a Large architecture (RoBERTa-Large, ~355M params). + This tests if the GPU can handle larger weights and requires more VRAM. + """ + + model_handler = HuggingFacePipelineModelHandler( + task="sentiment-analysis", + model="Siebert/sentiment-roberta-large-english", + device=0, + inference_args={"batch_size": 2}) + + DUPLICATE_FACTOR = 2 + + with TestPipeline() as pipeline: + examples = [ + "I absolutely love this product, it's a game changer!", + "This is the worst experience I have ever had.", + "Apache Beam scales effortlessly to massive datasets.", + "I am somewhat annoyed by the delay.", + "The nuanced performance of this large model is impressive.", + "I regret buying this immediately.", + "The sunset looks beautiful tonight.", + "This documentation is sparse and misleading.", + "Winning the championship felt surreal.", + "I'm feeling very neutral about this whole situation." + ] * DUPLICATE_FACTOR + + pcoll = pipeline | 'CreateInputs' >> beam.Create(examples) + predictions = pcoll | 'RunInference' >> RunInference( + model_handler, use_model_manager=True) + actual_labels = predictions | beam.Map(lambda x: x.inference['label']) + + expected_labels = [ + 'POSITIVE', # love + 'NEGATIVE', # worst + 'POSITIVE', # scales effortlessly + 'NEGATIVE', # annoyed + 'POSITIVE', # impressive + 'NEGATIVE', # regret + 'POSITIVE', # beautiful + 'NEGATIVE', # misleading + 'POSITIVE', # surreal + 'NEGATIVE' # "neutral" + ] * DUPLICATE_FACTOR + + assert_that( + actual_labels, + equal_to(expected_labels), + label='CheckPredictionsLarge') + + @unittest.skipIf(not torch.cuda.is_available(), "No GPU detected") + def test_parallel_inference_branches(self): + """ + Tests a branching pipeline where one input source feeds two + RunInference transforms running in parallel. + + Topology: + [ Input Data ] + | + +--------+--------+ + | | + [ Translation ] [ Sentiment ] + """ + + translator_handler = HuggingFacePipelineModelHandler( + task="translation_en_to_es", + model="Helsinki-NLP/opus-mt-en-es", + device=0, + inference_args={"batch_size": 8}) + sentiment_handler = HuggingFacePipelineModelHandler( + task="sentiment-analysis", + model="nlptown/bert-base-multilingual-uncased-sentiment", + device=0, + inference_args={"batch_size": 8}) + base_examples = [ + "I love this product.", + "This is terrible.", + "Hello world.", + "The service was okay.", + "I am extremely angry." + ] + MULTIPLIER = 10 + examples = base_examples * MULTIPLIER + + with TestPipeline() as pipeline: + inputs = pipeline | 'CreateInputs' >> beam.Create(examples) + _ = ( + inputs + | 'RunTranslation' >> RunInference( + translator_handler, use_model_manager=True) + | 'ExtractSpanish' >> + beam.Map(lambda x: x.inference['translation_text'])) + _ = ( + inputs + | 'RunSentiment' >> RunInference( + sentiment_handler, use_model_manager=True) + | 'ExtractLabel' >> beam.Map(lambda x: x.inference['label'])) From 68690217a602984eb00963687a5fe1cef39ec16e Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Wed, 4 Feb 2026 23:26:22 +0000 Subject: [PATCH 2/8] Fix lint --- .../python/apache_beam/ml/inference/model_manager_it_test.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py index 609de0606a14..49b7a1373b94 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -16,14 +16,17 @@ # import unittest + import apache_beam as beam from apache_beam.ml.inference.base import RunInference from apache_beam.testing.test_pipeline import TestPipeline -from apache_beam.testing.util import assert_that, equal_to +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to # pylint: disable=ungrouped-imports try: import torch + from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler except ImportError as e: raise unittest.SkipTest( From cd006474acb9a06dbdfd1d8ba92b6d57d71f8974 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Wed, 4 Feb 2026 23:49:44 +0000 Subject: [PATCH 3/8] More lint --- sdks/python/apache_beam/ml/inference/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index e6a9a15b7a29..0dc7197b3b05 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -68,6 +68,7 @@ try: # pylint: disable=wrong-import-order, wrong-import-position import resource + from apache_beam.ml.inference.model_manager import ModelManager except ImportError: resource = None # type: ignore[assignment] From fb3e608bcaabc2d72d750c2eeff1924d42526677 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Thu, 5 Feb 2026 17:51:45 +0000 Subject: [PATCH 4/8] Add unittest main call --- sdks/python/apache_beam/ml/inference/model_manager_it_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py index 49b7a1373b94..eaa645b12166 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -185,3 +185,7 @@ def test_parallel_inference_branches(self): | 'RunSentiment' >> RunInference( sentiment_handler, use_model_manager=True) | 'ExtractLabel' >> beam.Map(lambda x: x.inference['label'])) + + +if __name__ == "__main__": + unittest.main() From 0de1c782de148819e765deb66fc30c2b8d718799 Mon Sep 17 00:00:00 2001 From: "RuiLong J." Date: Fri, 6 Feb 2026 13:08:53 -0800 Subject: [PATCH 5/8] Update sdks/python/apache_beam/ml/inference/base.py Co-authored-by: Danny McCormick --- sdks/python/apache_beam/ml/inference/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 0dc7197b3b05..ebfa8a837a73 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -1842,7 +1842,7 @@ def __init__( self, models: Union[list[Any], ModelManager], model_tag: str, - loader_func: Callable[[], Any] = None): + loader_func: Optional[Callable[[], Any]] = None): self.models = models self.use_model_manager = not isinstance(models, list) self.model_tag = model_tag From e60167a04e9546c7a16320a5d06565ea98dbf9de Mon Sep 17 00:00:00 2001 From: "RuiLong J." Date: Fri, 6 Feb 2026 13:09:10 -0800 Subject: [PATCH 6/8] Update sdks/python/apache_beam/ml/inference/base_test.py Co-authored-by: Danny McCormick --- sdks/python/apache_beam/ml/inference/base_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index fd35ca2e3ff2..273f675afc52 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -2279,7 +2279,7 @@ def test_max_batch_duration_secs_only(self): self.assertEqual(kwargs, {'max_batch_duration_secs': 60}) -class SimpleFakeModelHanlder(base.ModelHandler[int, int, FakeModel]): +class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]): def load_model(self): return FakeModel() From b22b3e4668f34be6c9c547733cee978e5f175cc6 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 6 Feb 2026 21:13:31 +0000 Subject: [PATCH 7/8] Add some comments explaining the model loading logistics --- sdks/python/apache_beam/ml/inference/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index ebfa8a837a73..1c3f0918bafd 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -1823,6 +1823,11 @@ def __init__(self, loader_func, model_tag): self.model_tag = model_tag def __call__(self): + # Generate a unique tag for the model being loaded so that + # we will have unique instances of the model in multi_process_shared + # space instead of reusing the same instance over. The instance will + # be initialized and left running as a separate process, which then + # can be grabbed again using the unique tag if needed during inference. unique_tag = self.model_tag + '_' + uuid.uuid4().hex # Ensure that each model loaded in a different process for parallelism multi_process_shared.MultiProcessShared( From 43cf365f45da2785c74fcb8ac32ee820869af661 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 6 Feb 2026 21:55:39 +0000 Subject: [PATCH 8/8] Update name in tests as well --- sdks/python/apache_beam/ml/inference/base_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 273f675afc52..feccd8b0f12e 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -2316,7 +2316,7 @@ def test_run_inference_impl_with_model_manager(self): expected = [example + 1 for example in examples] pcoll = pipeline | 'start' >> beam.Create(examples) actual = pcoll | base.RunInference( - SimpleFakeModelHanlder(), use_model_manager=True) + SimpleFakeModelHandler(), use_model_manager=True) assert_that(actual, equal_to(expected), label='assert:inferences') @unittest.skipIf( @@ -2327,7 +2327,7 @@ def test_run_inference_impl_with_model_manager_args(self): expected = [example + 1 for example in examples] pcoll = pipeline | 'start' >> beam.Create(examples) actual = pcoll | base.RunInference( - SimpleFakeModelHanlder(), + SimpleFakeModelHandler(), use_model_manager=True, model_manager_args={ 'slack_percentage': 0.2,