Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 4 additions & 27 deletions dagshub/data_engine/model/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import json
import os
import os.path
import importlib

import dacite
import dagshub_annotation_converter.converters.yolo
Expand All @@ -19,16 +18,15 @@
from dagshub_annotation_converter.formats.yolo.categories import Categories
from dagshub_annotation_converter.formats.yolo.common import ir_mapping
from dagshub_annotation_converter.ir.image import IRImageAnnotationBase
from dagshub.mlflow import get_mlflow_model
from pydantic import ValidationError

from dagshub.auth import get_token
from dagshub.common import config
from dagshub.common.util import lazy_load
from dagshub.common.analytics import send_analytics_event
from dagshub.common.api import UserAPI
from dagshub.common.download import download_files
from dagshub.common.helpers import sizeof_fmt, prompt_user, log_message
from dagshub.common.rich_util import get_rich_progress
from dagshub.common.util import lazy_load, multi_urljoin
from dagshub.data_engine.annotation import MetadataAnnotations
from dagshub.data_engine.annotation.voxel_conversion import (
add_voxel_annotations,
Expand Down Expand Up @@ -543,7 +541,7 @@ def predict_with_mlflow_model(
repo: str,
name: str,
host: Optional[str] = None,
version: str = "latest",
version: Optional[str] = "latest",
pre_hook: Callable[[List[str]], Any] = identity_func,
post_hook: Callable[[Any], Any] = identity_func,
batch_size: int = 1,
Expand All @@ -566,7 +564,6 @@ def predict_with_mlflow_model(

Args:
repo: repository to extract the model from

name: name of the model in the repository's MLflow registry.
host: address of the DagsHub instance with the repo to load the model from.
Set it if the model is hosted on a different DagsHub instance than the datasource.
Expand All @@ -581,27 +578,7 @@ def predict_with_mlflow_model(
if not host:
host = self.datasource.source.repoApi.host

prev_uri = mlflow.get_tracking_uri()
mlflow.set_tracking_uri(multi_urljoin(host, f"{repo}.mlflow"))
token = get_token(host=host)
os.environ["MLFLOW_TRACKING_USERNAME"] = UserAPI.get_user_from_token(token, host=host).username
os.environ["MLFLOW_TRACKING_PASSWORD"] = token
model_uri = f"models:/{name}/{version}"

try:
loader_module = mlflow.models.get_model_info(model_uri).flavors["python_function"]["loader_module"]
loader_module_elems = loader_module.split(".")
if loader_module_elems[-1] == "model":
loader_module_elems.pop()
loader_module = ".".join(loader_module_elems)
loader = mlflow.pyfunc if "pyfunc" in loader_module_elems else importlib.import_module(loader_module)
model = loader.load_model(model_uri)
finally:
mlflow.set_tracking_uri(prev_uri)

if "torch" in loader_module:
model.predict = model.__call__

model = get_mlflow_model(repo, name, host, version)
return self.generate_predictions(lambda x: post_hook(model.predict(pre_hook(x))), batch_size, log_to_field)

def get_annotations(self, **kwargs) -> "QueryResult":
Expand Down
3 changes: 2 additions & 1 deletion dagshub/mlflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .patch import patch_mlflow, unpatch_mlflow
from .get_model import get_mlflow_model

__all__ = [patch_mlflow.__name__, unpatch_mlflow.__name__]
__all__ = [patch_mlflow.__name__, unpatch_mlflow.__name__, get_mlflow_model.__name__]
57 changes: 57 additions & 0 deletions dagshub/mlflow/get_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from dagshub.common.util import lazy_load, multi_urljoin
from dagshub.common.api import UserAPI
from dagshub.auth import get_token
from dagshub.common import config

from typing import TYPE_CHECKING, Optional
import importlib
import os

if TYPE_CHECKING:
import mlflow
else:
mlflow = lazy_load("mlflow")


def get_mlflow_model(repo: str, name: str, version: Optional[str] = "latest", host: Optional[str] = None):
"""
Load an MLflow registered model from the specified DagsHub repository.
Call `model.predict(<data>)` on the loaded model to infer on your data.

Example::

model = dagshub.mlflow.get_mlflow_model('jinensetpal/COCO_1K', 'yolov8-seg', 'latest', 'https://dagshub.com')
model.predict("https://www.dagshub.com/jinensetpal/COCO_1K/raw/68a54c3bdc84582af3e42fb6f03507146f176378/data/images/train/000000000009.jpg")

Args:
repo: repository to extract the model from
name: name of the model in the repository's MLflow registry.
host: address of the DagsHub instance with the repo to load the model from.
Set it if the model is hosted on a different DagsHub instance than the datasource.
version: version of the model in the mlflow registry.
"""
host = host or config.host

prev_uri = mlflow.get_tracking_uri()
mlflow.set_tracking_uri(multi_urljoin(host, f"{repo}.mlflow"))

token = get_token(host=host)
os.environ["MLFLOW_TRACKING_USERNAME"] = UserAPI.get_user_from_token(token, host=host).username
os.environ["MLFLOW_TRACKING_PASSWORD"] = token
model_uri = f"models:/{name}/{version}"

try:
loader_module = mlflow.models.get_model_info(model_uri).flavors["python_function"]["loader_module"]
loader_module_elems = loader_module.split(".")
if loader_module_elems[-1] == "model":
loader_module_elems.pop()
loader_module = ".".join(loader_module_elems)
loader = mlflow.pyfunc if "pyfunc" in loader_module_elems else importlib.import_module(loader_module)
model = loader.load_model(model_uri)
finally:
mlflow.set_tracking_uri(prev_uri)

if "torch" in loader_module:
model.predict = model.__call__

return model
Loading