From c0862299e031534b60e60f314a199e1b356aa6f6 Mon Sep 17 00:00:00 2001 From: Jinen Setpal Date: Tue, 7 Jan 2025 20:18:59 -0400 Subject: [PATCH 1/2] disentangled mechanism to obtain mlflow model --- dagshub/data_engine/model/query_result.py | 29 ++------------ dagshub/mlflow/__init__.py | 3 +- dagshub/mlflow/get_model.py | 48 +++++++++++++++++++++++ 3 files changed, 53 insertions(+), 27 deletions(-) create mode 100644 dagshub/mlflow/get_model.py diff --git a/dagshub/data_engine/model/query_result.py b/dagshub/data_engine/model/query_result.py index 061e6c80..6c1cf41f 100644 --- a/dagshub/data_engine/model/query_result.py +++ b/dagshub/data_engine/model/query_result.py @@ -10,7 +10,6 @@ import json import os import os.path -import importlib import dacite import dagshub_annotation_converter.converters.yolo @@ -19,16 +18,15 @@ from dagshub_annotation_converter.formats.yolo.categories import Categories from dagshub_annotation_converter.formats.yolo.common import ir_mapping from dagshub_annotation_converter.ir.image import IRImageAnnotationBase +from dagshub.mlflow import get_mlflow_model from pydantic import ValidationError -from dagshub.auth import get_token from dagshub.common import config +from dagshub.common.util import lazy_load from dagshub.common.analytics import send_analytics_event -from dagshub.common.api import UserAPI from dagshub.common.download import download_files from dagshub.common.helpers import sizeof_fmt, prompt_user, log_message from dagshub.common.rich_util import get_rich_progress -from dagshub.common.util import lazy_load, multi_urljoin from dagshub.data_engine.annotation import MetadataAnnotations from dagshub.data_engine.annotation.voxel_conversion import ( add_voxel_annotations, @@ -566,7 +564,6 @@ def predict_with_mlflow_model( Args: repo: repository to extract the model from - name: name of the model in the repository's MLflow registry. host: address of the DagsHub instance with the repo to load the model from. Set it if the model is hosted on a different DagsHub instance than the datasource. @@ -581,27 +578,7 @@ def predict_with_mlflow_model( if not host: host = self.datasource.source.repoApi.host - prev_uri = mlflow.get_tracking_uri() - mlflow.set_tracking_uri(multi_urljoin(host, f"{repo}.mlflow")) - token = get_token(host=host) - os.environ["MLFLOW_TRACKING_USERNAME"] = UserAPI.get_user_from_token(token, host=host).username - os.environ["MLFLOW_TRACKING_PASSWORD"] = token - model_uri = f"models:/{name}/{version}" - - try: - loader_module = mlflow.models.get_model_info(model_uri).flavors["python_function"]["loader_module"] - loader_module_elems = loader_module.split(".") - if loader_module_elems[-1] == "model": - loader_module_elems.pop() - loader_module = ".".join(loader_module_elems) - loader = mlflow.pyfunc if "pyfunc" in loader_module_elems else importlib.import_module(loader_module) - model = loader.load_model(model_uri) - finally: - mlflow.set_tracking_uri(prev_uri) - - if "torch" in loader_module: - model.predict = model.__call__ - + model = get_mlflow_model(repo, name, host, version) return self.generate_predictions(lambda x: post_hook(model.predict(pre_hook(x))), batch_size, log_to_field) def get_annotations(self, **kwargs) -> "QueryResult": diff --git a/dagshub/mlflow/__init__.py b/dagshub/mlflow/__init__.py index 75ea415d..db8e4f05 100644 --- a/dagshub/mlflow/__init__.py +++ b/dagshub/mlflow/__init__.py @@ -1,3 +1,4 @@ from .patch import patch_mlflow, unpatch_mlflow +from .get_model import get_mlflow_model -__all__ = [patch_mlflow.__name__, unpatch_mlflow.__name__] +__all__ = [patch_mlflow.__name__, unpatch_mlflow.__name__, get_mlflow_model.__name__] diff --git a/dagshub/mlflow/get_model.py b/dagshub/mlflow/get_model.py new file mode 100644 index 00000000..fc820f83 --- /dev/null +++ b/dagshub/mlflow/get_model.py @@ -0,0 +1,48 @@ +from dagshub.common.util import lazy_load, multi_urljoin +from dagshub.common.api import UserAPI +from dagshub.auth import get_token + +from typing import TYPE_CHECKING +import importlib +import os + +if TYPE_CHECKING: + import mlflow +else: + mlflow = lazy_load("mlflow") + + +def get_mlflow_model(repo, name, host, version): + """ + Get MLflow Model from the specified DagsHub repository, \ + patched to forward the primary prediction function to `predict`. + + Args: + repo: repository to extract the model from + name: name of the model in the repository's MLflow registry. + host: address of the DagsHub instance with the repo to load the model from. + Set it if the model is hosted on a different DagsHub instance than the datasource. + version: version of the model in the mlflow registry. + """ + prev_uri = mlflow.get_tracking_uri() + mlflow.set_tracking_uri(multi_urljoin(host, f"{repo}.mlflow")) + token = get_token(host=host) + os.environ["MLFLOW_TRACKING_USERNAME"] = UserAPI.get_user_from_token(token, host=host).username + os.environ["MLFLOW_TRACKING_PASSWORD"] = token + model_uri = f"models:/{name}/{version}" + + try: + loader_module = mlflow.models.get_model_info(model_uri).flavors["python_function"]["loader_module"] + loader_module_elems = loader_module.split(".") + if loader_module_elems[-1] == "model": + loader_module_elems.pop() + loader_module = ".".join(loader_module_elems) + loader = mlflow.pyfunc if "pyfunc" in loader_module_elems else importlib.import_module(loader_module) + model = loader.load_model(model_uri) + finally: + mlflow.set_tracking_uri(prev_uri) + + if "torch" in loader_module: + model.predict = model.__call__ + + return model From 568c12ccd9340db1748c62b826f374778dfe6784 Mon Sep 17 00:00:00 2001 From: Jinen Setpal Date: Wed, 8 Jan 2025 19:37:53 -0400 Subject: [PATCH 2/2] added typehints, defaults, improved docstring --- dagshub/data_engine/model/query_result.py | 2 +- dagshub/mlflow/get_model.py | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/dagshub/data_engine/model/query_result.py b/dagshub/data_engine/model/query_result.py index 6c1cf41f..ed923dd8 100644 --- a/dagshub/data_engine/model/query_result.py +++ b/dagshub/data_engine/model/query_result.py @@ -541,7 +541,7 @@ def predict_with_mlflow_model( repo: str, name: str, host: Optional[str] = None, - version: str = "latest", + version: Optional[str] = "latest", pre_hook: Callable[[List[str]], Any] = identity_func, post_hook: Callable[[Any], Any] = identity_func, batch_size: int = 1, diff --git a/dagshub/mlflow/get_model.py b/dagshub/mlflow/get_model.py index fc820f83..e88ce28c 100644 --- a/dagshub/mlflow/get_model.py +++ b/dagshub/mlflow/get_model.py @@ -1,8 +1,9 @@ from dagshub.common.util import lazy_load, multi_urljoin from dagshub.common.api import UserAPI from dagshub.auth import get_token +from dagshub.common import config -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import importlib import os @@ -12,10 +13,15 @@ mlflow = lazy_load("mlflow") -def get_mlflow_model(repo, name, host, version): +def get_mlflow_model(repo: str, name: str, version: Optional[str] = "latest", host: Optional[str] = None): """ - Get MLflow Model from the specified DagsHub repository, \ - patched to forward the primary prediction function to `predict`. + Load an MLflow registered model from the specified DagsHub repository. + Call `model.predict()` on the loaded model to infer on your data. + + Example:: + + model = dagshub.mlflow.get_mlflow_model('jinensetpal/COCO_1K', 'yolov8-seg', 'latest', 'https://dagshub.com') + model.predict("https://www.dagshub.com/jinensetpal/COCO_1K/raw/68a54c3bdc84582af3e42fb6f03507146f176378/data/images/train/000000000009.jpg") Args: repo: repository to extract the model from @@ -24,8 +30,11 @@ def get_mlflow_model(repo, name, host, version): Set it if the model is hosted on a different DagsHub instance than the datasource. version: version of the model in the mlflow registry. """ + host = host or config.host + prev_uri = mlflow.get_tracking_uri() mlflow.set_tracking_uri(multi_urljoin(host, f"{repo}.mlflow")) + token = get_token(host=host) os.environ["MLFLOW_TRACKING_USERNAME"] = UserAPI.get_user_from_token(token, host=host).username os.environ["MLFLOW_TRACKING_PASSWORD"] = token