Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 90 additions & 16 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,11 @@
try:
# pylint: disable=wrong-import-order, wrong-import-position
import resource

from apache_beam.ml.inference.model_manager import ModelManager
except ImportError:
resource = None # type: ignore[assignment]
ModelManager = None # type: ignore[assignment]

_NANOSECOND_TO_MILLISECOND = 1_000_000
_NANOSECOND_TO_MICROSECOND = 1_000
Expand Down Expand Up @@ -533,11 +536,12 @@ def request(
raise NotImplementedError(type(self))


class _ModelManager:
class _ModelHandlerManager:
"""
A class for efficiently managing copies of multiple models. Will load a
single copy of each model into a multi_process_shared object and then
return a lookup key for that object.
A class for efficiently managing copies of multiple model handlers.
Will load a single copy of each model from the model handler into a
multi_process_shared object and then return a lookup key for that
object. Used for KeyedModelHandler only.
"""
def __init__(self, mh_map: dict[str, ModelHandler]):
"""
Expand Down Expand Up @@ -602,8 +606,9 @@ def load(self, key: str) -> _ModelLoadStats:

def increment_max_models(self, increment: int):
"""
Increments the number of models that this instance of a _ModelManager is
able to hold. If it is never called, no limit is imposed.
Increments the number of models that this instance of a
_ModelHandlerManager is able to hold. If it is never called,
no limit is imposed.
Args:
increment: the amount by which we are incrementing the number of models.
"""
Expand Down Expand Up @@ -656,7 +661,7 @@ def __init__(
class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
ModelHandler[tuple[KeyT, ExampleT],
tuple[KeyT, PredictionT],
Union[ModelT, _ModelManager]]):
Union[ModelT, _ModelHandlerManager]]):
def __init__(
self,
unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT],
Expand Down Expand Up @@ -809,15 +814,15 @@ def __init__(
'to exactly one model handler.')
self._key_to_id_map[key] = keys[0]

def load_model(self) -> Union[ModelT, _ModelManager]:
def load_model(self) -> Union[ModelT, _ModelHandlerManager]:
if self._single_model:
return self._unkeyed.load_model()
return _ModelManager(self._id_to_mh_map)
return _ModelHandlerManager(self._id_to_mh_map)

def run_inference(
self,
batch: Sequence[tuple[KeyT, ExampleT]],
model: Union[ModelT, _ModelManager],
model: Union[ModelT, _ModelHandlerManager],
inference_args: Optional[dict[str, Any]] = None
) -> Iterable[tuple[KeyT, PredictionT]]:
if self._single_model:
Expand Down Expand Up @@ -919,7 +924,7 @@ def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):

def update_model_paths(
self,
model: Union[ModelT, _ModelManager],
model: Union[ModelT, _ModelHandlerManager],
model_paths: list[KeyModelPathMapping[KeyT]] = None):
# When there are many models, the keyed model handler is responsible for
# reorganizing the model handlers into cohorts and telling the model
Expand Down Expand Up @@ -1338,6 +1343,8 @@ def __init__(
model_metadata_pcoll: beam.PCollection[ModelMetadata] = None,
watch_model_pattern: Optional[str] = None,
model_identifier: Optional[str] = None,
use_model_manager: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @AMOOOMA , would you mind following up on this PR to add docstrings to the new RunInference params added here? Thanks a lot!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it sounds like we might want to feature this functionality in CHANGES.md

model_manager_args: Optional[dict[str, Any]] = None,
**kwargs):
"""
A transform that takes a PCollection of examples (or features) for use
Expand Down Expand Up @@ -1378,6 +1385,8 @@ def __init__(
self._exception_handling_timeout = None
self._timeout = None
self._watch_model_pattern = watch_model_pattern
self._use_model_manager = use_model_manager
self._model_manager_args = model_manager_args
self._kwargs = kwargs
# Generate a random tag to use for shared.py and multi_process_shared.py to
# allow us to effectively disambiguate in multi-model settings. Only use
Expand Down Expand Up @@ -1490,7 +1499,9 @@ def expand(
self._clock,
self._metrics_namespace,
load_model_at_runtime,
self._model_tag),
self._model_tag,
self._use_model_manager,
self._model_manager_args),
self._inference_args,
beam.pvalue.AsSingleton(
self._model_metadata_pcoll,
Expand Down Expand Up @@ -1803,31 +1814,75 @@ def load_model_status(
return shared.Shared().acquire(lambda: _ModelStatus(False), tag=tag)


class _ProxyLoader:
"""
A helper callable to wrap the loader for MultiProcessShared.
"""
def __init__(self, loader_func, model_tag):
self.loader_func = loader_func
self.model_tag = model_tag

def __call__(self):
# Generate a unique tag for the model being loaded so that
# we will have unique instances of the model in multi_process_shared
# space instead of reusing the same instance over. The instance will
# be initialized and left running as a separate process, which then
# can be grabbed again using the unique tag if needed during inference.
unique_tag = self.model_tag + '_' + uuid.uuid4().hex
# Ensure that each model loaded in a different process for parallelism
multi_process_shared.MultiProcessShared(
self.loader_func, tag=unique_tag, always_proxy=True,
spawn_process=True).acquire()
# Only return the tag to avoid pickling issues with the model itself.
return unique_tag


class _SharedModelWrapper():
"""A router class to map incoming calls to the correct model.

This allows us to round robin calls to models sitting in different
processes so that we can more efficiently use resources (e.g. GPUs).
"""
def __init__(self, models: list[Any], model_tag: str):
def __init__(
self,
models: Union[list[Any], ModelManager],
model_tag: str,
loader_func: Optional[Callable[[], Any]] = None):
self.models = models
if len(models) > 1:
self.use_model_manager = not isinstance(models, list)
self.model_tag = model_tag
self.loader_func = loader_func
if not self.use_model_manager and len(models) > 1:
self.model_router = multi_process_shared.MultiProcessShared(
lambda: _ModelRoutingStrategy(),
tag=f'{model_tag}_counter',
always_proxy=True).acquire()

def next_model(self):
if self.use_model_manager:
loader_wrapper = _ProxyLoader(self.loader_func, self.model_tag)
return self.models.acquire_model(self.model_tag, loader_wrapper)

if len(self.models) == 1:
# Short circuit if there's no routing strategy needed in order to
# avoid the cross-process call
return self.models[0]

return self.models[self.model_router.next_model_index(len(self.models))]

def release_model(self, model_tag: str, model: Any):
if self.use_model_manager:
self.models.release_model(model_tag, model)

def all_models(self):
if self.use_model_manager:
return self.models.all_models()[self.model_tag]
return self.models

def force_reset(self):
if self.use_model_manager:
self.models.force_reset()


class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
def __init__(
Expand All @@ -1836,7 +1891,9 @@ def __init__(
clock,
metrics_namespace,
load_model_at_runtime: bool = False,
model_tag: str = "RunInference"):
model_tag: str = "RunInference",
use_model_manager: bool = False,
model_manager_args: Optional[dict[str, Any]] = None):
"""A DoFn implementation generic to frameworks.

Args:
Expand All @@ -1860,6 +1917,8 @@ def __init__(
# _cur_tag is the tag of the actually loaded model
self._model_tag = model_tag
self._cur_tag = model_tag
self.use_model_manager = use_model_manager
self._model_manager_args = model_manager_args or {}

def _load_model(
self,
Expand Down Expand Up @@ -1894,7 +1953,15 @@ def load():
model_tag = side_input_model_path
# Ensure the tag we're loading is valid, if not replace it with a valid tag
self._cur_tag = self._model_metadata.get_valid_tag(model_tag)
if self._model_handler.share_model_across_processes():
if self.use_model_manager:
logging.info("Using Model Manager to manage models automatically.")
model_manager = multi_process_shared.MultiProcessShared(
lambda: ModelManager(**self._model_manager_args),
tag='model_manager',
always_proxy=True).acquire()
model_wrapper = _SharedModelWrapper(
model_manager, self._cur_tag, self._model_handler.load_model)
elif self._model_handler.share_model_across_processes():
models = []
for copy_tag in _get_tags_for_copies(self._cur_tag,
self._model_handler.model_copies()):
Expand Down Expand Up @@ -1949,8 +2016,15 @@ def _run_inference(self, batch, inference_args):
start_time = _to_microseconds(self._clock.time_ns())
try:
model = self._model.next_model()
if isinstance(model, str):
# ModelManager with MultiProcessShared returns the model tag
unique_tag = model
model = multi_process_shared.MultiProcessShared(
lambda: None, tag=model, always_proxy=True).acquire()
result_generator = self._model_handler.run_inference(
batch, model, inference_args)
if self.use_model_manager:
self._model.release_model(self._model_tag, unique_tag)
except BaseException as e:
if self._metrics_collector:
self._metrics_collector.failed_batches_counter.inc()
Expand Down
78 changes: 69 additions & 9 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

"""Tests for apache_beam.ml.base."""
import math
import multiprocessing
import os
import pickle
import sys
Expand Down Expand Up @@ -1599,13 +1600,13 @@ def test_child_class_without_env_vars(self):
actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars())
assert_that(actual, equal_to(expected), label='assert:inferences')

def test_model_manager_loads_shared_model(self):
def test_model_handler_manager_loads_shared_model(self):
mhs = {
'key1': FakeModelHandler(state=1),
'key2': FakeModelHandler(state=2),
'key3': FakeModelHandler(state=3)
}
mm = base._ModelManager(mh_map=mhs)
mm = base._ModelHandlerManager(mh_map=mhs)
tag1 = mm.load('key1').model_tag
# Use bad_mh's load function to make sure we're actually loading the
# version already stored
Expand All @@ -1623,12 +1624,12 @@ def test_model_manager_loads_shared_model(self):
self.assertEqual(2, model2.predict(10))
self.assertEqual(3, model3.predict(10))

def test_model_manager_evicts_models(self):
def test_model_handler_manager_evicts_models(self):
mh1 = FakeModelHandler(state=1)
mh2 = FakeModelHandler(state=2)
mh3 = FakeModelHandler(state=3)
mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
mm = base._ModelManager(mh_map=mhs)
mm = base._ModelHandlerManager(mh_map=mhs)
mm.increment_max_models(2)
tag1 = mm.load('key1').model_tag
sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
Expand Down Expand Up @@ -1667,10 +1668,10 @@ def test_model_manager_evicts_models(self):
mh3.load_model, tag=tag3).acquire()
self.assertEqual(8, model3.predict(10))

def test_model_manager_evicts_models_after_update(self):
def test_model_handler_manager_evicts_models_after_update(self):
mh1 = FakeModelHandler(state=1)
mhs = {'key1': mh1}
mm = base._ModelManager(mh_map=mhs)
mm = base._ModelHandlerManager(mh_map=mhs)
tag1 = mm.load('key1').model_tag
sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
model1 = sh1.acquire()
Expand All @@ -1697,13 +1698,12 @@ def test_model_manager_evicts_models_after_update(self):
self.assertEqual(6, model1.predict(10))
sh1.release(model1)

def test_model_manager_evicts_correct_num_of_models_after_being_incremented(
self):
def test_model_handler_manager_evicts_models_after_being_incremented(self):
mh1 = FakeModelHandler(state=1)
mh2 = FakeModelHandler(state=2)
mh3 = FakeModelHandler(state=3)
mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
mm = base._ModelManager(mh_map=mhs)
mm = base._ModelHandlerManager(mh_map=mhs)
mm.increment_max_models(1)
mm.increment_max_models(1)
tag1 = mm.load('key1').model_tag
Expand Down Expand Up @@ -2279,5 +2279,65 @@ def test_max_batch_duration_secs_only(self):
self.assertEqual(kwargs, {'max_batch_duration_secs': 60})


class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]):
def load_model(self):
return FakeModel()

def run_inference(
self,
batch: Sequence[int],
model: FakeModel,
inference_args=None) -> Iterable[int]:
for example in batch:
yield model.predict(example)


def try_import_model_manager():
try:
# pylint: disable=unused-import
from apache_beam.ml.inference.model_manager import ModelManager
return True
except ImportError:
return False


class ModelManagerTest(unittest.TestCase):
"""Tests for RunInference with Model Manager integration."""
def tearDown(self):
for p in multiprocessing.active_children():
p.terminate()
p.join()

@unittest.skipIf(
not try_import_model_manager(), 'Model Manager not available')
def test_run_inference_impl_with_model_manager(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
expected = [example + 1 for example in examples]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(
SimpleFakeModelHandler(), use_model_manager=True)
assert_that(actual, equal_to(expected), label='assert:inferences')

@unittest.skipIf(
not try_import_model_manager(), 'Model Manager not available')
def test_run_inference_impl_with_model_manager_args(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
expected = [example + 1 for example in examples]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(
SimpleFakeModelHandler(),
use_model_manager=True,
model_manager_args={
'slack_percentage': 0.2,
'poll_interval': 1.0,
'peak_window_seconds': 10.0,
'min_data_points': 10,
'smoothing_factor': 0.5
})
assert_that(actual, equal_to(expected), label='assert:inferences')


if __name__ == '__main__':
unittest.main()
Loading
Loading