From e785c94c1f180386ae2bd7aaf0eee5e6f22e2c7a Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 14 Aug 2024 22:19:56 -0700 Subject: [PATCH 01/62] fix Signed-off-by: dbczumar --- mlflow/entities/model.py | 142 ++++++++++++++++++++++++++++++++ mlflow/entities/model_param.py | 38 +++++++++ mlflow/entities/model_status.py | 6 ++ mlflow/entities/model_tag.py | 25 ++++++ 4 files changed, 211 insertions(+) create mode 100644 mlflow/entities/model.py create mode 100644 mlflow/entities/model_param.py create mode 100644 mlflow/entities/model_status.py create mode 100644 mlflow/entities/model_tag.py diff --git a/mlflow/entities/model.py b/mlflow/entities/model.py new file mode 100644 index 0000000000000..a2849b214bea4 --- /dev/null +++ b/mlflow/entities/model.py @@ -0,0 +1,142 @@ +from typing import Dict, List, Optional + +from mlflow.entities._mlflow_object import _MlflowObject +from mlflow.entities.model_param import ModelParam +from mlflow.entities.model_status import ModelStatus +from mlflow.entities.model_tag import ModelTag + + +class Model(_MlflowObject): + """ + MLflow entity representing a Model. + """ + + def __init__( + self, + model_id: str, + name: str, + creation_timestamp: int, + last_updated_timestamp: int, + source: Optional[str] = None, + run_id: Optional[str] = None, + status: ModelStatus = ModelStatus.READY, + status_message: Optional[str] = None, + tags: Optional[List[ModelTag]] = None, + params: Optional[ModelParam] = None, + ): + super().__init__() + self._model_id: str = model_id + self._name: str = name + self._creation_time: int = creation_timestamp + self._last_updated_timestamp: int = last_updated_timestamp + self._source: Optional[str] = source + self._run_id: Optional[str] = run_id + self._status: ModelStatus = status + self._status_message: Optional[str] = status_message + self._tags: Dict[str, str] = {tag.key: tag.value for tag in (tags or [])} + self._params: Optional[ModelParam] = params + + @property + def model_id(self) -> str: + """String. Unique ID for the Model.""" + return self._name + + @model_id.setter + def model_id(self, new_model_id: str): + self._model_id = new_model_id + + @property + def name(self) -> str: + """String. Name for the Model.""" + return self._name + + @name.setter + def name(self, new_name: str): + self._name = new_name + + @property + def version(self) -> str: + """version""" + return self._version + + @property + def creation_timestamp(self) -> int: + """Integer. Model version creation timestamp (milliseconds since the Unix epoch).""" + return self._creation_time + + @property + def last_updated_timestamp(self) -> int: + """Integer. Timestamp of last update for this model version (milliseconds since the Unix + epoch). + """ + return self._last_updated_timestamp + + @last_updated_timestamp.setter + def last_updated_timestamp(self, updated_timestamp: int): + self._last_updated_timestamp = updated_timestamp + + @property + def description(self) -> str: + """String. Description""" + return self._description + + @description.setter + def description(self, description: str): + self._description = description + + @property + def user_id(self) -> str: + """String. User ID that created this model version.""" + return self._user_id + + @property + def current_stage(self) -> str: + """String. Current stage of this model version.""" + return self._current_stage + + @current_stage.setter + def current_stage(self, stage: str): + self._current_stage = stage + + @property + def source(self) -> Optional[str]: + """String. Source path for the model.""" + return self._source + + @property + def run_id(self) -> Optional[str]: + """String. MLflow run ID that generated this model.""" + return self._run_id + + @property + def run_link(self) -> str: + """String. MLflow run link referring to the exact run that generated this model version.""" + return self._run_link + + @property + def status(self) -> ModelStatus: + """String. Current status of this model.""" + return self._status + + @property + def status_message(self) -> Optional[str]: + """String. Descriptive message for error status conditions.""" + return self._status_message + + @property + def tags(self) -> Dict[str, str]: + """Dictionary of tag key (string) -> tag value for the current model version.""" + return self._tags + + @property + def params(self) -> Optional[ModelParam]: + """Model parameters.""" + return self._params + + @classmethod + def _properties(cls) -> List[str]: + # aggregate with base class properties since cls.__dict__ does not do it automatically + return sorted(cls._get_properties_helper()) + + def _add_tag(self, tag): + self._tags[tag.key] = tag.value diff --git a/mlflow/entities/model_param.py b/mlflow/entities/model_param.py new file mode 100644 index 0000000000000..b5e4cf7fe8c65 --- /dev/null +++ b/mlflow/entities/model_param.py @@ -0,0 +1,38 @@ +import sys + +from mlflow.entities._mlflow_object import _MlflowObject + + +class ModelParam(_MlflowObject): + """ + MLflow entity representing a parameter of a Model. + """ + + def __init__(self, key, value): + if "pyspark.ml" in sys.modules: + import pyspark.ml.param + + if isinstance(key, pyspark.ml.param.Param): + key = key.name + value = str(value) + self._key = key + self._value = value + + @property + def key(self): + """String key corresponding to the parameter name.""" + return self._key + + @property + def value(self): + """String value of the parameter.""" + return self._value + + def __eq__(self, __o): + if isinstance(__o, self.__class__): + return self._key == __o._key + + return False + + def __hash__(self): + return hash(self._key) diff --git a/mlflow/entities/model_status.py b/mlflow/entities/model_status.py new file mode 100644 index 0000000000000..e9222ba8e4bc6 --- /dev/null +++ b/mlflow/entities/model_status.py @@ -0,0 +1,6 @@ +class ModelStatus: + """Enum for status of an :py:class:`mlflow.entities.Model`.""" + + PENDING = "PENDING" + READY = "READY" + FAILED = "FAILED" diff --git a/mlflow/entities/model_tag.py b/mlflow/entities/model_tag.py new file mode 100644 index 0000000000000..0774ce27759b1 --- /dev/null +++ b/mlflow/entities/model_tag.py @@ -0,0 +1,25 @@ +from mlflow.entities._mlflow_object import _MlflowObject + + +class ModelTag(_MlflowObject): + """Tag object associated with a Model.""" + + def __init__(self, key, value): + self._key = key + self._value = value + + def __eq__(self, other): + if type(other) is type(self): + # TODO deep equality here? + return self.__dict__ == other.__dict__ + return False + + @property + def key(self): + """String name of the tag.""" + return self._key + + @property + def value(self): + """String value of the tag.""" + return self._value From 04b785b7bfe6b195e3e7e0fac4c261cacbc9b03d Mon Sep 17 00:00:00 2001 From: dbczumar Date: Mon, 19 Aug 2024 00:24:41 -0700 Subject: [PATCH 02/62] fix Signed-off-by: dbczumar --- mlflow/entities/__init__.py | 10 ++ mlflow/entities/model.py | 76 ++++------ mlflow/entities/model_status.py | 5 +- mlflow/protos/internal.proto | 2 + mlflow/store/tracking/file_store.py | 207 +++++++++++++++++++++++++++- 5 files changed, 244 insertions(+), 56 deletions(-) diff --git a/mlflow/entities/__init__.py b/mlflow/entities/__init__.py index 483d39835e95b..84c420cdf684a 100644 --- a/mlflow/entities/__init__.py +++ b/mlflow/entities/__init__.py @@ -12,6 +12,11 @@ from mlflow.entities.input_tag import InputTag from mlflow.entities.lifecycle_stage import LifecycleStage from mlflow.entities.metric import Metric +from mlflow.entities.model import Model +from mlflow.entities.model_input import ModelInput +from mlflow.entities.model_param import ModelParam +from mlflow.entities.model_status import ModelStatus +from mlflow.entities.model_tag import ModelTag from mlflow.entities.param import Param from mlflow.entities.run import Run from mlflow.entities.run_data import RunData @@ -57,4 +62,9 @@ "TraceInfo", "SpanStatusCode", "_DatasetSummary", + "Model", + "ModelInput", + "ModelStatus", + "ModelTag", + "ModelParam", ] diff --git a/mlflow/entities/model.py b/mlflow/entities/model.py index a2849b214bea4..28e236fd41f4e 100644 --- a/mlflow/entities/model.py +++ b/mlflow/entities/model.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from mlflow.entities._mlflow_object import _MlflowObject from mlflow.entities.model_param import ModelParam @@ -13,11 +13,11 @@ class Model(_MlflowObject): def __init__( self, + experiment_id: str, # New field added model_id: str, name: str, creation_timestamp: int, last_updated_timestamp: int, - source: Optional[str] = None, run_id: Optional[str] = None, status: ModelStatus = ModelStatus.READY, status_message: Optional[str] = None, @@ -25,21 +25,30 @@ def __init__( params: Optional[ModelParam] = None, ): super().__init__() + self._experiment_id: str = experiment_id # New field initialized self._model_id: str = model_id self._name: str = name self._creation_time: int = creation_timestamp self._last_updated_timestamp: int = last_updated_timestamp - self._source: Optional[str] = source self._run_id: Optional[str] = run_id self._status: ModelStatus = status self._status_message: Optional[str] = status_message self._tags: Dict[str, str] = {tag.key: tag.value for tag in (tags or [])} self._params: Optional[ModelParam] = params + @property + def experiment_id(self) -> str: + """String. Experiment ID associated with this Model.""" + return self._experiment_id + + @experiment_id.setter + def experiment_id(self, new_experiment_id: str): + self._experiment_id = new_experiment_id + @property def model_id(self) -> str: - """String. Unique ID for the Model.""" - return self._name + """String. Unique ID for this Model.""" + return self._model_id @model_id.setter def model_id(self, new_model_id: str): @@ -47,26 +56,21 @@ def model_id(self, new_model_id: str): @property def name(self) -> str: - """String. Name for the Model.""" + """String. Name for this Model.""" return self._name @name.setter def name(self, new_name: str): self._name = new_name - @property - def version(self) -> str: - """version""" - return self._version - @property def creation_timestamp(self) -> int: - """Integer. Model version creation timestamp (milliseconds since the Unix epoch).""" + """Integer. Model creation timestamp (milliseconds since the Unix epoch).""" return self._creation_time @property def last_updated_timestamp(self) -> int: - """Integer. Timestamp of last update for this model version (milliseconds since the Unix + """Integer. Timestamp of last update for this Model (milliseconds since the Unix epoch). """ return self._last_updated_timestamp @@ -75,49 +79,20 @@ def last_updated_timestamp(self) -> int: def last_updated_timestamp(self, updated_timestamp: int): self._last_updated_timestamp = updated_timestamp - @property - def description(self) -> str: - """String. Description""" - return self._description - - @description.setter - def description(self, description: str): - self._description = description - - @property - def user_id(self) -> str: - """String. User ID that created this model version.""" - return self._user_id - - @property - def current_stage(self) -> str: - """String. Current stage of this model version.""" - return self._current_stage - - @current_stage.setter - def current_stage(self, stage: str): - self._current_stage = stage - - @property - def source(self) -> Optional[str]: - """String. Source path for the model.""" - return self._source - @property def run_id(self) -> Optional[str]: """String. MLflow run ID that generated this model.""" return self._run_id - @property - def run_link(self) -> str: - """String. MLflow run link referring to the exact run that generated this model version.""" - return self._run_link - @property def status(self) -> ModelStatus: - """String. Current status of this model.""" + """String. Current status of this Model.""" return self._status + @status.setter + def status(self, updated_status: str): + self._status = updated_status + @property def status_message(self) -> Optional[str]: """String. Descriptive message for error status conditions.""" @@ -125,7 +100,7 @@ def status_message(self) -> Optional[str]: @property def tags(self) -> Dict[str, str]: - """Dictionary of tag key (string) -> tag value for the current model version.""" + """Dictionary of tag key (string) -> tag value for this Model.""" return self._tags @property @@ -140,3 +115,8 @@ def _properties(cls) -> List[str]: def _add_tag(self, tag): self._tags[tag.key] = tag.value + + def to_dictionary(self) -> Dict[str, Any]: + model_dict = dict(self) + model_dict["status"] = str(self.status) + return model_dict diff --git a/mlflow/entities/model_status.py b/mlflow/entities/model_status.py index e9222ba8e4bc6..495eb0638022a 100644 --- a/mlflow/entities/model_status.py +++ b/mlflow/entities/model_status.py @@ -1,4 +1,7 @@ -class ModelStatus: +from enum import Enum + + +class ModelStatus(str, Enum): """Enum for status of an :py:class:`mlflow.entities.Model`.""" PENDING = "PENDING" diff --git a/mlflow/protos/internal.proto b/mlflow/protos/internal.proto index 614a1916c1415..fcffd056b3957 100644 --- a/mlflow/protos/internal.proto +++ b/mlflow/protos/internal.proto @@ -20,4 +20,6 @@ enum InputVertexType { RUN = 1; DATASET = 2; + + MODEL = 3; } diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index c2bdd63a09d1a..96b16b37e4180 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -6,7 +6,7 @@ import time import uuid from dataclasses import dataclass -from typing import Dict, List, NamedTuple, Optional, Tuple +from typing import Any, Dict, List, NamedTuple, Optional, Tuple from mlflow.entities import ( Dataset, @@ -15,6 +15,11 @@ ExperimentTag, InputTag, Metric, + Model, + ModelInput, + ModelParam, + ModelStatus, + ModelTag, Param, Run, RunData, @@ -170,6 +175,7 @@ class FileStore(AbstractStore): DATASETS_FOLDER_NAME, TRACES_FOLDER_NAME, ] + MODELS_FOLDER_NAME = "models" def __init__(self, root_directory=None, artifact_root_uri=None): """ @@ -1112,14 +1118,21 @@ def record_logged_model(self, run_id, mlflow_model): except Exception as e: raise MlflowException(e, INTERNAL_ERROR) - def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None): + def log_inputs( + self, + run_id: str, + datasets: Optional[List[DatasetInput]] = None, + models: Optional[List[ModelInput]] = None, + ): """ - Log inputs, such as datasets, to the specified run. + Log inputs, such as datasets and models, to the specified run. Args: run_id: String id for the run datasets: List of :py:class:`mlflow.entities.DatasetInput` instances to log as inputs to the run. + models: List of :py:class:`mlflow.entities.ModelInput` instances to log + as inputs to the run. Returns: None. @@ -1128,13 +1141,13 @@ def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None) run_info = self._get_run_info(run_id) check_run_is_active(run_info) - if datasets is None: + if datasets is None and models is None: return experiment_dir = self._get_experiment_path(run_info.experiment_id, assert_exists=True) run_dir = self._get_run_dir(run_info.experiment_id, run_id) - for dataset_input in datasets: + for dataset_input in datasets or []: dataset = dataset_input.dataset dataset_id = FileStore._get_dataset_id( dataset_name=dataset.name, dataset_digest=dataset.digest @@ -1144,7 +1157,7 @@ def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None) os.makedirs(dataset_dir, exist_ok=True) write_yaml(dataset_dir, FileStore.META_DATA_FILE_NAME, dict(dataset)) - input_id = FileStore._get_input_id(dataset_id=dataset_id, run_id=run_id) + input_id = FileStore._get_dataset_input_id(dataset_id=dataset_id, run_id=run_id) input_dir = os.path.join(run_dir, FileStore.INPUTS_FOLDER_NAME, input_id) if not os.path.exists(input_dir): os.makedirs(input_dir, exist_ok=True) @@ -1157,6 +1170,21 @@ def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None) ) fs_input.write_yaml(input_dir, FileStore.META_DATA_FILE_NAME) + for model_input in models or []: + model_id = model_input.model_id + input_id = FileStore._get_model_input_id(model_id=model_id, run_id=run_id) + input_dir = os.path.join(run_dir, FileStore.INPUTS_FOLDER_NAME, input_id) + if not os.path.exists(input_dir): + os.makedirs(input_dir, exist_ok=True) + fs_input = FileStore._FileStoreInput( + source_type=InputVertexType.MODEL, + source_id=model_id, + destination_type=InputVertexType.RUN, + destination_id=run_id, + tags={}, + ) + fs_input.write_yaml(input_dir, FileStore.META_DATA_FILE_NAME) + @staticmethod def _get_dataset_id(dataset_name: str, dataset_digest: str) -> str: md5 = insecure_hash.md5(dataset_name.encode("utf-8")) @@ -1164,11 +1192,17 @@ def _get_dataset_id(dataset_name: str, dataset_digest: str) -> str: return md5.hexdigest() @staticmethod - def _get_input_id(dataset_id: str, run_id: str) -> str: + def _get_dataset_input_id(dataset_id: str, run_id: str) -> str: md5 = insecure_hash.md5(dataset_id.encode("utf-8")) md5.update(run_id.encode("utf-8")) return md5.hexdigest() + @staticmethod + def _get_model_input_id(model_id: str, run_id: str) -> str: + md5 = insecure_hash.md5(model_id.encode("utf-8")) + md5.update(run_id.encode("utf-8")) + return md5.hexdigest() + class _FileStoreInput(NamedTuple): source_type: int source_id: str @@ -1686,3 +1720,162 @@ def _list_trace_infos(self, experiment_id): exc_info=_logger.isEnabledFor(logging.DEBUG), ) return trace_infos + + def create_model( + self, + experiment_id: str, + name: str, + run_id: Optional[str] = None, + tags: Optional[List[ModelTag]] = None, + params: Optional[List[ModelParam]] = None, + ) -> Model: + """ + Create a new model. + + Args: + experiment_id: ID of the Experiment where the model is being created. + name: Name of the model. + run_id: Run ID where the model is being created from. + tags: Key-value tags for the model. + params: Key-value params for the model. + + Returns: + The model version. + """ + experiment_id = FileStore.DEFAULT_EXPERIMENT_ID if experiment_id is None else experiment_id + experiment = self.get_experiment(experiment_id) + if experiment is None: + raise MlflowException( + "Could not create model under experiment with ID %s - no such experiment " + "exists." % experiment_id, + databricks_pb2.RESOURCE_DOES_NOT_EXIST, + ) + if experiment.lifecycle_stage != LifecycleStage.ACTIVE: + raise MlflowException( + f"Could not create model under non-active experiment with ID {experiment_id}.", + databricks_pb2.INVALID_STATE, + ) + + model_id = str(uuid.uuid4()) + artifact_uri = self._get_model_artifact_dir(experiment_id, model_id) + creation_timestamp = int(time.time() * 1000) + model = Model( + experiment_id=experiment_id, + model_id=model_id, + name=name, + creation_timestamp=creation_timestamp, + last_updated_timestamp=creation_timestamp, + run_id=run_id, + status=ModelStatus.PENDING, + tags=tags, + params=params, + ) + + # Persist model metadata and create directories for logging metrics, tags + model_dir = self._get_model_dir(experiment_id, model_id) + mkdir(model_dir) + model_info_dict: Dict[str, Any] = self._make_persisted_model_dict(model, artifact_uri) + write_yaml(model_dir, FileStore.META_DATA_FILE_NAME, model_info_dict) + for tag in tags or []: + self.set_model_tag(model_id=model_id, tag=tag) + + return self.get_model(model_id=model_id) + + def finalize_model(self, model_id: str, status: ModelStatus) -> Model: + """ + Finalize a model by updating its status. + + Args: + model_id: ID of the model to finalize. + status: Final status to set on the model. + + Returns: + The updated model. + """ + if status != ModelStatus.READY: + raise MlflowException( + f"Invalid model status: {status}. Expected statuses: [{ModelStatus.READY}]", + databricks_pb2.INVALID_PARAMETER_VALUE, + ) + model_dict = self._get_model_dict(model_id) + model = Model.from_dictionary(model_dict) + model.status = status + model.last_updated_timestamp = int(time.time() * 1000) + model_dir = self._get_model_dir(model.experiment_id, model.model_id) + model_info_dict = self._make_persisted_model_dict(model, model_dict["artifact_location"]) + write_yaml(model_dir, FileStore.META_DATA_FILE_NAME, model_info_dict, overwrite=True) + return self.get_model(model_id) + + def set_model_tag(self, model_id: str, tag: ModelTag): + _validate_tag_name(tag.key) + model = self.get_model(model_id) + tag_path = os.path.join( + self._get_model_dir(model.experiment_id, model.model_id), + FileStore.TAGS_FOLDER_NAME, + tag.key, + ) + make_containing_dirs(tag_path) + # Don't add trailing newline + write_to(tag_path, self._writeable_value(tag.value)) + return + + def get_model(self, model_id: str) -> Model: + return Model.from_dictionary(self._get_model_dict(model_id)) + + def _get_model_artifact_dir(self, experiment_id: str, model_id: str) -> str: + return append_to_uri_path( + self.get_experiment(experiment_id).artifact_location, + model_id, + FileStore.ARTIFACTS_FOLDER_NAME, + ) + + def _make_persisted_model_dict(self, model: Model, artifact_location) -> Dict[str, Any]: + model_dict = model.to_dictionary() + model_dict["artifact_location"] = artifact_location + model_dict.pop("tags", None) + model_dict["params"] = {param.key: param.value for param in model.params or []} + return model_dict + + def _get_model_dict(self, model_id: str) -> Dict[str, Any]: + exp_id, model_dir = self._find_model_root(model_id) + if model_dir is None: + raise MlflowException( + f"Model '{model_id}' not found", databricks_pb2.RESOURCE_DOES_NOT_EXIST + ) + model_dict: Dict[str, Any] = self._get_model_info_from_dir(model_dir) + if model_dict["experiment_id"] != exp_id: + raise MlflowException( + f"Model '{model_id}' metadata is in invalid state.", databricks_pb2.INVALID_STATE + ) + model_dict["tags"] = self._get_all_model_tags(model_dir) + return model_dict + + def _get_model_dir(self, experiment_id: str, model_id: str) -> str: + if not self._has_experiment(experiment_id): + return None + return os.path.join( + self._get_experiment_path(experiment_id, assert_exists=True), + FileStore.MODELS_FOLDER_NAME, + model_id, + ) + + def _find_model_root(self, model_id): + self._check_root_dir() + all_experiments = self._get_active_experiments(True) + self._get_deleted_experiments(True) + for experiment_dir in all_experiments: + models_dir_path = os.path.join(experiment_dir, FileStore.MODELS_FOLDER_NAME) + models = find(models_dir_path, model_id, full_path=True) + if len(models) == 0: + continue + return os.path.basename(os.path.dirname(os.path.abspath(models_dir_path))), models[0] + return None, None + + def _get_model_info_from_dir(self, model_dir: str) -> Dict[str, Any]: + return FileStore._read_yaml(model_dir, FileStore.META_DATA_FILE_NAME) + + def _get_all_model_tags(self, model_dir: str) -> List[ModelTag]: + parent_path, tag_files = self._get_resource_files(model_dir, FileStore.TAGS_FOLDER_NAME) + tags = [] + for tag_file in tag_files: + tags.append(self._get_tag_from_file(parent_path, tag_file)) + return tags From 7f8291e07c9f282bc218d2cd972f07c74aaf91d0 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Mon, 19 Aug 2024 00:55:59 -0700 Subject: [PATCH 03/62] progress Signed-off-by: dbczumar --- mlflow/entities/model_input.py | 18 +++++++++ mlflow/entities/run_inputs.py | 20 +++++++++- mlflow/entities/run_outputs.py | 26 +++++++++++++ .../org/mlflow/internal/proto/Internal.java | 15 ++++++-- mlflow/protos/internal_pb2.py | 9 +++-- mlflow/store/tracking/file_store.py | 38 +++++++++++++++---- 6 files changed, 110 insertions(+), 16 deletions(-) create mode 100644 mlflow/entities/model_input.py create mode 100644 mlflow/entities/run_outputs.py diff --git a/mlflow/entities/model_input.py b/mlflow/entities/model_input.py new file mode 100644 index 0000000000000..456c6db70ae5f --- /dev/null +++ b/mlflow/entities/model_input.py @@ -0,0 +1,18 @@ +from mlflow.entities._mlflow_object import _MlflowObject + + +class ModelInput(_MlflowObject): + """ModelInput object associated with a Run.""" + + def __init__(self, model_id: str): + self._model_id = model_id + + def __eq__(self, other: _MlflowObject) -> bool: + if type(other) is type(self): + return self.__dict__ == other.__dict__ + return False + + @property + def model_id(self) -> str: + """Model ID.""" + return self._model_id diff --git a/mlflow/entities/run_inputs.py b/mlflow/entities/run_inputs.py index d28f026c71bc3..e5b8f1ccc7c7f 100644 --- a/mlflow/entities/run_inputs.py +++ b/mlflow/entities/run_inputs.py @@ -2,14 +2,16 @@ from mlflow.entities._mlflow_object import _MlflowObject from mlflow.entities.dataset_input import DatasetInput +from mlflow.entities.model_input import ModelInput from mlflow.protos.service_pb2 import RunInputs as ProtoRunInputs class RunInputs(_MlflowObject): """RunInputs object.""" - def __init__(self, dataset_inputs: List[DatasetInput]) -> None: + def __init__(self, dataset_inputs: List[DatasetInput], model_inputs: List[ModelInput]) -> None: self._dataset_inputs = dataset_inputs + self._model_inputs = model_inputs def __eq__(self, other: _MlflowObject) -> bool: if type(other) is type(self): @@ -21,16 +23,26 @@ def dataset_inputs(self) -> List[DatasetInput]: """Array of dataset inputs.""" return self._dataset_inputs + @property + def model_inputs(self) -> List[ModelInput]: + """Array of model inputs.""" + return self._model_inputs + def to_proto(self): run_inputs = ProtoRunInputs() run_inputs.dataset_inputs.extend( [dataset_input.to_proto() for dataset_input in self.dataset_inputs] ) + # TODO: Support proto conversion for model inputs + # run_inputs.model_inputs.extend( + # [model_input.to_proto() for model_input in self.model_inputs] + # ) return run_inputs def to_dictionary(self) -> Dict[Any, Any]: return { "dataset_inputs": self.dataset_inputs, + "model_inputs": self.model_inputs, } @classmethod @@ -38,4 +50,8 @@ def from_proto(cls, proto): dataset_inputs = [ DatasetInput.from_proto(dataset_input) for dataset_input in proto.dataset_inputs ] - return cls(dataset_inputs) + # TODO: Support proto conversion for model inputs + # model_inputs = [ + # ModelInput.from_proto(model_input) for model_input in proto.model_inputs + # ] + return cls(dataset_inputs, []) diff --git a/mlflow/entities/run_outputs.py b/mlflow/entities/run_outputs.py new file mode 100644 index 0000000000000..3d4b8a5f83b77 --- /dev/null +++ b/mlflow/entities/run_outputs.py @@ -0,0 +1,26 @@ +from typing import Any, Dict, List + +from mlflow.entities._mlflow_object import _MlflowObject +from mlflow.entities.model_output import ModelOutput + + +class RunOutputs(_MlflowObject): + """RunOutputs object.""" + + def __init__(self, model_outputs: List[ModelOutput]) -> None: + self._model_outputs = model_outputs + + def __eq__(self, other: _MlflowObject) -> bool: + if type(other) is type(self): + return self.__dict__ == other.__dict__ + return False + + @property + def model_outputs(self) -> List[ModelOutput]: + """Array of model outputs.""" + return self._model_outputs + + def to_dictionary(self) -> Dict[Any, Any]: + return { + "model_outputs": self.model_outputs, + } diff --git a/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java b/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java index 487c8f9a2864e..d3accc8bbb95a 100644 --- a/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java +++ b/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java @@ -32,6 +32,10 @@ public enum InputVertexType * DATASET = 2; */ DATASET(2), + /** + * MODEL = 3; + */ + MODEL(3), ; /** @@ -42,6 +46,10 @@ public enum InputVertexType * DATASET = 2; */ public static final int DATASET_VALUE = 2; + /** + * MODEL = 3; + */ + public static final int MODEL_VALUE = 3; public final int getNumber() { @@ -66,6 +74,7 @@ public static InputVertexType forNumber(int value) { switch (value) { case 1: return RUN; case 2: return DATASET; + case 3: return MODEL; default: return null; } } @@ -125,9 +134,9 @@ private InputVertexType(int value) { static { java.lang.String[] descriptorData = { "\n\016internal.proto\022\017mlflow.internal\032\025scala" + - "pb/scalapb.proto*\'\n\017InputVertexType\022\007\n\003R" + - "UN\020\001\022\013\n\007DATASET\020\002B#\n\031org.mlflow.internal" + - ".proto\220\001\001\342?\002\020\001" + "pb/scalapb.proto*2\n\017InputVertexType\022\007\n\003R" + + "UN\020\001\022\013\n\007DATASET\020\002\022\t\n\005MODEL\020\003B#\n\031org.mlfl" + + "ow.internal.proto\220\001\001\342?\002\020\001" }; descriptor = com.google.protobuf.Descriptors.FileDescriptor .internalBuildGeneratedFileFrom(descriptorData, diff --git a/mlflow/protos/internal_pb2.py b/mlflow/protos/internal_pb2.py index 7fcf97a79455a..7752bc693568c 100644 --- a/mlflow/protos/internal_pb2.py +++ b/mlflow/protos/internal_pb2.py @@ -19,7 +19,7 @@ from .scalapb import scalapb_pb2 as scalapb_dot_scalapb__pb2 - DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*\'\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01') + DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*2\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x12\t\n\x05MODEL\x10\x03\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -28,7 +28,7 @@ _globals['DESCRIPTOR']._loaded_options = None _globals['DESCRIPTOR']._serialized_options = b'\n\031org.mlflow.internal.proto\220\001\001\342?\002\020\001' _globals['_INPUTVERTEXTYPE']._serialized_start=58 - _globals['_INPUTVERTEXTYPE']._serialized_end=97 + _globals['_INPUTVERTEXTYPE']._serialized_end=108 # @@protoc_insertion_point(module_scope) else: @@ -50,12 +50,13 @@ from .scalapb import scalapb_pb2 as scalapb_dot_scalapb__pb2 - DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*\'\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01') + DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*2\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x12\t\n\x05MODEL\x10\x03\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01') _INPUTVERTEXTYPE = DESCRIPTOR.enum_types_by_name['InputVertexType'] InputVertexType = enum_type_wrapper.EnumTypeWrapper(_INPUTVERTEXTYPE) RUN = 1 DATASET = 2 + MODEL = 3 if _descriptor._USE_C_DESCRIPTORS == False: @@ -63,6 +64,6 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b'\n\031org.mlflow.internal.proto\220\001\001\342?\002\020\001' _INPUTVERTEXTYPE._serialized_start=58 - _INPUTVERTEXTYPE._serialized_end=97 + _INPUTVERTEXTYPE._serialized_end=108 # @@protoc_insertion_point(module_scope) diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 96b16b37e4180..5a63e73751345 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -1235,9 +1235,19 @@ def _get_all_inputs(self, run_info: RunInfo) -> RunInputs: run_dir = self._get_run_dir(run_info.experiment_id, run_info.run_id) inputs_parent_path = os.path.join(run_dir, FileStore.INPUTS_FOLDER_NAME) experiment_dir = self._get_experiment_path(run_info.experiment_id, assert_exists=True) - datasets_parent_path = os.path.join(experiment_dir, FileStore.DATASETS_FOLDER_NAME) - if not os.path.exists(inputs_parent_path) or not os.path.exists(datasets_parent_path): - return RunInputs(dataset_inputs=[]) + if not os.path.exists(inputs_parent_path): + return RunInputs(dataset_inputs=[], model_inputs=[]) + + dataset_inputs = self._get_dataset_inputs(run_info, inputs_parent_path, experiment_dir) + model_inputs = self._get_model_inputs(run_info, inputs_parent_path, experiment_dir) + return RunInputs(dataset_inputs=dataset_inputs, model_inputs=model_inputs) + + def _get_dataset_inputs( + self, run_info: RunInfo, inputs_parent_path: str, experiment_dir_path: str + ) -> List[DatasetInput]: + datasets_parent_path = os.path.join(experiment_dir_path, FileStore.DATASETS_FOLDER_NAME) + if not os.path.exists(datasets_parent_path): + return [] dataset_dirs = os.listdir(datasets_parent_path) dataset_inputs = [] @@ -1247,9 +1257,6 @@ def _get_all_inputs(self, run_info: RunInfo) -> RunInputs: input_dir_full_path, FileStore.META_DATA_FILE_NAME ) if fs_input.source_type != InputVertexType.DATASET: - logging.warning( - f"Encountered invalid run input source type '{fs_input.source_type}'. Skipping." - ) continue matching_dataset_dirs = [d for d in dataset_dirs if d == fs_input.source_id] @@ -1272,7 +1279,24 @@ def _get_all_inputs(self, run_info: RunInfo) -> RunInputs: ) dataset_inputs.append(dataset_input) - return RunInputs(dataset_inputs=dataset_inputs) + return dataset_inputs + + def _get_model_inputs( + self, inputs_parent_path: str, experiment_dir_path: str + ) -> List[ModelInput]: + model_inputs = [] + for input_dir in os.listdir(inputs_parent_path): + input_dir_full_path = os.path.join(inputs_parent_path, input_dir) + fs_input = FileStore._FileStoreInput.from_yaml( + input_dir_full_path, FileStore.META_DATA_FILE_NAME + ) + if fs_input.source_type != InputVertexType.MODEL: + continue + + model_input = ModelInput(model_id=fs_input.source_id) + model_inputs.append(model_input) + + return model_inputs def _search_datasets(self, experiment_ids) -> List[_DatasetSummary]: """ From ad26e8fd6502874efdd40547cd25d16b932ec12a Mon Sep 17 00:00:00 2001 From: dbczumar Date: Mon, 19 Aug 2024 01:17:44 -0700 Subject: [PATCH 04/62] progress Signed-off-by: dbczumar --- mlflow/entities/__init__.py | 4 + mlflow/entities/model_output.py | 18 +++ mlflow/entities/run.py | 28 ++++- .../org/mlflow/internal/proto/Internal.java | 106 +++++++++++++++++- mlflow/protos/internal.proto | 8 ++ mlflow/protos/internal_pb2.py | 12 +- mlflow/store/tracking/file_store.py | 102 ++++++++++++++++- 7 files changed, 267 insertions(+), 11 deletions(-) create mode 100644 mlflow/entities/model_output.py diff --git a/mlflow/entities/__init__.py b/mlflow/entities/__init__.py index 84c420cdf684a..86283eb1c4707 100644 --- a/mlflow/entities/__init__.py +++ b/mlflow/entities/__init__.py @@ -14,6 +14,7 @@ from mlflow.entities.metric import Metric from mlflow.entities.model import Model from mlflow.entities.model_input import ModelInput +from mlflow.entities.model_output import ModelOutput from mlflow.entities.model_param import ModelParam from mlflow.entities.model_status import ModelStatus from mlflow.entities.model_tag import ModelTag @@ -22,6 +23,7 @@ from mlflow.entities.run_data import RunData from mlflow.entities.run_info import RunInfo from mlflow.entities.run_inputs import RunInputs +from mlflow.entities.run_outputs import RunOutputs from mlflow.entities.run_status import RunStatus from mlflow.entities.run_tag import RunTag from mlflow.entities.source_type import SourceType @@ -51,6 +53,7 @@ "InputTag", "DatasetInput", "RunInputs", + "RunOutputs", "Span", "LiveSpan", "NoOpSpan", @@ -64,6 +67,7 @@ "_DatasetSummary", "Model", "ModelInput", + "ModelOutput", "ModelStatus", "ModelTag", "ModelParam", diff --git a/mlflow/entities/model_output.py b/mlflow/entities/model_output.py new file mode 100644 index 0000000000000..81608c54ca771 --- /dev/null +++ b/mlflow/entities/model_output.py @@ -0,0 +1,18 @@ +from mlflow.entities._mlflow_object import _MlflowObject + + +class ModelOutput(_MlflowObject): + """ModelOutput object associated with a Run.""" + + def __init__(self, model_id: str): + self._model_id = model_id + + def __eq__(self, other: _MlflowObject) -> bool: + if type(other) is type(self): + return self.__dict__ == other.__dict__ + return False + + @property + def model_id(self) -> str: + """Model ID""" + return self._model_id diff --git a/mlflow/entities/run.py b/mlflow/entities/run.py index 0fade6fa8daa2..8bdc68fe8d895 100644 --- a/mlflow/entities/run.py +++ b/mlflow/entities/run.py @@ -4,6 +4,7 @@ from mlflow.entities.run_data import RunData from mlflow.entities.run_info import RunInfo from mlflow.entities.run_inputs import RunInputs +from mlflow.entities.run_outputs import RunOutputs from mlflow.exceptions import MlflowException from mlflow.protos.service_pb2 import Run as ProtoRun @@ -14,13 +15,18 @@ class Run(_MlflowObject): """ def __init__( - self, run_info: RunInfo, run_data: RunData, run_inputs: Optional[RunInputs] = None + self, + run_info: RunInfo, + run_data: RunData, + run_inputs: Optional[RunInputs] = None, + run_outputs: Optional[RunOutputs] = None, ) -> None: if run_info is None: raise MlflowException("run_info cannot be None") self._info = run_info self._data = run_data self._inputs = run_inputs + self._outputs = run_outputs @property def info(self) -> RunInfo: @@ -43,12 +49,21 @@ def data(self) -> RunData: @property def inputs(self) -> RunInputs: """ - The run inputs, including dataset inputs + The run inputs, including dataset inputs. :rtype: :py:class:`mlflow.entities.RunInputs` """ return self._inputs + @property + def outputs(self) -> RunOutputs: + """ + The run outputs, including model outputs. + + :rtype: :py:class:`mlflow.entities.RunOutputs` + """ + return self._outputs + def to_proto(self): run = ProtoRun() run.info.MergeFrom(self.info.to_proto()) @@ -56,6 +71,9 @@ def to_proto(self): run.data.MergeFrom(self.data.to_proto()) if self.inputs: run.inputs.MergeFrom(self.inputs.to_proto()) + # TODO: Support proto conversion for RunOutputs + # if self.outputs: + # run.outputs.MergeFrom(self.outputs.to_proto()) return run @classmethod @@ -63,7 +81,9 @@ def from_proto(cls, proto): return cls( RunInfo.from_proto(proto.info), RunData.from_proto(proto.data), - RunInputs.from_proto(proto.inputs), + RunInputs.from_proto(proto.inputs) if proto.inputs else None, + # TODO: Support proto conversion for RunOutputs + # RunOutputs.from_proto(proto.outputs) if proto.outputs else None, ) def to_dictionary(self) -> Dict[Any, Any]: @@ -74,4 +94,6 @@ def to_dictionary(self) -> Dict[Any, Any]: run_dict["data"] = self.data.to_dictionary() if self.inputs: run_dict["inputs"] = self.inputs.to_dictionary() + if self.outputs: + run_dict["outputs"] = self.outputs.to_dictionary() return run_dict diff --git a/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java b/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java index d3accc8bbb95a..4139fe8cab66d 100644 --- a/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java +++ b/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java @@ -124,6 +124,107 @@ private InputVertexType(int value) { // @@protoc_insertion_point(enum_scope:mlflow.internal.InputVertexType) } + /** + *
+   * Types of vertices represented in MLflow Run Outputs. Valid vertices are MLflow objects that can
+   * have an output relationship.
+   * 
+ * + * Protobuf enum {@code mlflow.internal.OutputVertexType} + */ + public enum OutputVertexType + implements com.google.protobuf.ProtocolMessageEnum { + /** + * RUN_OUTPUT = 1; + */ + RUN_OUTPUT(1), + /** + * MODEL_OUTPUT = 2; + */ + MODEL_OUTPUT(2), + ; + + /** + * RUN_OUTPUT = 1; + */ + public static final int RUN_OUTPUT_VALUE = 1; + /** + * MODEL_OUTPUT = 2; + */ + public static final int MODEL_OUTPUT_VALUE = 2; + + + public final int getNumber() { + return value; + } + + /** + * @param value The numeric wire value of the corresponding enum entry. + * @return The enum associated with the given numeric wire value. + * @deprecated Use {@link #forNumber(int)} instead. + */ + @java.lang.Deprecated + public static OutputVertexType valueOf(int value) { + return forNumber(value); + } + + /** + * @param value The numeric wire value of the corresponding enum entry. + * @return The enum associated with the given numeric wire value. + */ + public static OutputVertexType forNumber(int value) { + switch (value) { + case 1: return RUN_OUTPUT; + case 2: return MODEL_OUTPUT; + default: return null; + } + } + + public static com.google.protobuf.Internal.EnumLiteMap + internalGetValueMap() { + return internalValueMap; + } + private static final com.google.protobuf.Internal.EnumLiteMap< + OutputVertexType> internalValueMap = + new com.google.protobuf.Internal.EnumLiteMap() { + public OutputVertexType findValueByNumber(int number) { + return OutputVertexType.forNumber(number); + } + }; + + public final com.google.protobuf.Descriptors.EnumValueDescriptor + getValueDescriptor() { + return getDescriptor().getValues().get(ordinal()); + } + public final com.google.protobuf.Descriptors.EnumDescriptor + getDescriptorForType() { + return getDescriptor(); + } + public static final com.google.protobuf.Descriptors.EnumDescriptor + getDescriptor() { + return org.mlflow.internal.proto.Internal.getDescriptor().getEnumTypes().get(1); + } + + private static final OutputVertexType[] VALUES = values(); + + public static OutputVertexType valueOf( + com.google.protobuf.Descriptors.EnumValueDescriptor desc) { + if (desc.getType() != getDescriptor()) { + throw new java.lang.IllegalArgumentException( + "EnumValueDescriptor is not for this type."); + } + return VALUES[desc.getIndex()]; + } + + private final int value; + + private OutputVertexType(int value) { + this.value = value; + } + + // @@protoc_insertion_point(enum_scope:mlflow.internal.OutputVertexType) + } + public static com.google.protobuf.Descriptors.FileDescriptor getDescriptor() { @@ -135,8 +236,9 @@ private InputVertexType(int value) { java.lang.String[] descriptorData = { "\n\016internal.proto\022\017mlflow.internal\032\025scala" + "pb/scalapb.proto*2\n\017InputVertexType\022\007\n\003R" + - "UN\020\001\022\013\n\007DATASET\020\002\022\t\n\005MODEL\020\003B#\n\031org.mlfl" + - "ow.internal.proto\220\001\001\342?\002\020\001" + "UN\020\001\022\013\n\007DATASET\020\002\022\t\n\005MODEL\020\003*4\n\020OutputVe" + + "rtexType\022\016\n\nRUN_OUTPUT\020\001\022\020\n\014MODEL_OUTPUT" + + "\020\002B#\n\031org.mlflow.internal.proto\220\001\001\342?\002\020\001" }; descriptor = com.google.protobuf.Descriptors.FileDescriptor .internalBuildGeneratedFileFrom(descriptorData, diff --git a/mlflow/protos/internal.proto b/mlflow/protos/internal.proto index fcffd056b3957..057e18fea1c9c 100644 --- a/mlflow/protos/internal.proto +++ b/mlflow/protos/internal.proto @@ -23,3 +23,11 @@ enum InputVertexType { MODEL = 3; } + +// Types of vertices represented in MLflow Run Outputs. Valid vertices are MLflow objects that can +// have an output relationship. +enum OutputVertexType { + RUN_OUTPUT = 1; + + MODEL_OUTPUT = 2; +} diff --git a/mlflow/protos/internal_pb2.py b/mlflow/protos/internal_pb2.py index 7752bc693568c..7aa93249e6ff5 100644 --- a/mlflow/protos/internal_pb2.py +++ b/mlflow/protos/internal_pb2.py @@ -19,7 +19,7 @@ from .scalapb import scalapb_pb2 as scalapb_dot_scalapb__pb2 - DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*2\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x12\t\n\x05MODEL\x10\x03\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01') + DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*2\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x12\t\n\x05MODEL\x10\x03*4\n\x10OutputVertexType\x12\x0e\n\nRUN_OUTPUT\x10\x01\x12\x10\n\x0cMODEL_OUTPUT\x10\x02\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -29,6 +29,8 @@ _globals['DESCRIPTOR']._serialized_options = b'\n\031org.mlflow.internal.proto\220\001\001\342?\002\020\001' _globals['_INPUTVERTEXTYPE']._serialized_start=58 _globals['_INPUTVERTEXTYPE']._serialized_end=108 + _globals['_OUTPUTVERTEXTYPE']._serialized_start=110 + _globals['_OUTPUTVERTEXTYPE']._serialized_end=162 # @@protoc_insertion_point(module_scope) else: @@ -50,13 +52,17 @@ from .scalapb import scalapb_pb2 as scalapb_dot_scalapb__pb2 - DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*2\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x12\t\n\x05MODEL\x10\x03\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01') + DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*2\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x12\t\n\x05MODEL\x10\x03*4\n\x10OutputVertexType\x12\x0e\n\nRUN_OUTPUT\x10\x01\x12\x10\n\x0cMODEL_OUTPUT\x10\x02\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01') _INPUTVERTEXTYPE = DESCRIPTOR.enum_types_by_name['InputVertexType'] InputVertexType = enum_type_wrapper.EnumTypeWrapper(_INPUTVERTEXTYPE) + _OUTPUTVERTEXTYPE = DESCRIPTOR.enum_types_by_name['OutputVertexType'] + OutputVertexType = enum_type_wrapper.EnumTypeWrapper(_OUTPUTVERTEXTYPE) RUN = 1 DATASET = 2 MODEL = 3 + RUN_OUTPUT = 1 + MODEL_OUTPUT = 2 if _descriptor._USE_C_DESCRIPTORS == False: @@ -65,5 +71,7 @@ DESCRIPTOR._serialized_options = b'\n\031org.mlflow.internal.proto\220\001\001\342?\002\020\001' _INPUTVERTEXTYPE._serialized_start=58 _INPUTVERTEXTYPE._serialized_end=108 + _OUTPUTVERTEXTYPE._serialized_start=110 + _OUTPUTVERTEXTYPE._serialized_end=162 # @@protoc_insertion_point(module_scope) diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 5a63e73751345..97d8610e559a2 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -17,6 +17,7 @@ Metric, Model, ModelInput, + ModelOutput, ModelParam, ModelStatus, ModelTag, @@ -25,6 +26,7 @@ RunData, RunInfo, RunInputs, + RunOutputs, RunStatus, RunTag, SourceType, @@ -43,7 +45,7 @@ INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST, ) -from mlflow.protos.internal_pb2 import InputVertexType +from mlflow.protos.internal_pb2 import InputVertexType, OutputVertexType from mlflow.store.entities.paged_list import PagedList from mlflow.store.model_registry.file_store import FileStore as ModelRegistryFileStore from mlflow.store.tracking import ( @@ -164,6 +166,7 @@ class FileStore(AbstractStore): EXPERIMENT_TAGS_FOLDER_NAME = "tags" DATASETS_FOLDER_NAME = "datasets" INPUTS_FOLDER_NAME = "inputs" + OUTPUTS_FOLDER_NAME = "outputs" META_DATA_FILE_NAME = "meta.yaml" DEFAULT_EXPERIMENT_ID = "0" TRACE_INFO_FILE_NAME = "trace_info.yaml" @@ -693,11 +696,12 @@ def _get_run_from_info(self, run_info): params = self._get_all_params(run_info) tags = self._get_all_tags(run_info) inputs: RunInputs = self._get_all_inputs(run_info) + outputs: RunOutputs = self._get_all_outputs(run_info) if not run_info.run_name: run_name = _get_run_name_from_tags(tags) if run_name: run_info._set_run_name(run_name) - return Run(run_info, RunData(metrics, params, tags), inputs) + return Run(run_info, RunData(metrics, params, tags), inputs, outputs) def _get_run_info(self, run_uuid): """ @@ -1185,6 +1189,41 @@ def log_inputs( ) fs_input.write_yaml(input_dir, FileStore.META_DATA_FILE_NAME) + def log_outputs(self, run_id, models: Optional[List[ModelOutput]] = None): + """ + Log outputs, such as models, to the specified run. + + Args: + run_id: String id for the run + models: List of :py:class:`mlflow.entities.ModelOutput` instances to log + as outputs of the run. + + Returns: + None. + """ + _validate_run_id(run_id) + run_info = self._get_run_info(run_id) + check_run_is_active(run_info) + + if models is None: + return + + run_dir = self._get_run_dir(run_info.experiment_id, run_id) + + for model_output in models: + model_id = model_output.model_id + output_dir = os.path.join(run_dir, FileStore.OUTPUTS_FOLDER_NAME, model_id) + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + fs_output = FileStore._FileStoreOutput( + source_type=OutputVertexType.RUN_OUTPUT, + source_id=model_id, + destination_type=OutputVertexType.MODEL_OUTPUT, + destination_id=run_id, + tags={}, + ) + fs_output.write_yaml(output_dir, FileStore.META_DATA_FILE_NAME) + @staticmethod def _get_dataset_id(dataset_name: str, dataset_digest: str) -> str: md5 = insecure_hash.md5(dataset_name.encode("utf-8")) @@ -1231,15 +1270,43 @@ def from_yaml(cls, root, file_name): tags=dict_from_yaml["tags"], ) + class _FileStoreOutput(NamedTuple): + source_type: int + source_id: str + destination_type: int + destination_id: str + tags: Dict[str, str] + + def write_yaml(self, root: str, file_name: str): + dict_for_yaml = { + "source_type": OutputVertexType.Name(self.source_type), + "source_id": self.source_id, + "destination_type": OutputVertexType.Name(self.destination_type), + "destination_id": self.source_id, + "tags": self.tags, + } + write_yaml(root, file_name, dict_for_yaml) + + @classmethod + def from_yaml(cls, root, file_name): + dict_from_yaml = FileStore._read_yaml(root, file_name) + return cls( + source_type=OutputVertexType.Value(dict_from_yaml["source_type"]), + source_id=dict_from_yaml["source_id"], + destination_type=OutputVertexType.Value(dict_from_yaml["destination_type"]), + destination_id=dict_from_yaml["destination_id"], + tags=dict_from_yaml["tags"], + ) + def _get_all_inputs(self, run_info: RunInfo) -> RunInputs: run_dir = self._get_run_dir(run_info.experiment_id, run_info.run_id) inputs_parent_path = os.path.join(run_dir, FileStore.INPUTS_FOLDER_NAME) - experiment_dir = self._get_experiment_path(run_info.experiment_id, assert_exists=True) if not os.path.exists(inputs_parent_path): return RunInputs(dataset_inputs=[], model_inputs=[]) + experiment_dir = self._get_experiment_path(run_info.experiment_id, assert_exists=True) dataset_inputs = self._get_dataset_inputs(run_info, inputs_parent_path, experiment_dir) - model_inputs = self._get_model_inputs(run_info, inputs_parent_path, experiment_dir) + model_inputs = self._get_model_inputs(inputs_parent_path, experiment_dir) return RunInputs(dataset_inputs=dataset_inputs, model_inputs=model_inputs) def _get_dataset_inputs( @@ -1298,6 +1365,33 @@ def _get_model_inputs( return model_inputs + def _get_all_outputs(self, run_info: RunInfo) -> RunOutputs: + run_dir = self._get_run_dir(run_info.experiment_id, run_info.run_id) + outputs_parent_path = os.path.join(run_dir, FileStore.OUTPUTS_FOLDER_NAME) + if not os.path.exists(outputs_parent_path): + return RunOutputs(model_outputs=[]) + + experiment_dir = self._get_experiment_path(run_info.experiment_id, assert_exists=True) + model_outputs = self._get_model_outputs(outputs_parent_path, experiment_dir) + return RunOutputs(model_outputs=model_outputs) + + def _get_model_outputs( + self, outputs_parent_path: str, experiment_dir: str + ) -> List[ModelOutput]: + model_outputs = [] + for output_dir in os.listdir(outputs_parent_path): + output_dir_full_path = os.path.join(outputs_parent_path, output_dir) + fs_output = FileStore._FileStoreOutput.from_yaml( + output_dir_full_path, FileStore.META_DATA_FILE_NAME + ) + if fs_output.destination_type != OutputVertexType.MODEL_OUTPUT: + continue + + model_output = ModelOutput(model_id=fs_output.destination_id) + model_outputs.append(model_output) + + return model_outputs + def _search_datasets(self, experiment_ids) -> List[_DatasetSummary]: """ Return all dataset summaries associated to the given experiments. From 67b6472a6d535f9753c1d71124fdc82d72907d09 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Mon, 19 Aug 2024 01:22:04 -0700 Subject: [PATCH 05/62] proggy Signed-off-by: dbczumar --- mlflow/entities/model.py | 11 +++++++++++ mlflow/store/tracking/file_store.py | 10 +++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/mlflow/entities/model.py b/mlflow/entities/model.py index 28e236fd41f4e..c57aedec5f9aa 100644 --- a/mlflow/entities/model.py +++ b/mlflow/entities/model.py @@ -16,6 +16,7 @@ def __init__( experiment_id: str, # New field added model_id: str, name: str, + artifact_location: str, # New field added creation_timestamp: int, last_updated_timestamp: int, run_id: Optional[str] = None, @@ -28,6 +29,7 @@ def __init__( self._experiment_id: str = experiment_id # New field initialized self._model_id: str = model_id self._name: str = name + self._artifact_location: str = artifact_location # New field initialized self._creation_time: int = creation_timestamp self._last_updated_timestamp: int = last_updated_timestamp self._run_id: Optional[str] = run_id @@ -63,6 +65,15 @@ def name(self) -> str: def name(self, new_name: str): self._name = new_name + @property + def artifact_location(self) -> str: + """String. Location of the model artifacts.""" + return self._artifact_location + + @artifact_location.setter + def artifact_location(self, new_artifact_location: str): + self._artifact_location = new_artifact_location + @property def creation_timestamp(self) -> int: """Integer. Model creation timestamp (milliseconds since the Unix epoch).""" diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 97d8610e559a2..1ea43aaa806b8 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -1875,12 +1875,13 @@ def create_model( ) model_id = str(uuid.uuid4()) - artifact_uri = self._get_model_artifact_dir(experiment_id, model_id) + artifact_location = self._get_model_artifact_dir(experiment_id, model_id) creation_timestamp = int(time.time() * 1000) model = Model( experiment_id=experiment_id, model_id=model_id, name=name, + artifact_location=artifact_location, creation_timestamp=creation_timestamp, last_updated_timestamp=creation_timestamp, run_id=run_id, @@ -1892,7 +1893,7 @@ def create_model( # Persist model metadata and create directories for logging metrics, tags model_dir = self._get_model_dir(experiment_id, model_id) mkdir(model_dir) - model_info_dict: Dict[str, Any] = self._make_persisted_model_dict(model, artifact_uri) + model_info_dict: Dict[str, Any] = self._make_persisted_model_dict(model) write_yaml(model_dir, FileStore.META_DATA_FILE_NAME, model_info_dict) for tag in tags or []: self.set_model_tag(model_id=model_id, tag=tag) @@ -1920,7 +1921,7 @@ def finalize_model(self, model_id: str, status: ModelStatus) -> Model: model.status = status model.last_updated_timestamp = int(time.time() * 1000) model_dir = self._get_model_dir(model.experiment_id, model.model_id) - model_info_dict = self._make_persisted_model_dict(model, model_dict["artifact_location"]) + model_info_dict = self._make_persisted_model_dict(model) write_yaml(model_dir, FileStore.META_DATA_FILE_NAME, model_info_dict, overwrite=True) return self.get_model(model_id) @@ -1947,9 +1948,8 @@ def _get_model_artifact_dir(self, experiment_id: str, model_id: str) -> str: FileStore.ARTIFACTS_FOLDER_NAME, ) - def _make_persisted_model_dict(self, model: Model, artifact_location) -> Dict[str, Any]: + def _make_persisted_model_dict(self, model: Model) -> Dict[str, Any]: model_dict = model.to_dictionary() - model_dict["artifact_location"] = artifact_location model_dict.pop("tags", None) model_dict["params"] = {param.key: param.value for param in model.params or []} return model_dict From 4bab2341379261adfeddfb4db1e785fbf070a7bf Mon Sep 17 00:00:00 2001 From: dbczumar Date: Mon, 19 Aug 2024 01:24:10 -0700 Subject: [PATCH 06/62] fix Signed-off-by: dbczumar --- mlflow/store/tracking/file_store.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 1ea43aaa806b8..e71822867bcc0 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -1873,6 +1873,8 @@ def create_model( f"Could not create model under non-active experiment with ID {experiment_id}.", databricks_pb2.INVALID_STATE, ) + for param in params or []: + _validate_param(param.key, param.value) model_id = str(uuid.uuid4()) artifact_location = self._get_model_artifact_dir(experiment_id, model_id) From f9ddf42dd1c1c221562c037d7202fe2d5e9781d3 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Mon, 19 Aug 2024 21:03:55 -0700 Subject: [PATCH 07/62] progress Signed-off-by: dbczumar --- mlflow/entities/metric.py | 12 +++++++++++- mlflow/entities/model_output.py | 8 +++++++- mlflow/store/tracking/file_store.py | 15 +++++++++++++-- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/mlflow/entities/metric.py b/mlflow/entities/metric.py index bea6926b95d1f..636ed1a5aacea 100644 --- a/mlflow/entities/metric.py +++ b/mlflow/entities/metric.py @@ -10,11 +10,12 @@ class Metric(_MlflowObject): Metric object. """ - def __init__(self, key, value, timestamp, step): + def __init__(self, key, value, timestamp, step, model_id: str = None): self._key = key self._value = value self._timestamp = timestamp self._step = step + self._model_id = model_id @property def key(self): @@ -36,16 +37,24 @@ def step(self): """Integer metric step (x-coordinate).""" return self._step + @property + def model_id(self): + """ID of the Model associated with the metric.""" + return self._model_id + def to_proto(self): metric = ProtoMetric() metric.key = self.key metric.value = self.value metric.timestamp = self.timestamp metric.step = self.step + # TODO: Add model_id to the proto + metric.model_id = self.model_id return metric @classmethod def from_proto(cls, proto): + # TODO: Add model_id to the proto return cls(proto.key, proto.value, proto.timestamp, proto.step) def __eq__(self, __o): @@ -69,6 +78,7 @@ def to_dictionary(self): "value": self.value, "timestamp": self.timestamp, "step": self.step, + "model_id": self.model_id, } @classmethod diff --git a/mlflow/entities/model_output.py b/mlflow/entities/model_output.py index 81608c54ca771..058a50b316271 100644 --- a/mlflow/entities/model_output.py +++ b/mlflow/entities/model_output.py @@ -4,8 +4,9 @@ class ModelOutput(_MlflowObject): """ModelOutput object associated with a Run.""" - def __init__(self, model_id: str): + def __init__(self, model_id: str, step: int) -> None: self._model_id = model_id + self._step = step def __eq__(self, other: _MlflowObject) -> bool: if type(other) is type(self): @@ -16,3 +17,8 @@ def __eq__(self, other: _MlflowObject) -> bool: def model_id(self) -> str: """Model ID""" return self._model_id + + @property + def step(self) -> str: + """Step at which the model was logged""" + return self._step diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index e71822867bcc0..dbb44e0d191b8 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -955,18 +955,24 @@ def _search_runs( runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token, max_results) return runs, next_page_token - def log_metric(self, run_id, metric): + def log_metric(self, run_id: str, metric: Metric): _validate_run_id(run_id) _validate_metric(metric.key, metric.value, metric.timestamp, metric.step) run_info = self._get_run_info(run_id) check_run_is_active(run_info) self._log_run_metric(run_info, metric) + if metric.model_id is not None: + self._log_model_metric(model_id=metric.model_id, metric=metric) + def _log_run_metric(self, run_info, metric): metric_path = self._get_metric_path(run_info.experiment_id, run_info.run_id, metric.key) make_containing_dirs(metric_path) append_to(metric_path, f"{metric.timestamp} {metric.value} {metric.step}\n") + def _log_model_metric(self, model_id, metric): + pass + def _writeable_value(self, tag_value): if tag_value is None: return "" @@ -1221,6 +1227,7 @@ def log_outputs(self, run_id, models: Optional[List[ModelOutput]] = None): destination_type=OutputVertexType.MODEL_OUTPUT, destination_id=run_id, tags={}, + step=model_output.step, ) fs_output.write_yaml(output_dir, FileStore.META_DATA_FILE_NAME) @@ -1276,6 +1283,7 @@ class _FileStoreOutput(NamedTuple): destination_type: int destination_id: str tags: Dict[str, str] + step: int def write_yaml(self, root: str, file_name: str): dict_for_yaml = { @@ -1284,6 +1292,7 @@ def write_yaml(self, root: str, file_name: str): "destination_type": OutputVertexType.Name(self.destination_type), "destination_id": self.source_id, "tags": self.tags, + "step": self.step, } write_yaml(root, file_name, dict_for_yaml) @@ -1296,6 +1305,7 @@ def from_yaml(cls, root, file_name): destination_type=OutputVertexType.Value(dict_from_yaml["destination_type"]), destination_id=dict_from_yaml["destination_id"], tags=dict_from_yaml["tags"], + step=dict_from_yaml["step"], ) def _get_all_inputs(self, run_info: RunInfo) -> RunInputs: @@ -1387,7 +1397,7 @@ def _get_model_outputs( if fs_output.destination_type != OutputVertexType.MODEL_OUTPUT: continue - model_output = ModelOutput(model_id=fs_output.destination_id) + model_output = ModelOutput(model_id=fs_output.destination_id, step=fs_output.step) model_outputs.append(model_output) return model_outputs @@ -1897,6 +1907,7 @@ def create_model( mkdir(model_dir) model_info_dict: Dict[str, Any] = self._make_persisted_model_dict(model) write_yaml(model_dir, FileStore.META_DATA_FILE_NAME, model_info_dict) + mkdir(model_dir, FileStore.METRICS_FOLDER_NAME) for tag in tags or []: self.set_model_tag(model_id=model_id, tag=tag) From d33178fae4f49b273dee68637bbfaa9cbe4ad2ee Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 00:02:50 -0700 Subject: [PATCH 08/62] fix Signed-off-by: dbczumar --- mlflow/entities/metric.py | 37 +++++++++++- mlflow/entities/model.py | 31 ++++++++-- mlflow/store/tracking/file_store.py | 90 +++++++++++++++++++++++++++-- 3 files changed, 147 insertions(+), 11 deletions(-) diff --git a/mlflow/entities/metric.py b/mlflow/entities/metric.py index 636ed1a5aacea..98714789676a8 100644 --- a/mlflow/entities/metric.py +++ b/mlflow/entities/metric.py @@ -1,3 +1,5 @@ +from typing import Optional + from mlflow.entities._mlflow_object import _MlflowObject from mlflow.exceptions import MlflowException from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE @@ -10,12 +12,29 @@ class Metric(_MlflowObject): Metric object. """ - def __init__(self, key, value, timestamp, step, model_id: str = None): + def __init__( + self, + key, + value, + timestamp, + step, + model_id: Optional[str] = None, + dataset_name: Optional[str] = None, + dataset_digest: Optional[str] = None, + ): + if (dataset_name, dataset_digest).count(None) == 1: + raise MlflowException( + "Both dataset_name and dataset_digest must be provided if one is provided", + INVALID_PARAMETER_VALUE, + ) + self._key = key self._value = value self._timestamp = timestamp self._step = step self._model_id = model_id + self._dataset_name = dataset_name + self._dataset_digest = dataset_digest @property def key(self): @@ -42,19 +61,29 @@ def model_id(self): """ID of the Model associated with the metric.""" return self._model_id + @property + def dataset_name(self) -> Optional[str]: + """String. Name of the dataset associated with the metric.""" + return self._dataset_name + + @property + def dataset_digest(self) -> Optional[str]: + """String. Digest of the dataset associated with the metric.""" + return self._dataset_digest + def to_proto(self): metric = ProtoMetric() metric.key = self.key metric.value = self.value metric.timestamp = self.timestamp metric.step = self.step - # TODO: Add model_id to the proto + # TODO: Add model_id, dataset_name, and dataset_digest to the proto metric.model_id = self.model_id return metric @classmethod def from_proto(cls, proto): - # TODO: Add model_id to the proto + # TODO: Add model_id, dataset_name, and dataset_digest to the proto return cls(proto.key, proto.value, proto.timestamp, proto.step) def __eq__(self, __o): @@ -79,6 +108,8 @@ def to_dictionary(self): "timestamp": self.timestamp, "step": self.step, "model_id": self.model_id, + "dataset_name": self.dataset_name, + "dataset_digest": self.dataset_digest, } @classmethod diff --git a/mlflow/entities/model.py b/mlflow/entities/model.py index c57aedec5f9aa..0c8ab7b5a39ee 100644 --- a/mlflow/entities/model.py +++ b/mlflow/entities/model.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List, Optional from mlflow.entities._mlflow_object import _MlflowObject +from mlflow.entities.metric import Metric from mlflow.entities.model_param import ModelParam from mlflow.entities.model_status import ModelStatus from mlflow.entities.model_tag import ModelTag @@ -13,30 +14,34 @@ class Model(_MlflowObject): def __init__( self, - experiment_id: str, # New field added + experiment_id: str, model_id: str, name: str, - artifact_location: str, # New field added + artifact_location: str, creation_timestamp: int, last_updated_timestamp: int, + model_type: Optional[str] = None, run_id: Optional[str] = None, status: ModelStatus = ModelStatus.READY, status_message: Optional[str] = None, tags: Optional[List[ModelTag]] = None, params: Optional[ModelParam] = None, + metrics: Optional[List[Metric]] = None, ): super().__init__() - self._experiment_id: str = experiment_id # New field initialized + self._experiment_id: str = experiment_id self._model_id: str = model_id self._name: str = name - self._artifact_location: str = artifact_location # New field initialized + self._artifact_location: str = artifact_location self._creation_time: int = creation_timestamp self._last_updated_timestamp: int = last_updated_timestamp + self._model_type: Optional[str] = model_type self._run_id: Optional[str] = run_id self._status: ModelStatus = status self._status_message: Optional[str] = status_message self._tags: Dict[str, str] = {tag.key: tag.value for tag in (tags or [])} self._params: Optional[ModelParam] = params + self._metrics: Optional[List[Metric]] = metrics @property def experiment_id(self) -> str: @@ -90,6 +95,15 @@ def last_updated_timestamp(self) -> int: def last_updated_timestamp(self, updated_timestamp: int): self._last_updated_timestamp = updated_timestamp + @property + def model_type(self) -> Optional[str]: + """String. Type of the model.""" + return self._model_type + + @model_type.setter + def model_type(self, new_model_type: Optional[str]): + self._model_type = new_model_type + @property def run_id(self) -> Optional[str]: """String. MLflow run ID that generated this model.""" @@ -119,6 +133,15 @@ def params(self) -> Optional[ModelParam]: """Model parameters.""" return self._params + @property + def metrics(self) -> Optional[List[Metric]]: + """List of metrics associated with this Model.""" + return self._metrics + + @metrics.setter + def metrics(self, new_metrics: Optional[List[Metric]]): + self._metrics = new_metrics + @classmethod def _properties(cls) -> List[str]: # aggregate with base class properties since cls.__dict__ does not do it automatically diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index dbb44e0d191b8..ccceda6936641 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -246,6 +246,12 @@ def _get_metric_path(self, experiment_id, run_uuid, metric_key): self._get_run_dir(experiment_id, run_uuid), FileStore.METRICS_FOLDER_NAME, metric_key ) + def _get_model_metric_path(self, experiment_id: str, model_id: str, metric_key: str) -> str: + _validate_metric_name(metric_key) + return os.path.join( + self._get_model_dir(experiment_id, model_id), FileStore.METRICS_FOLDER_NAME, metric_key + ) + def _get_param_path(self, experiment_id, run_uuid, param_name): _validate_run_id(run_uuid) _validate_param_name(param_name) @@ -962,16 +968,31 @@ def log_metric(self, run_id: str, metric: Metric): check_run_is_active(run_info) self._log_run_metric(run_info, metric) if metric.model_id is not None: - self._log_model_metric(model_id=metric.model_id, metric=metric) - + self._log_model_metric( + experiment_id=run_info.experiment_id, + model_id=metric.model_id, + run_id=run_id, + metric=metric, + ) def _log_run_metric(self, run_info, metric): metric_path = self._get_metric_path(run_info.experiment_id, run_info.run_id, metric.key) make_containing_dirs(metric_path) append_to(metric_path, f"{metric.timestamp} {metric.value} {metric.step}\n") - def _log_model_metric(self, model_id, metric): - pass + def _log_model_metric(self, experiment_id: str, model_id: str, run_id: str, metric: Metric): + metric_path = self._get_model_metric_path( + experiment_id=experiment_id, model_id=model_id, metric_key=metric.key + ) + make_containing_dirs(metric_path) + if metric.dataset_name is not None and metric.dataset_digest is not None: + append_to( + metric_path, + f"{metric.timestamp} {metric.value} {metric.step} {run_id} {metric.dataset_name} " + f"{metric.dataset_digest}\n", + ) + else: + append_to(metric_path, f"{metric.timestamp} {metric.value} {metric.step} {run_id}\n") def _writeable_value(self, tag_value): if tag_value is None: @@ -1093,6 +1114,13 @@ def log_batch(self, run_id, metrics, params, tags): self._log_run_param(run_info, param) for metric in metrics: self._log_run_metric(run_info, metric) + if metric.model_id is not None: + self._log_model_metric( + experiment_id=run_info.experiment_id, + model_id=metric.model_id, + run_id=run_id, + metric=metric, + ) for tag in tags: # NB: If the tag run name value is set, update the run info to assure # synchronization. @@ -1979,6 +2007,7 @@ def _get_model_dict(self, model_id: str) -> Dict[str, Any]: f"Model '{model_id}' metadata is in invalid state.", databricks_pb2.INVALID_STATE ) model_dict["tags"] = self._get_all_model_tags(model_dir) + model_dict["metrics"] = self._get_all_model_metrics(model_id=model_id, model_dir=model_dir) return model_dict def _get_model_dir(self, experiment_id: str, model_id: str) -> str: @@ -2010,3 +2039,56 @@ def _get_all_model_tags(self, model_dir: str) -> List[ModelTag]: for tag_file in tag_files: tags.append(self._get_tag_from_file(parent_path, tag_file)) return tags + + def _get_all_model_metrics(self, model_id: str, model_dir: str) -> List[Metric]: + parent_path, metric_files = self._get_resource_files( + model_dir, FileStore.METRICS_FOLDER_NAME + ) + metrics = [] + for metric_file in metric_files: + metrics.append( + FileStore._get_model_metric_from_file( + model_id=model_id, parent_path=parent_path, metric_name=metric_file + ) + ) + return metrics + + @staticmethod + def _get_model_metric_from_file(model_id: str, parent_path: str, metric_name: str) -> Metric: + _validate_metric_name(metric_name) + metric_objs = [ + FileStore._get_model_metric_from_line(model_id, metric_name, line) + for line in read_file_lines(parent_path, metric_name) + ] + if len(metric_objs) == 0: + raise ValueError(f"Metric '{metric_name}' is malformed. No data found.") + # Python performs element-wise comparison of equal-length tuples, ordering them + # based on their first differing element. Therefore, we use max() operator to find the + # largest value at the largest timestamp. For more information, see + # https://docs.python.org/3/reference/expressions.html#value-comparisons + return max(metric_objs, key=lambda m: (m.step, m.timestamp, m.value)) + + @staticmethod + def _get_model_metric_from_line(model_id: str, metric_name: str, metric_line: str) -> Metric: + metric_parts = metric_line.strip().split(" ") + if len(metric_parts) not in [4, 6]: + raise MlflowException( + f"Metric '{metric_name}' is malformed; persisted metric data contained " + f"{len(metric_parts)} fields. Expected 4 or 6 fields.", + databricks_pb2.INTERNAL_ERROR, + ) + ts = int(metric_parts[0]) + val = float(metric_parts[1]) + step = int(metric_parts[2]) + dataset_name = str(metric_parts[4]) if len(metric_parts) == 6 else None + dataset_digest = str(metric_parts[5]) if len(metric_parts) == 6 else None + # TODO: Read run ID from the metric file and pass it to the Metric constructor + return Metric( + key=metric_name, + value=val, + timestamp=ts, + step=step, + model_id=model_id, + dataset_name=dataset_name, + dataset_digest=dataset_digest, + ) From c51ae7f13f86968bba68e86a54e12d6012925acb Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 00:06:38 -0700 Subject: [PATCH 09/62] Fix Signed-off-by: dbczumar --- mlflow/store/tracking/file_store.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index ccceda6936641..491d2846ad38b 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -798,17 +798,26 @@ def _get_all_metrics(self, run_info): @staticmethod def _get_metric_from_line(metric_name, metric_line, exp_id): metric_parts = metric_line.strip().split(" ") - if len(metric_parts) != 2 and len(metric_parts) != 3: + if len(metric_parts) != 2 and len(metric_parts) != 3 and len(metric_parts) != 5: raise MlflowException( f"Metric '{metric_name}' is malformed; persisted metric data contained " - f"{len(metric_parts)} fields. Expected 2 or 3 fields. " + f"{len(metric_parts)} fields. Expected 2, 3, or 5 fields. " f"Experiment id: {exp_id}", databricks_pb2.INTERNAL_ERROR, ) ts = int(metric_parts[0]) val = float(metric_parts[1]) step = int(metric_parts[2]) if len(metric_parts) == 3 else 0 - return Metric(key=metric_name, value=val, timestamp=ts, step=step) + dataset_name = str(metric_parts[3]) if len(metric_parts) == 5 else None + dataset_digest = str(metric_parts[4]) if len(metric_parts) == 5 else None + return Metric( + key=metric_name, + value=val, + timestamp=ts, + step=step, + dataset_name=dataset_name, + dataset_digest=dataset_digest, + ) def get_metric_history(self, run_id, metric_key, max_results=None, page_token=None): """ @@ -978,7 +987,14 @@ def log_metric(self, run_id: str, metric: Metric): def _log_run_metric(self, run_info, metric): metric_path = self._get_metric_path(run_info.experiment_id, run_info.run_id, metric.key) make_containing_dirs(metric_path) - append_to(metric_path, f"{metric.timestamp} {metric.value} {metric.step}\n") + if metric.dataset_name is not None and metric.dataset_digest is not None: + append_to( + metric_path, + f"{metric.timestamp} {metric.value} {metric.step} {metric.dataset_name} " + f"{metric.dataset_digest}\n", + ) + else: + append_to(metric_path, f"{metric.timestamp} {metric.value} {metric.step}\n") def _log_model_metric(self, experiment_id: str, model_id: str, run_id: str, metric: Metric): metric_path = self._get_model_metric_path( From e1a27f1eeecba49b0c92b6129ca7500ea576a8e6 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 00:15:06 -0700 Subject: [PATCH 10/62] progress Signed-off-by: dbczumar --- mlflow/entities/metric.py | 13 ++++++++++--- mlflow/store/tracking/file_store.py | 19 ++++++++++++++----- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/mlflow/entities/metric.py b/mlflow/entities/metric.py index 98714789676a8..68bd68451a150 100644 --- a/mlflow/entities/metric.py +++ b/mlflow/entities/metric.py @@ -21,6 +21,7 @@ def __init__( model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_digest: Optional[str] = None, + run_id: Optional[str] = None, ): if (dataset_name, dataset_digest).count(None) == 1: raise MlflowException( @@ -35,6 +36,7 @@ def __init__( self._model_id = model_id self._dataset_name = dataset_name self._dataset_digest = dataset_digest + self._run_id = run_id @property def key(self): @@ -71,19 +73,23 @@ def dataset_digest(self) -> Optional[str]: """String. Digest of the dataset associated with the metric.""" return self._dataset_digest + @property + def run_id(self) -> Optional[str]: + """String. Run ID associated with the metric.""" + return self._run_id + def to_proto(self): metric = ProtoMetric() metric.key = self.key metric.value = self.value metric.timestamp = self.timestamp metric.step = self.step - # TODO: Add model_id, dataset_name, and dataset_digest to the proto - metric.model_id = self.model_id + # TODO: Add model_id, dataset_name, dataset_digest, and run_id to the proto return metric @classmethod def from_proto(cls, proto): - # TODO: Add model_id, dataset_name, and dataset_digest to the proto + # TODO: Add model_id, dataset_name, dataset_digest, and run_id to the proto return cls(proto.key, proto.value, proto.timestamp, proto.step) def __eq__(self, __o): @@ -110,6 +116,7 @@ def to_dictionary(self): "model_id": self.model_id, "dataset_name": self.dataset_name, "dataset_digest": self.dataset_digest, + "run_id": self._run_id, } @classmethod diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 491d2846ad38b..8011eb0def22f 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -767,10 +767,12 @@ def _get_resource_files(self, root_dir, subfolder_name): return source_dirs[0], file_names @staticmethod - def _get_metric_from_file(parent_path, metric_name, exp_id): + def _get_metric_from_file( + parent_path: str, metric_name: str, run_id: str, exp_id: str + ) -> Metric: _validate_metric_name(metric_name) metric_objs = [ - FileStore._get_metric_from_line(metric_name, line, exp_id) + FileStore._get_metric_from_line(run_id, metric_name, line, exp_id) for line in read_file_lines(parent_path, metric_name) ] if len(metric_objs) == 0: @@ -791,12 +793,16 @@ def _get_all_metrics(self, run_info): metrics = [] for metric_file in metric_files: metrics.append( - self._get_metric_from_file(parent_path, metric_file, run_info.experiment_id) + self._get_metric_from_file( + parent_path, metric_file, run_info.run_id, run_info.experiment_id + ) ) return metrics @staticmethod - def _get_metric_from_line(metric_name, metric_line, exp_id): + def _get_metric_from_line( + run_id: str, metric_name: str, metric_line: str, exp_id: str + ) -> Metric: metric_parts = metric_line.strip().split(" ") if len(metric_parts) != 2 and len(metric_parts) != 3 and len(metric_parts) != 5: raise MlflowException( @@ -817,6 +823,7 @@ def _get_metric_from_line(metric_name, metric_line, exp_id): step=step, dataset_name=dataset_name, dataset_digest=dataset_digest, + run_id=run_id, ) def get_metric_history(self, run_id, metric_key, max_results=None, page_token=None): @@ -856,7 +863,7 @@ def get_metric_history(self, run_id, metric_key, max_results=None, page_token=No return PagedList([], None) return PagedList( [ - FileStore._get_metric_from_line(metric_key, line, run_info.experiment_id) + FileStore._get_metric_from_line(run_id, metric_key, line, run_info.experiment_id) for line in read_file_lines(parent_path, metric_key) ], None, @@ -2096,6 +2103,7 @@ def _get_model_metric_from_line(model_id: str, metric_name: str, metric_line: st ts = int(metric_parts[0]) val = float(metric_parts[1]) step = int(metric_parts[2]) + run_id = str(metric_parts[3]) dataset_name = str(metric_parts[4]) if len(metric_parts) == 6 else None dataset_digest = str(metric_parts[5]) if len(metric_parts) == 6 else None # TODO: Read run ID from the metric file and pass it to the Metric constructor @@ -2107,4 +2115,5 @@ def _get_model_metric_from_line(model_id: str, metric_name: str, metric_line: st model_id=model_id, dataset_name=dataset_name, dataset_digest=dataset_digest, + run_id=run_id, ) From 90bd5b77227dbbf32734b699441a64c3a4039352 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 11:52:21 -0700 Subject: [PATCH 11/62] progress Signed-off-by: dbczumar --- .../entities/model_registry/model_version.py | 18 +++++++++++++++ mlflow/tracking/_tracking_service/client.py | 23 +++++++++++++++++++ mlflow/tracking/client.py | 11 +++++++++ 3 files changed, 52 insertions(+) diff --git a/mlflow/entities/model_registry/model_version.py b/mlflow/entities/model_registry/model_version.py index a30d022c9eb0f..05267f0948d75 100644 --- a/mlflow/entities/model_registry/model_version.py +++ b/mlflow/entities/model_registry/model_version.py @@ -1,3 +1,7 @@ +from typing import List, Optional + +from mlflow.entities.metric import Metric +from mlflow.entities.model_param import ModelParam from mlflow.entities.model_registry._model_registry_entity import _ModelRegistryEntity from mlflow.entities.model_registry.model_version_status import ModelVersionStatus from mlflow.entities.model_registry.model_version_tag import ModelVersionTag @@ -26,6 +30,8 @@ def __init__( tags=None, run_link=None, aliases=None, + params: Optional[List[ModelParam]] = None, + metrics: Optional[List[Metric]] = None, ): super().__init__() self._name = name @@ -42,6 +48,8 @@ def __init__( self._status_message = status_message self._tags = {tag.key: tag.value for tag in (tags or [])} self._aliases = aliases or [] + self._params = params + self._metrics = metrics @property def name(self): @@ -135,6 +143,16 @@ def aliases(self): def aliases(self, aliases): self._aliases = aliases + @property + def params(self) -> Optional[List[ModelParam]]: + """List of parameters associated with this model version.""" + return self._params + + @property + def metrics(self) -> Optional[List[Metric]]: + """List of metrics associated with this model version.""" + return self._metrics + @classmethod def _properties(cls): # aggregate with base class properties since cls.__dict__ does not do it automatically diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index 70f5dd38681c0..bfe7e3e2a2b63 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -14,6 +14,9 @@ from mlflow.entities import ( ExperimentTag, Metric, + Model, + ModelParam, + ModelTag, Param, RunStatus, RunTag, @@ -976,3 +979,23 @@ def search_runs( order_by=order_by, page_token=page_token, ) + + def create_model( + self, + experiment_id: str, + name: str, + run_id: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, str]] = None, + ) -> Model: + return self.store.create_model( + experiment_id=experiment_id, + name=name, + run_id=run_id, + tags=[ModelTag(key, value) for key, value in tags.items()] + if tags is not None + else tags, + params=[ModelParam(key, value) for key, value in params.items()] + if params is not None + else params, + ) diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index f85e93755e296..65ebc47f6bd98 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -24,6 +24,7 @@ Experiment, FileInfo, Metric, + Model, Param, Run, RunTag, @@ -4735,3 +4736,13 @@ def print_model_version_info(mv): """ _validate_model_name(name) return self._get_registry_client().get_model_version_by_alias(name, alias) + + def create_model( + self, + experiment_id: str, + name: str, + run_id: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, str]] = None, + ) -> Model: + return self._tracking_client.create_model(experiment_id, name, run_id, tags, params) From 8e429955ff6555ad1b9a878b6b55d30139c3bccf Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 12:00:38 -0700 Subject: [PATCH 12/62] fix Signed-off-by: dbczumar --- .../entities/model_registry/model_version.py | 114 +++++++++--------- 1 file changed, 57 insertions(+), 57 deletions(-) diff --git a/mlflow/entities/model_registry/model_version.py b/mlflow/entities/model_registry/model_version.py index 05267f0948d75..9651abefaaa76 100644 --- a/mlflow/entities/model_registry/model_version.py +++ b/mlflow/entities/model_registry/model_version.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Dict, List, Optional from mlflow.entities.metric import Metric from mlflow.entities.model_param import ModelParam @@ -16,131 +16,131 @@ class ModelVersion(_ModelRegistryEntity): def __init__( self, - name, - version, - creation_timestamp, - last_updated_timestamp=None, - description=None, - user_id=None, - current_stage=None, - source=None, - run_id=None, - status=ModelVersionStatus.to_string(ModelVersionStatus.READY), - status_message=None, - tags=None, - run_link=None, - aliases=None, + name: str, + version: str, + creation_timestamp: int, + last_updated_timestamp: Optional[int] = None, + description: Optional[str] = None, + user_id: Optional[str] = None, + current_stage: Optional[str] = None, + source: Optional[str] = None, + run_id: Optional[str] = None, + status: str = ModelVersionStatus.to_string(ModelVersionStatus.READY), + status_message: Optional[str] = None, + tags: Optional[List[ModelVersionTag]] = None, + run_link: Optional[str] = None, + aliases: Optional[List[str]] = None, params: Optional[List[ModelParam]] = None, metrics: Optional[List[Metric]] = None, ): super().__init__() - self._name = name - self._version = version - self._creation_time = creation_timestamp - self._last_updated_timestamp = last_updated_timestamp - self._description = description - self._user_id = user_id - self._current_stage = current_stage - self._source = source - self._run_id = run_id - self._run_link = run_link - self._status = status - self._status_message = status_message - self._tags = {tag.key: tag.value for tag in (tags or [])} - self._aliases = aliases or [] - self._params = params - self._metrics = metrics - - @property - def name(self): + self._name: str = name + self._version: str = version + self._creation_time: int = creation_timestamp + self._last_updated_timestamp: Optional[int] = last_updated_timestamp + self._description: Optional[str] = description + self._user_id: Optional[str] = user_id + self._current_stage: Optional[str] = current_stage + self._source: Optional[str] = source + self._run_id: Optional[str] = run_id + self._run_link: Optional[str] = run_link + self._status: str = status + self._status_message: Optional[str] = status_message + self._tags: Dict[str, str] = {tag.key: tag.value for tag in (tags or [])} + self._aliases: List[str] = aliases or [] + self._params: Optional[List[ModelParam]] = params + self._metrics: Optional[List[Metric]] = metrics + + @property + def name(self) -> str: """String. Unique name within Model Registry.""" return self._name @name.setter - def name(self, new_name): + def name(self, new_name: str): self._name = new_name @property - def version(self): - """version""" + def version(self) -> str: + """Version""" return self._version @property - def creation_timestamp(self): + def creation_timestamp(self) -> int: """Integer. Model version creation timestamp (milliseconds since the Unix epoch).""" return self._creation_time @property - def last_updated_timestamp(self): + def last_updated_timestamp(self) -> Optional[int]: """Integer. Timestamp of last update for this model version (milliseconds since the Unix epoch). """ return self._last_updated_timestamp @last_updated_timestamp.setter - def last_updated_timestamp(self, updated_timestamp): + def last_updated_timestamp(self, updated_timestamp: int): self._last_updated_timestamp = updated_timestamp @property - def description(self): + def description(self) -> Optional[str]: """String. Description""" return self._description @description.setter - def description(self, description): + def description(self, description: str): self._description = description @property - def user_id(self): + def user_id(self) -> Optional[str]: """String. User ID that created this model version.""" return self._user_id @property - def current_stage(self): + def current_stage(self) -> Optional[str]: """String. Current stage of this model version.""" return self._current_stage @current_stage.setter - def current_stage(self, stage): + def current_stage(self, stage: str): self._current_stage = stage @property - def source(self): + def source(self) -> Optional[str]: """String. Source path for the model.""" return self._source @property - def run_id(self): + def run_id(self) -> Optional[str]: """String. MLflow run ID that generated this model.""" return self._run_id @property - def run_link(self): + def run_link(self) -> Optional[str]: """String. MLflow run link referring to the exact run that generated this model version.""" return self._run_link @property - def status(self): + def status(self) -> str: """String. Current Model Registry status for this model.""" return self._status @property - def status_message(self): + def status_message(self) -> Optional[str]: """String. Descriptive message for error status conditions.""" return self._status_message @property - def tags(self): + def tags(self) -> Dict[str, str]: """Dictionary of tag key (string) -> tag value for the current model version.""" return self._tags @property - def aliases(self): + def aliases(self) -> List[str]: """List of aliases (string) for the current model version.""" return self._aliases @aliases.setter - def aliases(self, aliases): + def aliases(self, aliases: List[str]): self._aliases = aliases @property @@ -154,16 +154,16 @@ def metrics(self) -> Optional[List[Metric]]: return self._metrics @classmethod - def _properties(cls): + def _properties(cls) -> List[str]: # aggregate with base class properties since cls.__dict__ does not do it automatically return sorted(cls._get_properties_helper()) - def _add_tag(self, tag): + def _add_tag(self, tag: ModelVersionTag): self._tags[tag.key] = tag.value # proto mappers @classmethod - def from_proto(cls, proto): + def from_proto(cls, proto: ProtoModelVersion) -> "ModelVersion": # input: mlflow.protos.model_registry_pb2.ModelVersion # returns: ModelVersion entity model_version = cls( @@ -185,7 +185,7 @@ def from_proto(cls, proto): model_version._add_tag(ModelVersionTag.from_proto(tag)) return model_version - def to_proto(self): + def to_proto(self) -> ProtoModelVersion: # input: ModelVersion entity # returns mlflow.protos.model_registry_pb2.ModelVersion model_version = ProtoModelVersion() From 3743f40e820524952bf7bb316dd72aa5a25e2adb Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 12:04:48 -0700 Subject: [PATCH 13/62] progress Signed-off-by: dbczumar --- mlflow/entities/model_registry/model_version.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mlflow/entities/model_registry/model_version.py b/mlflow/entities/model_registry/model_version.py index 9651abefaaa76..e47f9c8deb1de 100644 --- a/mlflow/entities/model_registry/model_version.py +++ b/mlflow/entities/model_registry/model_version.py @@ -30,6 +30,9 @@ def __init__( tags: Optional[List[ModelVersionTag]] = None, run_link: Optional[str] = None, aliases: Optional[List[str]] = None, + # TODO: Make model_id a required field + # (currently optional to minimize breakages during prototype development) + model_id: Optional[str] = None, params: Optional[List[ModelParam]] = None, metrics: Optional[List[Metric]] = None, ): @@ -48,6 +51,7 @@ def __init__( self._status_message: Optional[str] = status_message self._tags: Dict[str, str] = {tag.key: tag.value for tag in (tags or [])} self._aliases: List[str] = aliases or [] + self._model_id: Optional[str] = model_id self._params: Optional[List[ModelParam]] = params self._metrics: Optional[List[Metric]] = metrics @@ -143,6 +147,11 @@ def aliases(self) -> List[str]: def aliases(self, aliases: List[str]): self._aliases = aliases + @property + def model_id(self) -> Optional[str]: + """String. ID of the model associated with this version.""" + return self._model_id + @property def params(self) -> Optional[List[ModelParam]]: """List of parameters associated with this model version.""" @@ -183,6 +192,7 @@ def from_proto(cls, proto: ProtoModelVersion) -> "ModelVersion": ) for tag in proto.tags: model_version._add_tag(ModelVersionTag.from_proto(tag)) + # TODO: Include params, metrics, and model ID in proto return model_version def to_proto(self) -> ProtoModelVersion: @@ -214,4 +224,5 @@ def to_proto(self) -> ProtoModelVersion: [ProtoModelVersionTag(key=key, value=value) for key, value in self._tags.items()] ) model_version.aliases.extend(self.aliases) + # TODO: Include params, metrics, and model ID in proto return model_version From 1ee6da8ee81578fdb69ddce0426c1e5c6284b098 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 14:05:29 -0700 Subject: [PATCH 14/62] fix Signed-off-by: dbczumar --- mlflow/store/model_registry/file_store.py | 20 +++++++++++++++++++- mlflow/tracking/_tracking_service/client.py | 3 +++ mlflow/tracking/client.py | 3 +++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/mlflow/store/model_registry/file_store.py b/mlflow/store/model_registry/file_store.py index da5869fe5f825..8f2fc81051696 100644 --- a/mlflow/store/model_registry/file_store.py +++ b/mlflow/store/model_registry/file_store.py @@ -5,7 +5,7 @@ import time import urllib from os.path import join -from typing import List +from typing import List, Optional from mlflow.entities.model_registry import ( ModelVersion, @@ -570,9 +570,23 @@ def _get_model_version_aliases(self, directory): return [alias.alias for alias in aliases if alias.version == version] def _get_file_model_version_from_dir(self, directory) -> FileModelVersion: + from mlflow.tracking.client import MlflowClient + meta = FileStore._read_yaml(directory, FileStore.META_DATA_FILE_NAME) meta["tags"] = self._get_model_version_tags_from_dir(directory) meta["aliases"] = self._get_model_version_aliases(directory) + # Fetch metrics and params from model ID + # + # TODO: Propagate tracking URI to file store directly, rather than relying on global + # URI (individual MlflowClient instances may have different tracking URIs) + if "model_id" in meta: + try: + model = MlflowClient().get_model(meta["model_id"]) + meta["metrics"] = model.metrics + meta["params"] = model.params + except Exception: + # TODO: Make this exception handling more specific + pass return FileModelVersion.from_dictionary(meta) def _save_model_version_as_meta_file( @@ -605,6 +619,7 @@ def create_model_version( run_link=None, description=None, local_model_path=None, + model_id: Optional[str] = None, ) -> ModelVersion: """ Create a new model version from given source and run ID. @@ -617,6 +632,8 @@ def create_model_version( instances associated with this model version. run_link: Link to the run from an MLflow tracking server that generated this model. description: Description of the version. + model_id: The ID of the model (from an Experiment) that is being promoted to a model + version, if applicable. Returns: A single object of :py:class:`mlflow.entities.model_registry.ModelVersion` @@ -667,6 +684,7 @@ def next_version(registered_model_name): tags=tags, aliases=[], storage_location=storage_location, + model_id=model_id, ) model_version_dir = self._get_model_version_dir(name, version) mkdir(model_version_dir) diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index bfe7e3e2a2b63..23eddce08f14f 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -999,3 +999,6 @@ def create_model( if params is not None else params, ) + + def get_model(self, model_id: str) -> Model: + return self.store.get_model(model_id) diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index 65ebc47f6bd98..3b60332bf947e 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -4746,3 +4746,6 @@ def create_model( params: Optional[Dict[str, str]] = None, ) -> Model: return self._tracking_client.create_model(experiment_id, name, run_id, tags, params) + + def get_model(self, model_id: str) -> Model: + return self._tracking_client.get_model(model_id) From 0cc71bb9c0b492febd525a6d25d815bb06125fb7 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 17:27:08 -0700 Subject: [PATCH 15/62] fix Signed-off-by: dbczumar --- mlflow/entities/model.py | 18 ++++++++++++------ mlflow/store/tracking/file_store.py | 1 - 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/mlflow/entities/model.py b/mlflow/entities/model.py index 0c8ab7b5a39ee..7e0fc16f1c0b1 100644 --- a/mlflow/entities/model.py +++ b/mlflow/entities/model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from mlflow.entities._mlflow_object import _MlflowObject from mlflow.entities.metric import Metric @@ -24,8 +24,8 @@ def __init__( run_id: Optional[str] = None, status: ModelStatus = ModelStatus.READY, status_message: Optional[str] = None, - tags: Optional[List[ModelTag]] = None, - params: Optional[ModelParam] = None, + tags: Optional[Union[List[ModelTag], Dict[str, str]]] = None, + params: Optional[Union[List[ModelParam], Dict[str, str]]] = None, metrics: Optional[List[Metric]] = None, ): super().__init__() @@ -39,8 +39,14 @@ def __init__( self._run_id: Optional[str] = run_id self._status: ModelStatus = status self._status_message: Optional[str] = status_message - self._tags: Dict[str, str] = {tag.key: tag.value for tag in (tags or [])} - self._params: Optional[ModelParam] = params + self._tags: Dict[str, str] = ( + {tag.key: tag.value for tag in (tags or [])} if isinstance(tags, list) else tags + ) + self._params: Dict[str, str] = ( + {param.key: param.value for param in (params or [])} + if isinstance(params, list) + else params + ) self._metrics: Optional[List[Metric]] = metrics @property @@ -129,7 +135,7 @@ def tags(self) -> Dict[str, str]: return self._tags @property - def params(self) -> Optional[ModelParam]: + def params(self) -> Dict[str, str]: """Model parameters.""" return self._params diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 8011eb0def22f..5a233a6ae2f46 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -2015,7 +2015,6 @@ def _get_model_artifact_dir(self, experiment_id: str, model_id: str) -> str: def _make_persisted_model_dict(self, model: Model) -> Dict[str, Any]: model_dict = model.to_dictionary() model_dict.pop("tags", None) - model_dict["params"] = {param.key: param.value for param in model.params or []} return model_dict def _get_model_dict(self, model_id: str) -> Dict[str, Any]: From bff8a2cc80582877229576edb52f354ccf8fd298 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 17:29:12 -0700 Subject: [PATCH 16/62] finalize Signed-off-by: dbczumar --- mlflow/tracking/_tracking_service/client.py | 4 ++++ mlflow/tracking/client.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index 23eddce08f14f..f5643cc036d06 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -16,6 +16,7 @@ Metric, Model, ModelParam, + ModelStatus, ModelTag, Param, RunStatus, @@ -1000,5 +1001,8 @@ def create_model( else params, ) + def finalize_model(self, model_id: str, status: ModelStatus) -> Model: + return self.store.finalize_model(model_id, status) + def get_model(self, model_id: str) -> Model: return self.store.get_model(model_id) diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index 3b60332bf947e..65bc7a04e9daf 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -25,6 +25,7 @@ FileInfo, Metric, Model, + ModelStatus, Param, Run, RunTag, @@ -4747,5 +4748,8 @@ def create_model( ) -> Model: return self._tracking_client.create_model(experiment_id, name, run_id, tags, params) + def finalize_model(self, model_id: str, status: ModelStatus) -> Model: + return self._tracking_client.finalize_model(model_id, status) + def get_model(self, model_id: str) -> Model: return self._tracking_client.get_model(model_id) From 1a94eaf3376f723ee4b465ea03c5f9c0f495118b Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 17:32:21 -0700 Subject: [PATCH 17/62] Tag setting Signed-off-by: dbczumar --- mlflow/store/tracking/file_store.py | 1 - mlflow/tracking/_tracking_service/client.py | 3 +++ mlflow/tracking/client.py | 3 +++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 5a233a6ae2f46..c960ed458008f 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -2000,7 +2000,6 @@ def set_model_tag(self, model_id: str, tag: ModelTag): make_containing_dirs(tag_path) # Don't add trailing newline write_to(tag_path, self._writeable_value(tag.value)) - return def get_model(self, model_id: str) -> Model: return Model.from_dictionary(self._get_model_dict(model_id)) diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index f5643cc036d06..7d01b54fdfec2 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -1006,3 +1006,6 @@ def finalize_model(self, model_id: str, status: ModelStatus) -> Model: def get_model(self, model_id: str) -> Model: return self.store.get_model(model_id) + + def set_model_tag(self, model_id: str, key: str, value: str): + return self.store.set_model_tag(model_id, ModelTag(key, value)) diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index 65bc7a04e9daf..5e657d8c054d4 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -4753,3 +4753,6 @@ def finalize_model(self, model_id: str, status: ModelStatus) -> Model: def get_model(self, model_id: str) -> Model: return self._tracking_client.get_model(model_id) + + def set_model_tag(self, model_id: str, key: str, value: str): + return self._tracking_client.set_model_tag(model_id, key, value) From 67b896c4fefe798cdb668b3354293b61af3ee72c Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 20:08:12 -0700 Subject: [PATCH 18/62] progress Signed-off-by: dbczumar --- mlflow/models/model.py | 135 +++++++++++--------- mlflow/store/tracking/file_store.py | 1 + mlflow/tracking/_tracking_service/client.py | 22 ++++ mlflow/tracking/client.py | 3 + 4 files changed, 104 insertions(+), 57 deletions(-) diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 170e89f5f2a57..3163ceb33fa0d 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -14,6 +14,7 @@ import mlflow from mlflow.artifacts import download_artifacts +from mlflow.entities import ModelStatus from mlflow.exceptions import MlflowException from mlflow.models.resources import Resource, ResourceType, _ResourceBuilder from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST @@ -689,13 +690,14 @@ def log( A :py:class:`ModelInfo ` instance that contains the metadata of the logged model. """ - from mlflow.utils.model_utils import _validate_and_get_model_config_from_file registered_model = None with TempDir() as tmp: local_path = tmp.path("model") - if run_id is None: - run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id + + # NO LONGER START A RUN! + # if run_id is None: + # run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id mlflow_model = cls( artifact_path=artifact_path, run_id=run_id, metadata=metadata, resources=resources ) @@ -717,66 +719,71 @@ def log( _logger.warning(_LOG_MODEL_MISSING_INPUT_EXAMPLE_WARNING) elif tracking_uri == "databricks" or get_uri_scheme(tracking_uri) == "databricks": _logger.warning(_LOG_MODEL_MISSING_SIGNATURE_WARNING) - mlflow.tracking.fluent.log_artifacts(local_path, mlflow_model.artifact_path, run_id) - - # if the model_config kwarg is passed in, then log the model config as an params - if model_config := kwargs.get("model_config"): - if isinstance(model_config, str): - try: - file_extension = os.path.splitext(model_config)[1].lower() - if file_extension == ".json": - with open(model_config) as f: - model_config = json.load(f) - elif file_extension in [".yaml", ".yml"]: - model_config = _validate_and_get_model_config_from_file(model_config) - else: - _logger.warning( - "Unsupported file format for model config: %s. " - "Failed to load model config.", - model_config, - ) - except Exception as e: - _logger.warning("Failed to load model config from %s: %s", model_config, e) - - try: - from mlflow.models.utils import _flatten_nested_params - # We are using the `/` separator to flatten the nested params - # since we are using the same separator to log nested metrics. - params_to_log = _flatten_nested_params(model_config, sep="/") - except Exception as e: - _logger.warning("Failed to flatten nested params: %s", str(e)) - params_to_log = model_config - - try: - mlflow.tracking.fluent.log_params(params_to_log or {}, run_id=run_id) - except Exception as e: - _logger.warning("Failed to log model config as params: %s", str(e)) - - try: - mlflow.tracking.fluent._record_logged_model(mlflow_model, run_id) - except MlflowException: - # We need to swallow all mlflow exceptions to maintain backwards compatibility with - # older tracking servers. Only print out a warning for now. - _logger.warning(_LOG_MODEL_METADATA_WARNING_TEMPLATE, mlflow.get_artifact_uri()) - _logger.debug("", exc_info=True) - - if registered_model_name is not None: - registered_model = mlflow.tracking._model_registry.fluent._register_model( - f"runs:/{run_id}/{mlflow_model.artifact_path}", - registered_model_name, - await_registration_for=await_registration_for, - local_model_path=local_path, - ) - model_info = mlflow_model.get_model_info() - if registered_model is not None: - model_info.registered_model_version = registered_model.version + # NO LONGER LOG ARTIFACTS. CREATE A MODEL AND FINALIZE IT INSTEAD! + client = mlflow.MlflowClient(tracking_uri) + active_run = mlflow.tracking.fluent.active_run() + model = client.create_model( + experiment_id=mlflow.tracking.fluent._get_experiment_id(), + # TODO: Update model name + name=artifact_path, + run_id=active_run.info.run_id if active_run is not None else None, + ) + client.log_model_artifacts(model.model_id, local_path) + client.finalize_model(model.model_id, status=ModelStatus.READY) + + # mlflow.tracking.fluent.log_artifacts(local_path, mlflow_model.artifact_path, run_id) + + # # if the model_config kwarg is passed in, then log the model config as an params + # if model_config := kwargs.get("model_config"): + # if isinstance(model_config, str): + # try: + # file_extension = os.path.splitext(model_config)[1].lower() + # if file_extension == ".json": + # with open(model_config) as f: + # model_config = json.load(f) + # elif file_extension in [".yaml", ".yml"]: + # model_config = _validate_and_get_model_config_from_file(model_config) + # else: + # _logger.warning( + # "Unsupported file format for model config: %s. " + # "Failed to load model config.", + # model_config, + # ) + # except Exception as e: + # _logger.warning( + # "Failed to load model config from %s: %s", model_config, e + # ) + # + # try: + # from mlflow.models.utils import _flatten_nested_params + # + # # We are using the `/` separator to flatten the nested params + # # since we are using the same separator to log nested metrics. + # params_to_log = _flatten_nested_params(model_config, sep="/") + # except Exception as e: + # _logger.warning("Failed to flatten nested params: %s", str(e)) + # params_to_log = model_config + # + # try: + # mlflow.tracking.fluent.log_params(params_to_log or {}, run_id=run_id) + # except Exception as e: + # _logger.warning("Failed to log model config as params: %s", str(e)) + # + # try: + # mlflow.tracking.fluent._record_logged_model(mlflow_model, run_id) + # except MlflowException: + # # We need to swallow all mlflow exceptions to maintain backwards compatibility + # # with older tracking servers. Only print out a warning for now. + # _logger.warning(_LOG_MODEL_METADATA_WARNING_TEMPLATE, mlflow.get_artifact_uri()) + # _logger.debug("", exc_info=True) # validate input example works for serving when logging the model if serving_input: from mlflow.models import validate_serving_input try: + model_info = mlflow_model.get_model_info() validate_serving_input(model_info.model_uri, serving_input) except Exception as e: _logger.warning( @@ -792,7 +799,21 @@ def log( exc_info=_logger.isEnabledFor(logging.DEBUG), ) - return model_info + if registered_model_name is not None: + registered_model = mlflow.tracking._model_registry.fluent._register_model( + f"runs:/{run_id}/{mlflow_model.artifact_path}", + registered_model_name, + await_registration_for=await_registration_for, + local_model_path=local_path, + ) + return client.get_model_version(registered_model_name, registered_model.version) + else: + return client.get_model(model.model_id) + # model_info = mlflow_model.get_model_info() + # if registered_model is not None: + # model_info.registered_model_version = registered_model.version + + # return model_info def _copy_model_metadata_for_uc_sharing(local_path, flavor): diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index c960ed458008f..42971683c294f 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -2007,6 +2007,7 @@ def get_model(self, model_id: str) -> Model: def _get_model_artifact_dir(self, experiment_id: str, model_id: str) -> str: return append_to_uri_path( self.get_experiment(experiment_id).artifact_location, + FileStore.MODELS_FOLDER_NAME, model_id, FileStore.ARTIFACTS_FOLDER_NAME, ) diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index 7d01b54fdfec2..d28f54e7db2be 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -36,6 +36,7 @@ MlflowTraceDataNotFound, ) from mlflow.protos.databricks_pb2 import BAD_REQUEST, INVALID_PARAMETER_VALUE, ErrorCode +from mlflow.store.artifact.artifact_repo import ArtifactRepository from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository from mlflow.store.entities.paged_list import PagedList from mlflow.store.tracking import ( @@ -1009,3 +1010,24 @@ def get_model(self, model_id: str) -> Model: def set_model_tag(self, model_id: str, key: str, value: str): return self.store.set_model_tag(model_id, ModelTag(key, value)) + + def log_model_artifacts(self, model_id: str, local_dir: str) -> None: + self._get_artifact_repo_for_model(model_id).log_artifacts(local_dir) + + def _get_artifact_repo_for_model(self, model_id: str) -> ArtifactRepository: + # Attempt to fetch the artifact repo from a local cache + cached_repo = utils._artifact_repos_cache.get(model_id) + if cached_repo is not None: + return cached_repo + else: + model = self.get_model(model_id) + artifact_uri = add_databricks_profile_info_to_artifact_uri( + model.artifact_location, self.tracking_uri + ) + artifact_repo = get_artifact_repository(artifact_uri) + # Cache the artifact repo to avoid a future network call, removing the oldest + # entry in the cache if there are too many elements + if len(utils._artifact_repos_cache) > 1024: + utils._artifact_repos_cache.popitem(last=False) + utils._artifact_repos_cache[model_id] = artifact_repo + return artifact_repo diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index 5e657d8c054d4..a64d5419d0202 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -4756,3 +4756,6 @@ def get_model(self, model_id: str) -> Model: def set_model_tag(self, model_id: str, key: str, value: str): return self._tracking_client.set_model_tag(model_id, key, value) + + def log_model_artifacts(self, model_id: str, local_dir: str) -> None: + self._tracking_client.log_model_artifacts(model_id, local_dir) From 8b0e9e44d356ed1cc7754b8a318db55aa12f41b4 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 20:24:08 -0700 Subject: [PATCH 19/62] fix Signed-off-by: dbczumar --- mlflow/store/artifact/models_artifact_repo.py | 11 +++++++++-- mlflow/store/artifact/utils/models.py | 13 ++++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/mlflow/store/artifact/models_artifact_repo.py b/mlflow/store/artifact/models_artifact_repo.py index f54a062f22716..c21878064b78a 100644 --- a/mlflow/store/artifact/models_artifact_repo.py +++ b/mlflow/store/artifact/models_artifact_repo.py @@ -90,8 +90,15 @@ def _get_model_uri_infos(uri): get_databricks_profile_uri_from_artifact_uri(uri) or mlflow.get_registry_uri() ) client = MlflowClient(registry_uri=databricks_profile_uri) - name, version = get_model_name_and_version(client, uri) - download_uri = client.get_model_version_download_uri(name, version) + name_and_version_or_id = get_model_name_and_version(client, uri) + if len(name_and_version_or_id) == 1: + name = None + version = None + model_id = name_and_version_or_id[0] + download_uri = client.get_model(model_id).artifact_location + else: + name, version = name_and_version_or_id + download_uri = client.get_model_version_download_uri(name, version) return ( name, diff --git a/mlflow/store/artifact/utils/models.py b/mlflow/store/artifact/utils/models.py index e4347328d9e52..9e3cab029fc6b 100644 --- a/mlflow/store/artifact/utils/models.py +++ b/mlflow/store/artifact/utils/models.py @@ -37,7 +37,8 @@ def _get_latest_model_version(client, name, stage): class ParsedModelUri(NamedTuple): - name: str + model_id: Optional[str] = None + name: Optional[str] = None version: Optional[str] = None stage: Optional[str] = None alias: Optional[str] = None @@ -47,6 +48,7 @@ def _parse_model_uri(uri): """ Returns a ParsedModelUri tuple. Since a models:/ URI can only have one of {version, stage, 'latest', alias}, it will return + - (id, None, None, None) to look for a specific model by ID, - (name, version, None, None) to look for a specific version, - (name, None, stage, None) to look for the latest version of a stage, - (name, None, None, None) to look for the latest of all versions. @@ -77,16 +79,21 @@ def _parse_model_uri(uri): else: # The suffix is a specific stage (case insensitive), e.g. "models:/AdsModel1/Production" return ParsedModelUri(name, stage=suffix) - else: + elif "@" in path: # The URI is an alias URI, e.g. "models:/AdsModel1@Champion" alias_parts = parts[0].rsplit("@", 1) if len(alias_parts) != 2 or alias_parts[1].strip() == "": raise MlflowException(_improper_model_uri_msg(uri)) return ParsedModelUri(alias_parts[0], alias=alias_parts[1]) + else: + # The URI is of the form "models:/" + return ParsedModelUri(parts[0]) def get_model_name_and_version(client, models_uri): - (model_name, model_version, model_stage, model_alias) = _parse_model_uri(models_uri) + (model_id, model_name, model_version, model_stage, model_alias) = _parse_model_uri(models_uri) + if model_id is not None: + return (model_id,) if model_version is not None: return model_name, model_version if model_alias is not None: From 7e62f6b137575756ad92507436a29fe5d4617f12 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 20:34:29 -0700 Subject: [PATCH 20/62] fix Signed-off-by: dbczumar --- mlflow/models/model.py | 32 +++++++++++++++++++------------- mlflow/pyfunc/__init__.py | 6 +++--- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 3163ceb33fa0d..6c708c0abb83a 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -655,7 +655,7 @@ def from_dict(cls, model_dict): @classmethod def log( cls, - artifact_path, + name, flavor, registered_model_name=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, @@ -669,7 +669,7 @@ def log( active run. Args: - artifact_path: Run relative path identifying the model. + name: The name of the model. flavor: Flavor module to save the model with. The module must have the ``save_model`` function that will persist the model as a valid MLflow model. @@ -695,11 +695,25 @@ def log( with TempDir() as tmp: local_path = tmp.path("model") + tracking_uri = _resolve_tracking_uri() + client = mlflow.MlflowClient(tracking_uri) + active_run = mlflow.tracking.fluent.active_run() + model = client.create_model( + experiment_id=mlflow.tracking.fluent._get_experiment_id(), + # TODO: Update model name + name=name, + run_id=active_run.info.run_id if active_run is not None else None, + ) + # NO LONGER START A RUN! # if run_id is None: # run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id mlflow_model = cls( - artifact_path=artifact_path, run_id=run_id, metadata=metadata, resources=resources + artifact_path=model.artifact_location, + model_uuid=model.model_id, + run_id=active_run.info.run_id if active_run is not None else None, + metadata=metadata, + resources=resources, ) flavor.save_model(path=local_path, mlflow_model=mlflow_model, **kwargs) # `save_model` calls `load_model` to infer the model requirements, which may result in @@ -710,7 +724,6 @@ def log( if is_in_databricks_runtime(): _copy_model_metadata_for_uc_sharing(local_path, flavor) - tracking_uri = _resolve_tracking_uri() serving_input = mlflow_model.get_serving_input(local_path) # We check signature presence here as some flavors have a default signature as a # fallback when not provided by user, which is set during flavor's save_model() call. @@ -720,15 +733,7 @@ def log( elif tracking_uri == "databricks" or get_uri_scheme(tracking_uri) == "databricks": _logger.warning(_LOG_MODEL_MISSING_SIGNATURE_WARNING) - # NO LONGER LOG ARTIFACTS. CREATE A MODEL AND FINALIZE IT INSTEAD! - client = mlflow.MlflowClient(tracking_uri) - active_run = mlflow.tracking.fluent.active_run() - model = client.create_model( - experiment_id=mlflow.tracking.fluent._get_experiment_id(), - # TODO: Update model name - name=artifact_path, - run_id=active_run.info.run_id if active_run is not None else None, - ) + # NO LONGER LOG ARTIFACTS TO A RUN. CREATE A MODEL AND FINALIZE IT INSTEAD! client.log_model_artifacts(model.model_id, local_path) client.finalize_model(model.model_id, status=ModelStatus.READY) @@ -801,6 +806,7 @@ def log( if registered_model_name is not None: registered_model = mlflow.tracking._model_registry.fluent._register_model( + # TODO: Fix this! f"runs:/{run_id}/{mlflow_model.artifact_path}", registered_model_name, await_registration_for=await_registration_for, diff --git a/mlflow/pyfunc/__init__.py b/mlflow/pyfunc/__init__.py index 2c79912cf47ce..a0ac88e119870 100644 --- a/mlflow/pyfunc/__init__.py +++ b/mlflow/pyfunc/__init__.py @@ -2633,7 +2633,7 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]: @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn")) @trace_disabled # Suppress traces for internal predict calls while logging model def log_model( - artifact_path, + name, loader_module=None, data_path=None, code_path=None, # deprecated @@ -2665,7 +2665,7 @@ def log_model( and the parameters for the first workflow: ``python_model``, ``artifacts`` together. Args: - artifact_path: The run-relative artifact path to which to log the Python model. + name: The name of the model. loader_module: The name of the Python module that is used to load the model from ``data_path``. This module must define a method with the prototype ``_load_pyfunc(data_path)``. If not ``None``, this module and its @@ -2852,7 +2852,7 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]: metadata of the logged model. """ return Model.log( - artifact_path=artifact_path, + name=name, flavor=mlflow.pyfunc, loader_module=loader_module, data_path=data_path, From b1c52ebf34779f5b98f2d06d152f947f000cd2a5 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 20:41:24 -0700 Subject: [PATCH 21/62] fixen Signed-off-by: dbczumar --- mlflow/models/model.py | 2 ++ mlflow/pyfunc/__init__.py | 2 ++ mlflow/store/tracking/file_store.py | 2 ++ mlflow/tracking/_tracking_service/client.py | 2 ++ mlflow/tracking/client.py | 5 ++++- 5 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 6c708c0abb83a..7fa634e7faa60 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -662,6 +662,7 @@ def log( metadata=None, run_id=None, resources=None, + model_type: Optional[str] = None, **kwargs, ): """ @@ -703,6 +704,7 @@ def log( # TODO: Update model name name=name, run_id=active_run.info.run_id if active_run is not None else None, + model_type=model_type, ) # NO LONGER START A RUN! diff --git a/mlflow/pyfunc/__init__.py b/mlflow/pyfunc/__init__.py index a0ac88e119870..8fc993bf1e6b6 100644 --- a/mlflow/pyfunc/__init__.py +++ b/mlflow/pyfunc/__init__.py @@ -2653,6 +2653,7 @@ def log_model( example_no_conversion=None, streamable=None, resources: Optional[Union[str, List[Resource]]] = None, + model_type: Optional[str] = None, ): """ Log a Pyfunc model with custom inference logic and optional data dependencies as an MLflow @@ -2873,6 +2874,7 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]: streamable=streamable, resources=resources, infer_code_paths=infer_code_paths, + model_type=model_type, ) diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 42971683c294f..56715dfa3cb6c 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -1907,6 +1907,7 @@ def create_model( run_id: Optional[str] = None, tags: Optional[List[ModelTag]] = None, params: Optional[List[ModelParam]] = None, + model_type: Optional[str] = None, ) -> Model: """ Create a new model. @@ -1951,6 +1952,7 @@ def create_model( status=ModelStatus.PENDING, tags=tags, params=params, + model_type=model_type, ) # Persist model metadata and create directories for logging metrics, tags diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index d28f54e7db2be..faa9f55540149 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -989,6 +989,7 @@ def create_model( run_id: Optional[str] = None, tags: Optional[Dict[str, str]] = None, params: Optional[Dict[str, str]] = None, + model_type: Optional[str] = None, ) -> Model: return self.store.create_model( experiment_id=experiment_id, @@ -1000,6 +1001,7 @@ def create_model( params=[ModelParam(key, value) for key, value in params.items()] if params is not None else params, + model_type=model_type, ) def finalize_model(self, model_id: str, status: ModelStatus) -> Model: diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index a64d5419d0202..83ed2177ad896 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -4745,8 +4745,11 @@ def create_model( run_id: Optional[str] = None, tags: Optional[Dict[str, str]] = None, params: Optional[Dict[str, str]] = None, + model_type: Optional[str] = None, ) -> Model: - return self._tracking_client.create_model(experiment_id, name, run_id, tags, params) + return self._tracking_client.create_model( + experiment_id, name, run_id, tags, params, model_type + ) def finalize_model(self, model_id: str, status: ModelStatus) -> Model: return self._tracking_client.finalize_model(model_id, status) From 0d7640b2b76f64276a871b31a3f33c2fe08dcc01 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 20:44:25 -0700 Subject: [PATCH 22/62] fixen Signed-off-by: dbczumar --- mlflow/models/model.py | 4 ++++ mlflow/pyfunc/__init__.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 7fa634e7faa60..5195df9d5bd58 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -663,6 +663,8 @@ def log( run_id=None, resources=None, model_type: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, Any]] = None, **kwargs, ): """ @@ -705,6 +707,8 @@ def log( name=name, run_id=active_run.info.run_id if active_run is not None else None, model_type=model_type, + params={key: str(value) for key, value in params.items()}, + tags={key: str(value) for key, value in tags.items()}, ) # NO LONGER START A RUN! diff --git a/mlflow/pyfunc/__init__.py b/mlflow/pyfunc/__init__.py index 8fc993bf1e6b6..e1034ed1aa6ff 100644 --- a/mlflow/pyfunc/__init__.py +++ b/mlflow/pyfunc/__init__.py @@ -2653,6 +2653,8 @@ def log_model( example_no_conversion=None, streamable=None, resources: Optional[Union[str, List[Resource]]] = None, + params: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, Any]] = None, model_type: Optional[str] = None, ): """ @@ -2875,6 +2877,8 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]: resources=resources, infer_code_paths=infer_code_paths, model_type=model_type, + params=params, + tags=tags, ) From af40b62f31f193b3da6c5ff008aa63ed8cd43d28 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 20:45:14 -0700 Subject: [PATCH 23/62] fix Signed-off-by: dbczumar --- mlflow/models/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 5195df9d5bd58..52951fc547de3 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -707,8 +707,10 @@ def log( name=name, run_id=active_run.info.run_id if active_run is not None else None, model_type=model_type, - params={key: str(value) for key, value in params.items()}, - tags={key: str(value) for key, value in tags.items()}, + params={key: str(value) for key, value in params.items()} + if params is not None + else None, + tags={key: str(value) for key, value in tags.items()} if tags is not None else None, ) # NO LONGER START A RUN! From 4b270789f1faf93b536be951a4c40d0e672eea23 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 21:24:57 -0700 Subject: [PATCH 24/62] fix Signed-off-by: dbczumar --- mlflow/models/model.py | 7 ++++++- mlflow/pyfunc/__init__.py | 2 ++ mlflow/tracking/_tracking_service/client.py | 16 ++++++++++++---- mlflow/tracking/client.py | 9 ++++++++- 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 52951fc547de3..1155518df4dbf 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -14,7 +14,7 @@ import mlflow from mlflow.artifacts import download_artifacts -from mlflow.entities import ModelStatus +from mlflow.entities import ModelOutput, ModelStatus from mlflow.exceptions import MlflowException from mlflow.models.resources import Resource, ResourceType, _ResourceBuilder from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST @@ -665,6 +665,7 @@ def log( model_type: Optional[str] = None, params: Optional[Dict[str, Any]] = None, tags: Optional[Dict[str, Any]] = None, + step: int = 0, **kwargs, ): """ @@ -712,6 +713,10 @@ def log( else None, tags={key: str(value) for key, value in tags.items()} if tags is not None else None, ) + if active_run is not None: + client.log_outputs( + run_id=active_run.info.run_id, models=[ModelOutput(model.model_id, step=step)] + ) # NO LONGER START A RUN! # if run_id is None: diff --git a/mlflow/pyfunc/__init__.py b/mlflow/pyfunc/__init__.py index e1034ed1aa6ff..80b2cd6b8c4e6 100644 --- a/mlflow/pyfunc/__init__.py +++ b/mlflow/pyfunc/__init__.py @@ -2656,6 +2656,7 @@ def log_model( params: Optional[Dict[str, Any]] = None, tags: Optional[Dict[str, Any]] = None, model_type: Optional[str] = None, + step: int = 0, ): """ Log a Pyfunc model with custom inference logic and optional data dependencies as an MLflow @@ -2879,6 +2880,7 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]: model_type=model_type, params=params, tags=tags, + step=step, ) diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index faa9f55540149..b3db59444651c 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -15,6 +15,8 @@ ExperimentTag, Metric, Model, + ModelInput, + ModelOutput, ModelParam, ModelStatus, ModelTag, @@ -757,12 +759,18 @@ def log_batch( # Merge all the run operations into a single run operations object return get_combined_run_operations(run_operations_list) - def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None): + def log_inputs( + self, + run_id: str, + datasets: Optional[List[DatasetInput]] = None, + models: Optional[List[ModelInput]] = None, + ): """Log one or more dataset inputs to a run. Args: run_id: String ID of the run datasets: List of :py:class:`mlflow.entities.DatasetInput` instances to log. + models: List of :py:class:`mlflow.entities.ModelInput` instances to log. Raises: MlflowException: If any errors occur. @@ -770,10 +778,10 @@ def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None) Returns: None """ - if datasets is None or len(datasets) == 0: - return + self.store.log_inputs(run_id=run_id, datasets=datasets, models=models) - self.store.log_inputs(run_id=run_id, datasets=datasets) + def log_outputs(self, run_id: str, models: List[ModelOutput]): + self.store.log_outputs(run_id=run_id, models=models) def _record_logged_model(self, run_id, mlflow_model): from mlflow.models import Model diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index 83ed2177ad896..6252b567feb47 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -25,6 +25,8 @@ FileInfo, Metric, Model, + ModelInput, + ModelOutput, ModelStatus, Param, Run, @@ -1862,6 +1864,7 @@ def log_inputs( self, run_id: str, datasets: Optional[Sequence[DatasetInput]] = None, + models: Optional[Sequence[ModelInput]] = None, ) -> None: """ Log one or more dataset inputs to a run. @@ -1869,11 +1872,15 @@ def log_inputs( Args: run_id: String ID of the run. datasets: List of :py:class:`mlflow.entities.DatasetInput` instances to log. + models: List of :py:class:`mlflow.entities.ModelInput` instances to log. Raises: mlflow.MlflowException: If any errors occur. """ - self._tracking_client.log_inputs(run_id, datasets) + self._tracking_client.log_inputs(run_id, datasets, models) + + def log_outputs(self, run_id: str, models: Sequence[ModelOutput]): + self._tracking_client.log_outputs(run_id, models) def log_artifact(self, run_id, local_path, artifact_path=None) -> None: """Write a local file or directory to the remote ``artifact_uri``. From 54a86b47629928a11ecd7daf1565493690853ba7 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 21:27:47 -0700 Subject: [PATCH 25/62] fix Signed-off-by: dbczumar --- mlflow/langchain/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlflow/langchain/__init__.py b/mlflow/langchain/__init__.py index 688bba77b2301..57a1f9308b8bf 100644 --- a/mlflow/langchain/__init__.py +++ b/mlflow/langchain/__init__.py @@ -422,6 +422,10 @@ def log_model( run_id=None, model_config=None, streamable=None, + params: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, Any]] = None, + model_type: Optional[str] = None, + step: int = 0, ): """ Log a LangChain model as an MLflow artifact for the current run. From 75dfbce9302d5a261105aeee5bd0fa4f1bc46143 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 21:29:50 -0700 Subject: [PATCH 26/62] fix Signed-off-by: dbczumar --- mlflow/langchain/__init__.py | 4 ++++ mlflow/pyfunc/__init__.py | 2 +- mlflow/pytorch/__init__.py | 8 ++++++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/mlflow/langchain/__init__.py b/mlflow/langchain/__init__.py index 57a1f9308b8bf..d35d8225c0b46 100644 --- a/mlflow/langchain/__init__.py +++ b/mlflow/langchain/__init__.py @@ -569,6 +569,10 @@ def load_retriever(persist_directory): run_id=run_id, model_config=model_config, streamable=streamable, + params=params, + tags=tags, + model_type=model_type, + step=step, ) diff --git a/mlflow/pyfunc/__init__.py b/mlflow/pyfunc/__init__.py index 80b2cd6b8c4e6..4e01542fb5cf8 100644 --- a/mlflow/pyfunc/__init__.py +++ b/mlflow/pyfunc/__init__.py @@ -2877,9 +2877,9 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]: streamable=streamable, resources=resources, infer_code_paths=infer_code_paths, - model_type=model_type, params=params, tags=tags, + model_type=model_type, step=step, ) diff --git a/mlflow/pytorch/__init__.py b/mlflow/pytorch/__init__.py index 198ac1ff6c609..ae7b83241add1 100644 --- a/mlflow/pytorch/__init__.py +++ b/mlflow/pytorch/__init__.py @@ -150,6 +150,10 @@ def log_model( pip_requirements=None, extra_pip_requirements=None, metadata=None, + params: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, Any]] = None, + model_type: Optional[str] = None, + step: int = 0, **kwargs, ): """ @@ -308,6 +312,10 @@ class definition itself, should be included in one of the following locations: pip_requirements=pip_requirements, extra_pip_requirements=extra_pip_requirements, metadata=metadata, + params=params, + tags=tags, + model_type=model_type, + step=step, **kwargs, ) From a8ef24e955c3a77d9b9bbae4238b2e0b01bcc82d Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 21:32:03 -0700 Subject: [PATCH 27/62] fix Signed-off-by: dbczumar --- mlflow/langchain/__init__.py | 6 +++--- mlflow/pytorch/__init__.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mlflow/langchain/__init__.py b/mlflow/langchain/__init__.py index d35d8225c0b46..c34eaa8322223 100644 --- a/mlflow/langchain/__init__.py +++ b/mlflow/langchain/__init__.py @@ -406,7 +406,7 @@ def load_retriever(persist_directory): @trace_disabled # Suppress traces for internal predict calls while logging model def log_model( lc_model, - artifact_path, + name: str, conda_env=None, code_paths=None, registered_model_name=None, @@ -441,7 +441,7 @@ def log_model( .. Note:: Experimental: Using model as path may change or be removed in a future release without warning. - artifact_path: Run-relative artifact path. + name: The name of the model. conda_env: {{ conda_env }} code_paths: {{ code_paths }} registered_model_name: This argument may change or be removed in a @@ -551,7 +551,7 @@ def load_retriever(persist_directory): metadata of the logged model. """ return Model.log( - artifact_path=artifact_path, + name=name, flavor=mlflow.langchain, registered_model_name=registered_model_name, lc_model=lc_model, diff --git a/mlflow/pytorch/__init__.py b/mlflow/pytorch/__init__.py index ae7b83241add1..26df134ab4984 100644 --- a/mlflow/pytorch/__init__.py +++ b/mlflow/pytorch/__init__.py @@ -137,7 +137,7 @@ def get_default_conda_env(): @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="torch")) def log_model( pytorch_model, - artifact_path, + name: str, conda_env=None, code_paths=None, pickle_module=None, @@ -181,7 +181,7 @@ class definition itself, should be included in one of the following locations: ``conda_env`` parameter. - One or more of the files specified by the ``code_paths`` parameter. - artifact_path: Run-relative artifact path. + name: The name of the model. conda_env: {{ conda_env }} code_paths: {{ code_paths }} pickle_module: The module that PyTorch should use to serialize ("pickle") the specified @@ -297,7 +297,7 @@ class definition itself, should be included in one of the following locations: """ pickle_module = pickle_module or mlflow_pytorch_pickle_module return Model.log( - artifact_path=artifact_path, + name=name, flavor=mlflow.pytorch, pytorch_model=pytorch_model, conda_env=conda_env, From 71115663bdffe3d38822f8b4487cf420b044b66b Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 21:39:58 -0700 Subject: [PATCH 28/62] fix Signed-off-by: dbczumar --- mlflow/models/model.py | 3 ++- mlflow/store/model_registry/file_store.py | 4 ++-- mlflow/tracking/_model_registry/client.py | 5 +++++ mlflow/tracking/_model_registry/fluent.py | 11 ++++++++++- mlflow/tracking/client.py | 2 ++ 5 files changed, 21 insertions(+), 4 deletions(-) diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 1155518df4dbf..563933c03226c 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -820,10 +820,11 @@ def log( if registered_model_name is not None: registered_model = mlflow.tracking._model_registry.fluent._register_model( # TODO: Fix this! - f"runs:/{run_id}/{mlflow_model.artifact_path}", + f"models:/{model.model_id}", registered_model_name, await_registration_for=await_registration_for, local_model_path=local_path, + model_id=model.model_id, ) return client.get_model_version(registered_model_name, registered_model.version) else: diff --git a/mlflow/store/model_registry/file_store.py b/mlflow/store/model_registry/file_store.py index 8f2fc81051696..0ce2256ea70dd 100644 --- a/mlflow/store/model_registry/file_store.py +++ b/mlflow/store/model_registry/file_store.py @@ -632,8 +632,8 @@ def create_model_version( instances associated with this model version. run_link: Link to the run from an MLflow tracking server that generated this model. description: Description of the version. - model_id: The ID of the model (from an Experiment) that is being promoted to a model - version, if applicable. + model_id: The ID of the model (from an Experiment) that is being promoted to a + registered model version, if applicable. Returns: A single object of :py:class:`mlflow.entities.model_registry.ModelVersion` diff --git a/mlflow/tracking/_model_registry/client.py b/mlflow/tracking/_model_registry/client.py index 5cbe391b7debe..d3773bc029077 100644 --- a/mlflow/tracking/_model_registry/client.py +++ b/mlflow/tracking/_model_registry/client.py @@ -4,6 +4,7 @@ exposed in the :py:mod:`mlflow.tracking` module. """ import logging +from typing import Optional from mlflow.entities.model_registry import ModelVersionTag, RegisteredModelTag from mlflow.exceptions import MlflowException @@ -188,6 +189,7 @@ def create_model_version( description=None, await_creation_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, local_model_path=None, + model_id: Optional[str] = None, ): """Create a new model version from given source. @@ -202,6 +204,8 @@ def create_model_version( await_creation_for: Number of seconds to wait for the model version to finish being created and is in ``READY`` status. By default, the function waits for five minutes. Specify 0 or None to skip waiting. + model_id: The ID of the model (from an Experiment) that is being promoted to a + registered model version, if applicable. Returns: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object created by @@ -220,6 +224,7 @@ def create_model_version( run_link, description, local_model_path=local_model_path, + model_id=model_id, ) else: # Fall back to calling create_model_version without diff --git a/mlflow/tracking/_model_registry/fluent.py b/mlflow/tracking/_model_registry/fluent.py index 50892ab28063c..7bfff8c013bc7 100644 --- a/mlflow/tracking/_model_registry/fluent.py +++ b/mlflow/tracking/_model_registry/fluent.py @@ -20,6 +20,7 @@ def register_model( await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, *, tags: Optional[Dict[str, Any]] = None, + model_id: Optional[str] = None, ) -> ModelVersion: """Create a new model version in model registry for the model files specified by ``model_uri``. @@ -40,6 +41,8 @@ def register_model( waits for five minutes. Specify 0 or None to skip waiting. tags: A dictionary of key-value pairs that are converted into :py:class:`mlflow.entities.model_registry.ModelVersionTag` objects. + model_id: The ID of the model (from an Experiment) that is being promoted to a registered + model version, if applicable. Returns: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object created by @@ -75,7 +78,11 @@ def register_model( Version: 1 """ return _register_model( - model_uri=model_uri, name=name, await_registration_for=await_registration_for, tags=tags + model_uri=model_uri, + name=name, + await_registration_for=await_registration_for, + tags=tags, + model_id=model_id, ) @@ -86,6 +93,7 @@ def _register_model( *, tags: Optional[Dict[str, Any]] = None, local_model_path=None, + model_id: Optional[str] = None, ) -> ModelVersion: client = MlflowClient() try: @@ -116,6 +124,7 @@ def _register_model( tags=tags, await_creation_for=await_registration_for, local_model_path=local_model_path, + model_id=model_id, ) eprint( f"Created version '{create_version_response.version}' of model " diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index 6252b567feb47..c5395c627980c 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -3594,6 +3594,7 @@ def _create_model_version( description: Optional[str] = None, await_creation_for: int = DEFAULT_AWAIT_MAX_SLEEP_SECONDS, local_model_path: Optional[str] = None, + model_id: Optional[str] = None, ) -> ModelVersion: tracking_uri = self._tracking_client.tracking_uri if ( @@ -3636,6 +3637,7 @@ def _create_model_version( description=description, await_creation_for=await_creation_for, local_model_path=local_model_path, + model_id=model_id, ) def create_model_version( From a92456cf126dd9b1de6a77ad569fef74b7d9bb9e Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 21:46:43 -0700 Subject: [PATCH 29/62] fix Signed-off-by: dbczumar --- mlflow/store/model_registry/file_store.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/mlflow/store/model_registry/file_store.py b/mlflow/store/model_registry/file_store.py index 0ce2256ea70dd..05ebc1e6d0f9f 100644 --- a/mlflow/store/model_registry/file_store.py +++ b/mlflow/store/model_registry/file_store.py @@ -656,9 +656,18 @@ def next_version(registered_model_name): if urllib.parse.urlparse(source).scheme == "models": parsed_model_uri = _parse_model_uri(source) try: - storage_location = self.get_model_version_download_uri( - parsed_model_uri.name, parsed_model_uri.version - ) + from mlflow.tracking.client import MlflowClient + + if parsed_model_uri.model_id is not None: + # TODO: Propagate tracking URI to file store directly, rather than relying on + # global URI (individual MlflowClient instances may have different tracking + # URIs) + model = MlflowClient().get_model(parsed_model_uri.model_id) + storage_location = model.artifact_location + else: + storage_location = self.get_model_version_download_uri( + parsed_model_uri.name, parsed_model_uri.version + ) except Exception as e: raise MlflowException( f"Unable to fetch model from model URI source artifact location '{source}'." From e5d035b8e590834a69aae35be7f10213bb1bf309 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 21:59:45 -0700 Subject: [PATCH 30/62] fix Signed-off-by: dbczumar --- mlflow/models/model.py | 1 - mlflow/tracking/_model_registry/fluent.py | 9 +++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 563933c03226c..2611e215f443e 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -824,7 +824,6 @@ def log( registered_model_name, await_registration_for=await_registration_for, local_model_path=local_path, - model_id=model.model_id, ) return client.get_model_version(registered_model_name, registered_model.version) else: diff --git a/mlflow/tracking/_model_registry/fluent.py b/mlflow/tracking/_model_registry/fluent.py index 7bfff8c013bc7..8d5b66c9d2de2 100644 --- a/mlflow/tracking/_model_registry/fluent.py +++ b/mlflow/tracking/_model_registry/fluent.py @@ -4,6 +4,7 @@ from mlflow.exceptions import MlflowException from mlflow.protos.databricks_pb2 import ALREADY_EXISTS, RESOURCE_ALREADY_EXISTS, ErrorCode from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository +from mlflow.store.artifact.utils.models import _parse_model_uri from mlflow.store.model_registry import ( SEARCH_MODEL_VERSION_MAX_RESULTS_DEFAULT, SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT, @@ -20,7 +21,6 @@ def register_model( await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, *, tags: Optional[Dict[str, Any]] = None, - model_id: Optional[str] = None, ) -> ModelVersion: """Create a new model version in model registry for the model files specified by ``model_uri``. @@ -41,8 +41,6 @@ def register_model( waits for five minutes. Specify 0 or None to skip waiting. tags: A dictionary of key-value pairs that are converted into :py:class:`mlflow.entities.model_registry.ModelVersionTag` objects. - model_id: The ID of the model (from an Experiment) that is being promoted to a registered - model version, if applicable. Returns: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object created by @@ -82,7 +80,6 @@ def register_model( name=name, await_registration_for=await_registration_for, tags=tags, - model_id=model_id, ) @@ -93,7 +90,6 @@ def _register_model( *, tags: Optional[Dict[str, Any]] = None, local_model_path=None, - model_id: Optional[str] = None, ) -> ModelVersion: client = MlflowClient() try: @@ -117,6 +113,7 @@ def _register_model( source = RunsArtifactRepository.get_underlying_uri(model_uri) (run_id, _) = RunsArtifactRepository.parse_runs_uri(model_uri) + parsed_model_uri = _parse_model_uri(model_uri) create_version_response = client._create_model_version( name=name, source=source, @@ -124,7 +121,7 @@ def _register_model( tags=tags, await_creation_for=await_registration_for, local_model_path=local_model_path, - model_id=model_id, + model_id=parsed_model_uri.model_id, ) eprint( f"Created version '{create_version_response.version}' of model " From d0184864ac5e453431a6ab3095f77dbab6e5ba5c Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 22:22:48 -0700 Subject: [PATCH 31/62] progress Signed-off-by: dbczumar --- mlflow/tracking/_tracking_service/client.py | 11 +++++++++-- mlflow/tracking/client.py | 9 ++++++++- mlflow/tracking/fluent.py | 12 +++++++++++- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index b3db59444651c..7e959428343dd 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -539,7 +539,14 @@ def rename_experiment(self, experiment_id, new_name): self.store.rename_experiment(experiment_id, new_name) def log_metric( - self, run_id, key, value, timestamp=None, step=None, synchronous=True + self, + run_id, + key, + value, + timestamp=None, + step=None, + synchronous=True, + model_id: Optional[str] = None, ) -> Optional[RunOperations]: """Log a metric against the run ID. @@ -566,7 +573,7 @@ def log_metric( timestamp = timestamp if timestamp is not None else get_current_time_millis() step = step if step is not None else 0 metric_value = convert_metric_value_to_float_if_possible(value) - metric = Metric(key, metric_value, timestamp, step) + metric = Metric(key, metric_value, timestamp, step, model_id=model_id) if synchronous: self.store.log_metric(run_id, metric) else: diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index c5395c627980c..267ebbc7edb15 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -1444,6 +1444,7 @@ def log_metric( timestamp: Optional[int] = None, step: Optional[int] = None, synchronous: Optional[bool] = None, + model_id: Optional[str] = None, ) -> Optional[RunOperations]: """ Log a metric against the run ID. @@ -1519,7 +1520,13 @@ def print_run_info(r): synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() ) return self._tracking_client.log_metric( - run_id, key, value, timestamp, step, synchronous=synchronous + run_id, + key, + value, + timestamp, + step, + synchronous=synchronous, + model_id=model_id, ) def log_param( diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index 257627f316b62..9d8bebefe1b0b 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -18,6 +18,7 @@ Experiment, InputTag, Metric, + ModelInput, Param, Run, RunStatus, @@ -825,6 +826,7 @@ def log_metric( synchronous: Optional[bool] = None, timestamp: Optional[int] = None, run_id: Optional[str] = None, + model_id: Optional[str] = None, ) -> Optional[RunOperations]: """ Log a metric under the current run. If no run is active, this method will create @@ -868,13 +870,21 @@ def log_metric( """ run_id = run_id or _get_or_start_run().info.run_id synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() - return MlflowClient().log_metric( + client = MlflowClient() + if model_id is not None: + run = client.get_run(run_id) + if model_id not in [inp.model_id for inp in run.inputs.model_inputs] + [ + output.model_id for output in run.outputs.model_outputs + ]: + client.log_inputs(run_id, models=[ModelInput(model_id=model_id)]) + return client.log_metric( run_id, key, value, timestamp or get_current_time_millis(), step or 0, synchronous=synchronous, + model_id=model_id, ) From 3abecff82af3824d826f0d0719575eb939779b13 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 22:31:22 -0700 Subject: [PATCH 32/62] fix Signed-off-by: dbczumar --- mlflow/tracking/fluent.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index 9d8bebefe1b0b..486454d0661e3 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -870,14 +870,19 @@ def log_metric( """ run_id = run_id or _get_or_start_run().info.run_id synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() - client = MlflowClient() - if model_id is not None: - run = client.get_run(run_id) - if model_id not in [inp.model_id for inp in run.inputs.model_inputs] + [ - output.model_id for output in run.outputs.model_outputs - ]: - client.log_inputs(run_id, models=[ModelInput(model_id=model_id)]) - return client.log_metric( + _log_inputs_for_metrics( + run_id, + [ + Metric( + key=key, + value=value, + timestamp=timestamp or get_current_time_millis(), + step=step or 0, + model_id=model_id, + ), + ], + ) + return MlflowClient().log_metric( run_id, key, value, @@ -888,6 +893,16 @@ def log_metric( ) +def _log_inputs_for_metrics(run_id, metrics: List[Metric]) -> None: + client = MlflowClient() + run = client.get_run(run_id) + for metric in [metric for metric in metrics if metric.model_id is not None]: + if metric.model_id not in [inp.model_id for inp in run.inputs.model_inputs] + [ + output.model_id for output in run.outputs.model_outputs + ]: + client.log_inputs(run_id, models=[ModelInput(model_id=metric.model_id)]) + + def log_metrics( metrics: Dict[str, float], step: Optional[int] = None, From 04430565043c7038f1f9821b23bbaf67ec55caca Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 22:32:37 -0700 Subject: [PATCH 33/62] prog Signed-off-by: dbczumar --- mlflow/tracking/fluent.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index 486454d0661e3..e27d5fc03e407 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -870,7 +870,7 @@ def log_metric( """ run_id = run_id or _get_or_start_run().info.run_id synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() - _log_inputs_for_metrics( + _log_inputs_for_metrics_if_necessary( run_id, [ Metric( @@ -893,7 +893,7 @@ def log_metric( ) -def _log_inputs_for_metrics(run_id, metrics: List[Metric]) -> None: +def _log_inputs_for_metrics_if_necessary(run_id, metrics: List[Metric]) -> None: client = MlflowClient() run = client.get_run(run_id) for metric in [metric for metric in metrics if metric.model_id is not None]: @@ -951,6 +951,7 @@ def log_metrics( run_id = run_id or _get_or_start_run().info.run_id timestamp = timestamp or get_current_time_millis() metrics_arr = [Metric(key, value, timestamp, step or 0) for key, value in metrics.items()] + _log_inputs_for_metrics_if_necessary(run_id, metrics_arr) synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() return MlflowClient().log_batch( run_id=run_id, metrics=metrics_arr, params=[], tags=[], synchronous=synchronous From e33d8c56051138ca2cd9c481c4663ea0ab4bf23a Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 22:51:48 -0700 Subject: [PATCH 34/62] prog Signed-off-by: dbczumar --- mlflow/models/model.py | 31 +++++++++++++++++++-- mlflow/store/tracking/file_store.py | 1 + mlflow/tracking/_tracking_service/client.py | 12 +++++--- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 2611e215f443e..2518f22ba08cc 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -14,7 +14,7 @@ import mlflow from mlflow.artifacts import download_artifacts -from mlflow.entities import ModelOutput, ModelStatus +from mlflow.entities import Metric, ModelOutput, ModelStatus from mlflow.exceptions import MlflowException from mlflow.models.resources import Resource, ResourceType, _ResourceBuilder from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST @@ -695,6 +695,29 @@ def log( metadata of the logged model. """ + def log_model_metrics_for_step(client, model_id, run_id, step): + metric_names = client.get_run(run_id).data.metrics.keys() + metrics_for_step = [] + for metric_name in metric_names: + history = client.get_metric_history(run_id, metric_name) + metrics_for_step.extend( + [ + Metric( + key=metric.key, + value=metric.value, + timestamp=metric.timestamp, + step=metric.step, + dataset_name=metric.dataset_name, + dataset_digest=metric.dataset_digest, + run_id=metric.run_id, + model_id=model_id, + ) + for metric in history + if metric.step == step and metric.model_id is None + ] + ) + client.log_batch(run_id=run_id, metrics=metrics_for_step) + registered_model = None with TempDir() as tmp: local_path = tmp.path("model") @@ -714,8 +737,10 @@ def log( tags={key: str(value) for key, value in tags.items()} if tags is not None else None, ) if active_run is not None: - client.log_outputs( - run_id=active_run.info.run_id, models=[ModelOutput(model.model_id, step=step)] + run_id = active_run.info.run_id + client.log_outputs(run_id=run_id, models=[ModelOutput(model.model_id, step=step)]) + log_model_metrics_for_step( + client=client, model_id=model.model_id, run_id=run_id, step=step ) # NO LONGER START A RUN! diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 56715dfa3cb6c..82611222889c4 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -2017,6 +2017,7 @@ def _get_model_artifact_dir(self, experiment_id: str, model_id: str) -> str: def _make_persisted_model_dict(self, model: Model) -> Dict[str, Any]: model_dict = model.to_dictionary() model_dict.pop("tags", None) + model_dict.pop("metrics", None) return model_dict def _get_model_dict(self, model_id: str) -> Dict[str, Any]: diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index 7e959428343dd..20eb338613ad2 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -712,10 +712,14 @@ def log_batch( metrics = [ Metric( - metric.key, - convert_metric_value_to_float_if_possible(metric.value), - metric.timestamp, - metric.step, + key=metric.key, + value=convert_metric_value_to_float_if_possible(metric.value), + timestamp=metric.timestamp, + step=metric.step, + dataset_name=metric.dataset_name, + dataset_digest=metric.dataset_digest, + model_id=metric.model_id, + run_id=metric.run_id, ) for metric in metrics ] From bd8787a3b13d95dcf6a489784e236ad679cc1943 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 22:56:13 -0700 Subject: [PATCH 35/62] prog Signed-off-by: dbczumar --- mlflow/store/model_registry/file_store.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mlflow/store/model_registry/file_store.py b/mlflow/store/model_registry/file_store.py index 05ebc1e6d0f9f..3f805369a9832 100644 --- a/mlflow/store/model_registry/file_store.py +++ b/mlflow/store/model_registry/file_store.py @@ -664,6 +664,7 @@ def next_version(registered_model_name): # URIs) model = MlflowClient().get_model(parsed_model_uri.model_id) storage_location = model.artifact_location + run_id = run_id or model.run_id else: storage_location = self.get_model_version_download_uri( parsed_model_uri.name, parsed_model_uri.version From 70afd46370cd4eb53b1e07a0e88ace7b0547c76f Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 22:59:40 -0700 Subject: [PATCH 36/62] progress Signed-off-by: dbczumar --- mlflow/models/model.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 2518f22ba08cc..36f005b142b44 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -743,9 +743,6 @@ def log_model_metrics_for_step(client, model_id, run_id, step): client=client, model_id=model.model_id, run_id=run_id, step=step ) - # NO LONGER START A RUN! - # if run_id is None: - # run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id mlflow_model = cls( artifact_path=model.artifact_location, model_uuid=model.model_id, @@ -771,12 +768,9 @@ def log_model_metrics_for_step(client, model_id, run_id, step): elif tracking_uri == "databricks" or get_uri_scheme(tracking_uri) == "databricks": _logger.warning(_LOG_MODEL_MISSING_SIGNATURE_WARNING) - # NO LONGER LOG ARTIFACTS TO A RUN. CREATE A MODEL AND FINALIZE IT INSTEAD! client.log_model_artifacts(model.model_id, local_path) client.finalize_model(model.model_id, status=ModelStatus.READY) - # mlflow.tracking.fluent.log_artifacts(local_path, mlflow_model.artifact_path, run_id) - # # if the model_config kwarg is passed in, then log the model config as an params # if model_config := kwargs.get("model_config"): # if isinstance(model_config, str): @@ -844,7 +838,6 @@ def log_model_metrics_for_step(client, model_id, run_id, step): if registered_model_name is not None: registered_model = mlflow.tracking._model_registry.fluent._register_model( - # TODO: Fix this! f"models:/{model.model_id}", registered_model_name, await_registration_for=await_registration_for, From f6dc44f996de2653053581d608717ed799da41fb Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 23:03:09 -0700 Subject: [PATCH 37/62] fix Signed-off-by: dbczumar --- mlflow/langchain/__init__.py | 2 ++ mlflow/models/model.py | 29 ++++++++++++++++++----------- mlflow/pyfunc/__init__.py | 2 ++ mlflow/pytorch/__init__.py | 2 ++ 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/mlflow/langchain/__init__.py b/mlflow/langchain/__init__.py index c34eaa8322223..6334ae0d9ca65 100644 --- a/mlflow/langchain/__init__.py +++ b/mlflow/langchain/__init__.py @@ -426,6 +426,7 @@ def log_model( tags: Optional[Dict[str, Any]] = None, model_type: Optional[str] = None, step: int = 0, + model_id: Optional[str] = None, ): """ Log a LangChain model as an MLflow artifact for the current run. @@ -573,6 +574,7 @@ def load_retriever(persist_directory): tags=tags, model_type=model_type, step=step, + model_id=model_id, ) diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 36f005b142b44..d32ed663fb4c3 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -666,6 +666,7 @@ def log( params: Optional[Dict[str, Any]] = None, tags: Optional[Dict[str, Any]] = None, step: int = 0, + model_id: Optional[str] = None, **kwargs, ): """ @@ -725,17 +726,23 @@ def log_model_metrics_for_step(client, model_id, run_id, step): tracking_uri = _resolve_tracking_uri() client = mlflow.MlflowClient(tracking_uri) active_run = mlflow.tracking.fluent.active_run() - model = client.create_model( - experiment_id=mlflow.tracking.fluent._get_experiment_id(), - # TODO: Update model name - name=name, - run_id=active_run.info.run_id if active_run is not None else None, - model_type=model_type, - params={key: str(value) for key, value in params.items()} - if params is not None - else None, - tags={key: str(value) for key, value in tags.items()} if tags is not None else None, - ) + if model_id is None: + model = client.get_model(model_id) + else: + model = client.create_model( + experiment_id=mlflow.tracking.fluent._get_experiment_id(), + # TODO: Update model name + name=name, + run_id=active_run.info.run_id if active_run is not None else None, + model_type=model_type, + params={key: str(value) for key, value in params.items()} + if params is not None + else None, + tags={key: str(value) for key, value in tags.items()} + if tags is not None + else None, + ) + if active_run is not None: run_id = active_run.info.run_id client.log_outputs(run_id=run_id, models=[ModelOutput(model.model_id, step=step)]) diff --git a/mlflow/pyfunc/__init__.py b/mlflow/pyfunc/__init__.py index 4e01542fb5cf8..234622b6e2a65 100644 --- a/mlflow/pyfunc/__init__.py +++ b/mlflow/pyfunc/__init__.py @@ -2657,6 +2657,7 @@ def log_model( tags: Optional[Dict[str, Any]] = None, model_type: Optional[str] = None, step: int = 0, + model_id: Optional[str] = None, ): """ Log a Pyfunc model with custom inference logic and optional data dependencies as an MLflow @@ -2881,6 +2882,7 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]: tags=tags, model_type=model_type, step=step, + model_id=model_id, ) diff --git a/mlflow/pytorch/__init__.py b/mlflow/pytorch/__init__.py index 26df134ab4984..56a04600d69f7 100644 --- a/mlflow/pytorch/__init__.py +++ b/mlflow/pytorch/__init__.py @@ -154,6 +154,7 @@ def log_model( tags: Optional[Dict[str, Any]] = None, model_type: Optional[str] = None, step: int = 0, + model_id: Optional[str] = None, **kwargs, ): """ @@ -316,6 +317,7 @@ class definition itself, should be included in one of the following locations: tags=tags, model_type=model_type, step=step, + model_id=model_id, **kwargs, ) From ec6e9194b7875c2053990c85af06d6c006fd5b59 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 20 Aug 2024 23:14:24 -0700 Subject: [PATCH 38/62] fix Signed-off-by: dbczumar --- mlflow/__init__.py | 2 ++ mlflow/langchain/__init__.py | 2 +- mlflow/models/model.py | 8 +++++++- mlflow/pyfunc/__init__.py | 2 +- mlflow/pytorch/__init__.py | 2 +- mlflow/tracking/fluent.py | 23 +++++++++++++++++++++++ 6 files changed, 35 insertions(+), 4 deletions(-) diff --git a/mlflow/__init__.py b/mlflow/__init__.py index 404fe895eb2a1..41327ecd0221c 100644 --- a/mlflow/__init__.py +++ b/mlflow/__init__.py @@ -126,6 +126,7 @@ active_run, autolog, create_experiment, + create_model, delete_experiment, delete_run, delete_tag, @@ -173,6 +174,7 @@ "active_run", "autolog", "create_experiment", + "create_model", "delete_experiment", "delete_run", "delete_tag", diff --git a/mlflow/langchain/__init__.py b/mlflow/langchain/__init__.py index 6334ae0d9ca65..23045b1ce97b1 100644 --- a/mlflow/langchain/__init__.py +++ b/mlflow/langchain/__init__.py @@ -406,7 +406,7 @@ def load_retriever(persist_directory): @trace_disabled # Suppress traces for internal predict calls while logging model def log_model( lc_model, - name: str, + name: Optional[str] = None, conda_env=None, code_paths=None, registered_model_name=None, diff --git a/mlflow/models/model.py b/mlflow/models/model.py index d32ed663fb4c3..4a34edee3d653 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -695,6 +695,12 @@ def log( A :py:class:`ModelInfo ` instance that contains the metadata of the logged model. """ + if (model_id, name).count(None) == 2: + raise MlflowException( + "Either `model_id` or `name` must be specified when logging a model. " + "Both are None.", + error_code=INVALID_PARAMETER_VALUE, + ) def log_model_metrics_for_step(client, model_id, run_id, step): metric_names = client.get_run(run_id).data.metrics.keys() @@ -726,7 +732,7 @@ def log_model_metrics_for_step(client, model_id, run_id, step): tracking_uri = _resolve_tracking_uri() client = mlflow.MlflowClient(tracking_uri) active_run = mlflow.tracking.fluent.active_run() - if model_id is None: + if model_id is not None: model = client.get_model(model_id) else: model = client.create_model( diff --git a/mlflow/pyfunc/__init__.py b/mlflow/pyfunc/__init__.py index 234622b6e2a65..0f47db32a67ad 100644 --- a/mlflow/pyfunc/__init__.py +++ b/mlflow/pyfunc/__init__.py @@ -2633,7 +2633,7 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]: @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn")) @trace_disabled # Suppress traces for internal predict calls while logging model def log_model( - name, + name=None, loader_module=None, data_path=None, code_path=None, # deprecated diff --git a/mlflow/pytorch/__init__.py b/mlflow/pytorch/__init__.py index 56a04600d69f7..02c453745d3d5 100644 --- a/mlflow/pytorch/__init__.py +++ b/mlflow/pytorch/__init__.py @@ -137,7 +137,7 @@ def get_default_conda_env(): @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="torch")) def log_model( pytorch_model, - name: str, + name: Optional[str] = None, conda_env=None, code_paths=None, pickle_module=None, diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index e27d5fc03e407..bda9a75586c62 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -18,6 +18,7 @@ Experiment, InputTag, Metric, + Model, ModelInput, Param, Run, @@ -1841,6 +1842,28 @@ def delete_experiment(experiment_id: str) -> None: MlflowClient().delete_experiment(experiment_id) +def create_model( + name: str, + run_id: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, str]] = None, + model_type: Optional[str] = None, + experiment_id: Optional[str] = None, +) -> Model: + run = active_run() + if run_id is None and run is not None: + run_id = run.info.run_id + experiment_id = experiment_id if experiment_id is not None else _get_experiment_id() + return MlflowClient().create_model( + experiment_id=experiment_id, + name=name, + run_id=run_id, + tags=tags, + params=params, + model_type=model_type, + ) + + def delete_run(run_id: str) -> None: """ Deletes a run with the given ID. From 4094f0cb4b993a9bd8519dfbc692ac0a35c7b06a Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 01:01:43 -0700 Subject: [PATCH 39/62] fix Signed-off-by: dbczumar --- mlflow/tracing/constant.py | 2 ++ mlflow/tracing/fluent.py | 12 +++++++++--- mlflow/tracing/processor/mlflow.py | 10 ++++++++-- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/mlflow/tracing/constant.py b/mlflow/tracing/constant.py index 6fd5c0c025178..68eaa47c20aa9 100644 --- a/mlflow/tracing/constant.py +++ b/mlflow/tracing/constant.py @@ -3,6 +3,7 @@ class TraceMetadataKey: INPUTS = "mlflow.traceInputs" OUTPUTS = "mlflow.traceOutputs" SOURCE_RUN = "mlflow.sourceRun" + MODEL_ID = "mlflow.modelId" class TraceTagKey: @@ -19,6 +20,7 @@ class SpanAttributeKey: OUTPUTS = "mlflow.spanOutputs" SPAN_TYPE = "mlflow.spanType" FUNCTION_NAME = "mlflow.spanFunctionName" + MODEL_ID = "mlflow.modelId" # All storage backends are guaranteed to support key values up to 250 characters diff --git a/mlflow/tracing/fluent.py b/mlflow/tracing/fluent.py index c1c7d66306de8..95bc74ad5d7e5 100644 --- a/mlflow/tracing/fluent.py +++ b/mlflow/tracing/fluent.py @@ -58,6 +58,7 @@ def trace( name: Optional[str] = None, span_type: str = SpanType.UNKNOWN, attributes: Optional[Dict[str, Any]] = None, + model_id: Optional[str] = None, ) -> Callable: """ A decorator that creates a new span for the decorated function. @@ -135,7 +136,9 @@ class _WrappingContext: def _wrapping_logic(fn, args, kwargs): span_name = name or fn.__name__ - with start_span(name=span_name, span_type=span_type, attributes=attributes) as span: + with start_span( + name=span_name, span_type=span_type, attributes=attributes, model_id=model_id + ) as span: span.set_attribute(SpanAttributeKey.FUNCTION_NAME, fn.__name__) try: span.set_inputs(capture_function_input_args(fn, args, kwargs)) @@ -184,6 +187,7 @@ def start_span( name: str = "span", span_type: Optional[str] = SpanType.UNKNOWN, attributes: Optional[Dict[str, Any]] = None, + model_id: Optional[str] = None, ) -> Generator[LiveSpan, None, None]: """ Context manager to create a new span and start it as the current span in the context. @@ -253,9 +257,11 @@ def start_span( # Create a new MLflow span and register it to the in-memory trace manager request_id = get_otel_attribute(otel_span, SpanAttributeKey.REQUEST_ID) mlflow_span = create_mlflow_span(otel_span, request_id, span_type) - mlflow_span.set_attributes(attributes or {}) + attributes = dict(attributes) if attributes is not None else {} + if model_id is not None: + attributes[SpanAttributeKey.MODEL_ID] = model_id + mlflow_span.set_attributes(attributes) InMemoryTraceManager.get_instance().register_span(mlflow_span) - except Exception as e: _logger.warning( f"Failed to start span: {e}. For full traceback, set logging level to debug.", diff --git a/mlflow/tracing/processor/mlflow.py b/mlflow/tracing/processor/mlflow.py index bed4c48f275b4..21decbda90ffa 100644 --- a/mlflow/tracing/processor/mlflow.py +++ b/mlflow/tracing/processor/mlflow.py @@ -145,17 +145,21 @@ def on_end(self, span: OTelReadableSpan) -> None: return request_id = get_otel_attribute(span, SpanAttributeKey.REQUEST_ID) + # TODO: We should remove the model ID from the span attributes + model_id = get_otel_attribute(span, SpanAttributeKey.MODEL_ID) with self._trace_manager.get_trace(request_id) as trace: if trace is None: _logger.debug(f"Trace data with request ID {request_id} not found.") return - self._update_trace_info(trace, span) + self._update_trace_info(trace, span, model_id) deduplicate_span_names_in_place(list(trace.span_dict.values())) super().on_end(span) - def _update_trace_info(self, trace: _Trace, root_span: OTelReadableSpan): + def _update_trace_info( + self, trace: _Trace, root_span: OTelReadableSpan, model_id: Optional[str] + ): """Update the trace info with the final values from the root span.""" # The trace/span start time needs adjustment to exclude the latency of # the backend API call. We already adjusted the span start time in the @@ -173,6 +177,8 @@ def _update_trace_info(self, trace: _Trace, root_span: OTelReadableSpan): ), } ) + if model_id is not None: + trace.info.request_metadata[SpanAttributeKey.MODEL_ID] = model_id def _truncate_metadata(self, value: Optional[str]) -> str: """Get truncated value of the attribute if it exceeds the maximum length.""" From 7ad8c3a6964f4d8093c236357b9bb690771c20c0 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 01:09:45 -0700 Subject: [PATCH 40/62] fix Signed-off-by: dbczumar --- mlflow/tracing/fluent.py | 2 ++ mlflow/tracking/_tracking_service/client.py | 13 +++++++++++++ mlflow/tracking/client.py | 2 ++ 3 files changed, 17 insertions(+) diff --git a/mlflow/tracing/fluent.py b/mlflow/tracing/fluent.py index 95bc74ad5d7e5..dc0932acd05d1 100644 --- a/mlflow/tracing/fluent.py +++ b/mlflow/tracing/fluent.py @@ -338,6 +338,7 @@ def search_traces( max_results: Optional[int] = None, order_by: Optional[List[str]] = None, extract_fields: Optional[List[str]] = None, + model_id: Optional[str] = None, ) -> "pandas.DataFrame": """ Return traces that match the given list of search expressions within the experiments. @@ -436,6 +437,7 @@ def pagination_wrapper_func(number_to_get, next_page_token): filter_string=filter_string, order_by=order_by, page_token=next_page_token, + model_id=model_id, ) results = get_results_from_paginated_fn( diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index 20eb338613ad2..4b076954dd274 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -303,7 +303,18 @@ def _search_traces( max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS, order_by: Optional[List[str]] = None, page_token: Optional[str] = None, + model_id: Optional[str] = None, ): + if model_id is not None: + if filter_string: + raise MlflowException( + message=( + "Cannot specify both `model_id` and `experiment_ids` or `filter_string`" + " in the search_traces call." + ), + error_code=INVALID_PARAMETER_VALUE, + ) + filter_string = f"request_metadata.`mlflow.modelId` = '{model_id}'" return self.store.search_traces( experiment_ids=experiment_ids, filter_string=filter_string, @@ -319,6 +330,7 @@ def search_traces( max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS, order_by: Optional[List[str]] = None, page_token: Optional[str] = None, + model_id: Optional[str] = None, ) -> PagedList[Trace]: def download_trace_data(trace_info: TraceInfo) -> Optional[Trace]: """ @@ -350,6 +362,7 @@ def download_trace_data(trace_info: TraceInfo) -> Optional[Trace]: max_results=next_max_results, order_by=order_by, page_token=next_token, + model_id=model_id, ) traces.extend(t for t in executor.map(download_trace_data, trace_infos) if t) diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index 267ebbc7edb15..cd4b11d4971f6 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -489,6 +489,7 @@ def search_traces( max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS, order_by: Optional[List[str]] = None, page_token: Optional[str] = None, + model_id: Optional[str] = None, ) -> PagedList[Trace]: """ Return traces that match the given list of search expressions within the experiments. @@ -515,6 +516,7 @@ def search_traces( max_results=max_results, order_by=order_by, page_token=page_token, + model_id=model_id, ) get_display_handler().display_traces(traces) From 1420bf6daa3c00caa559771dd784657cffe750b2 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 01:32:07 -0700 Subject: [PATCH 41/62] client Signed-off-by: dbczumar --- mlflow/store/tracking/file_store.py | 68 ++++++++++++++++++++- mlflow/tracking/_tracking_service/client.py | 9 +++ mlflow/tracking/client.py | 13 +++- 3 files changed, 86 insertions(+), 4 deletions(-) diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 82611222889c4..bd9672bb44b0f 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -2031,8 +2031,6 @@ def _get_model_dict(self, model_id: str) -> Dict[str, Any]: raise MlflowException( f"Model '{model_id}' metadata is in invalid state.", databricks_pb2.INVALID_STATE ) - model_dict["tags"] = self._get_all_model_tags(model_dir) - model_dict["metrics"] = self._get_all_model_metrics(model_id=model_id, model_dir=model_dir) return model_dict def _get_model_dir(self, experiment_id: str, model_id: str) -> str: @@ -2055,8 +2053,16 @@ def _find_model_root(self, model_id): return os.path.basename(os.path.dirname(os.path.abspath(models_dir_path))), models[0] return None, None + def _get_model_from_dir(self, model_dir: str) -> Model: + return Model.from_dictionary(self._get_model_info_from_dir(model_dir)) + def _get_model_info_from_dir(self, model_dir: str) -> Dict[str, Any]: - return FileStore._read_yaml(model_dir, FileStore.META_DATA_FILE_NAME) + model_dict = FileStore._read_yaml(model_dir, FileStore.META_DATA_FILE_NAME) + model_dict["tags"] = self._get_all_model_tags(model_dir) + model_dict["metrics"] = self._get_all_model_metrics( + model_id=model_dict["model_id"], model_dir=model_dir + ) + return model_dict def _get_all_model_tags(self, model_dir: str) -> List[ModelTag]: parent_path, tag_files = self._get_resource_files(model_dir, FileStore.TAGS_FOLDER_NAME) @@ -2119,3 +2125,59 @@ def _get_model_metric_from_line(model_id: str, metric_name: str, metric_line: st dataset_digest=dataset_digest, run_id=run_id, ) + + def search_models( + self, + experiment_ids: List[str], + filter_string: Optional[str] = None, + max_results: Optional[int] = None, + order_by: Optional[List[str]] = None, + ) -> List[Model]: + all_models = [] + for experiment_id in experiment_ids: + models = self._list_models(experiment_id) + all_models.extend(models) + # filtered = SearchUtils.filter(runs, filter_string) + # sorted_runs = SearchUtils.sort(filtered, order_by) + # runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token, max_results) + return all_models + + def _list_models(self, experiment_id: str) -> List[Model]: + self._check_root_dir() + if not self._has_experiment(experiment_id): + return [] + experiment_dir = self._get_experiment_path(experiment_id, assert_exists=True) + model_dirs = list_all( + os.path.join(experiment_dir, FileStore.MODELS_FOLDER_NAME), + filter_func=lambda x: all( + os.path.basename(os.path.normpath(x)) != reservedFolderName + for reservedFolderName in FileStore.RESERVED_EXPERIMENT_FOLDERS + ) + and os.path.isdir(x), + full_path=True, + ) + models = [] + for m_dir in model_dirs: + try: + # trap and warn known issues, will raise unexpected exceptions to caller + model = self._get_model_from_dir(m_dir) + if model.experiment_id != experiment_id: + logging.warning( + "Wrong experiment ID (%s) recorded for model '%s'. " + "It should be %s. Model will be ignored.", + str(model.experiment_id), + str(model.model_id), + str(experiment_id), + exc_info=True, + ) + continue + models.append(model) + except MissingConfigException as exc: + # trap malformed model exception and log + # this is at debug level because if the same store is used for + # artifact storage, it's common the folder is not a run folder + m_id = os.path.basename(m_dir) + logging.debug( + "Malformed model '%s'. Detailed error %s", m_id, str(exc), exc_info=True + ) + return models diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index 4b076954dd274..7119e470f5920 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -1048,6 +1048,15 @@ def set_model_tag(self, model_id: str, key: str, value: str): def log_model_artifacts(self, model_id: str, local_dir: str) -> None: self._get_artifact_repo_for_model(model_id).log_artifacts(local_dir) + def search_models( + self, + experiment_ids: List[str], + filter_string: Optional[str] = None, + max_results: Optional[int] = None, + order_by: Optional[List[str]] = None, + ): + return self.store.search_models(experiment_ids, filter_string, max_results, order_by) + def _get_artifact_repo_for_model(self, model_id: str) -> ArtifactRepository: # Attempt to fetch the artifact repo from a local cache cached_repo = utils._artifact_repos_cache.get(model_id) diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index cd4b11d4971f6..04a38b03766e2 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -4779,4 +4779,15 @@ def set_model_tag(self, model_id: str, key: str, value: str): return self._tracking_client.set_model_tag(model_id, key, value) def log_model_artifacts(self, model_id: str, local_dir: str) -> None: - self._tracking_client.log_model_artifacts(model_id, local_dir) + return self._tracking_client.log_model_artifacts(model_id, local_dir) + + def search_models( + self, + experiment_ids: List[str], + filter_string: Optional[str] = None, + max_results: Optional[int] = None, + order_by: Optional[List[str]] = None, + ): + return self._tracking_client.search_models( + experiment_ids, filter_string, max_results, order_by + ) From f81d55d7d89705f07c7e888d77abd938b51df9e4 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 01:59:33 -0700 Subject: [PATCH 42/62] search Signed-off-by: dbczumar --- mlflow/entities/model.py | 4 +- mlflow/store/tracking/file_store.py | 6 +- mlflow/tracking/_tracking_service/client.py | 4 +- mlflow/utils/search_utils.py | 102 +++++++++++++++++++- 4 files changed, 106 insertions(+), 10 deletions(-) diff --git a/mlflow/entities/model.py b/mlflow/entities/model.py index 7e0fc16f1c0b1..6f7c7c7f55736 100644 --- a/mlflow/entities/model.py +++ b/mlflow/entities/model.py @@ -40,12 +40,12 @@ def __init__( self._status: ModelStatus = status self._status_message: Optional[str] = status_message self._tags: Dict[str, str] = ( - {tag.key: tag.value for tag in (tags or [])} if isinstance(tags, list) else tags + {tag.key: tag.value for tag in (tags or [])} if isinstance(tags, list) else (tags or {}) ) self._params: Dict[str, str] = ( {param.key: param.value for param in (params or [])} if isinstance(params, list) - else params + else (params or {}) ) self._metrics: Optional[List[Metric]] = metrics diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index bd9672bb44b0f..a9fcc0e3180be 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -2137,10 +2137,8 @@ def search_models( for experiment_id in experiment_ids: models = self._list_models(experiment_id) all_models.extend(models) - # filtered = SearchUtils.filter(runs, filter_string) - # sorted_runs = SearchUtils.sort(filtered, order_by) - # runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token, max_results) - return all_models + filtered = SearchUtils.filter_models(models, filter_string) + return SearchUtils.sort_models(filtered, order_by)[:max_results] def _list_models(self, experiment_id: str) -> List[Model]: self._check_root_dir() diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index 7119e470f5920..df8ed29915e89 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -1027,10 +1027,10 @@ def create_model( experiment_id=experiment_id, name=name, run_id=run_id, - tags=[ModelTag(key, value) for key, value in tags.items()] + tags=[ModelTag(str(key), str(value)) for key, value in tags.items()] if tags is not None else tags, - params=[ModelParam(key, value) for key, value in params.items()] + params=[ModelParam(str(key), str(value)) for key, value in params.items()] if params is not None else params, model_type=model_type, diff --git a/mlflow/utils/search_utils.py b/mlflow/utils/search_utils.py index aa399f93f0e99..6d3b6e0424562 100644 --- a/mlflow/utils/search_utils.py +++ b/mlflow/utils/search_utils.py @@ -5,7 +5,7 @@ import operator import re import shlex -from typing import Any, Dict +from typing import Any, Dict, List, Optional import sqlparse from packaging.version import Version @@ -20,7 +20,7 @@ ) from sqlparse.tokens import Token as TokenType -from mlflow.entities import RunInfo +from mlflow.entities import Model, RunInfo from mlflow.entities.model_registry.model_version_stages import STAGE_DELETED_INTERNAL from mlflow.exceptions import MlflowException from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE @@ -635,6 +635,38 @@ def _does_run_match_clause(cls, run, sed): return SearchUtils.get_comparison_func(comparator)(lhs, value) + @classmethod + def _does_model_match_clause(cls, model, sed): + key_type = sed.get("type") + key = sed.get("key") + value = sed.get("value") + comparator = sed.get("comparator").upper() + + key = SearchUtils.translate_key_alias(key) + + if cls.is_metric(key_type, comparator): + matching_metrics = [metric for metric in model.metrics if metric.key == key] + lhs = matching_metrics[0].value if matching_metrics else None + value = float(value) + elif cls.is_param(key_type, comparator): + lhs = model.params.get(key, None) + elif cls.is_tag(key_type, comparator): + lhs = model.tags.get(key, None) + elif cls.is_string_attribute(key_type, key, comparator): + lhs = getattr(model.info, key) + elif cls.is_numeric_attribute(key_type, key, comparator): + lhs = getattr(model.info, key) + value = int(value) + else: + raise MlflowException( + f"Invalid model search expression type '{key_type}'", + error_code=INVALID_PARAMETER_VALUE, + ) + if lhs is None: + return False + + return SearchUtils.get_comparison_func(comparator)(lhs, value) + @classmethod def filter(cls, runs, filter_string): """Filters a set of runs based on a search filter string.""" @@ -647,6 +679,20 @@ def run_matches(run): return [run for run in runs if run_matches(run)] + @classmethod + def filter_models(cls, models: List[Model], filter_string: Optional[str] = None): + """Filters a set of runs based on a search filter string.""" + if not filter_string: + return models + + # TODO: Update parsing function to handle model-specific filter clauses + parsed = cls.parse_search_filter(filter_string) + + def model_matches(model): + return all(cls._does_model_match_clause(model, s) for s in parsed) + + return [model for model in models if model_matches(model)] + @classmethod def _validate_order_by_and_generate_token(cls, order_by): try: @@ -760,6 +806,40 @@ def _get_value_for_sort(cls, run, key_type, key, ascending): return (is_none_or_nan, sort_value) if ascending else (not is_none_or_nan, sort_value) + @classmethod + def _get_model_value_for_sort(cls, model, key_type, key, ascending): + """Returns a tuple suitable to be used as a sort key for models.""" + sort_value = None + key = SearchUtils.translate_key_alias(key) + if key_type == cls._METRIC_IDENTIFIER: + matching_metrics = [metric for metric in model.metrics if metric.key == key] + sort_value = float(matching_metrics[0].value) if matching_metrics else None + elif key_type == cls._PARAM_IDENTIFIER: + sort_value = model.params.get(key) + elif key_type == cls._TAG_IDENTIFIER: + sort_value = model.tags.get(key) + elif key_type == cls._ATTRIBUTE_IDENTIFIER: + sort_value = getattr(model, key) + else: + raise MlflowException( + f"Invalid models order_by entity type '{key_type}'", + error_code=INVALID_PARAMETER_VALUE, + ) + + # Return a key such that None values are always at the end. + is_none = sort_value is None + is_nan = isinstance(sort_value, float) and math.isnan(sort_value) + fill_value = (1 if ascending else -1) * math.inf + + if is_none: + sort_value = fill_value + elif is_nan: + sort_value = -fill_value + + is_none_or_nan = is_none or is_nan + + return (is_none_or_nan, sort_value) if ascending else (not is_none_or_nan, sort_value) + @classmethod def sort(cls, runs, order_by_list): """Sorts a set of runs based on their natural ordering and an overriding set of order_bys. @@ -780,6 +860,24 @@ def sort(cls, runs, order_by_list): ) return runs + @classmethod + def sort_models(cls, models, order_by_list): + models = sorted(models, key=lambda model: (-model.creation_timestamp, model.model_id)) + if not order_by_list: + return models + # NB: We rely on the stability of Python's sort function, so that we can apply + # the ordering conditions in reverse order. + for order_by_clause in reversed(order_by_list): + # TODO: Update parsing function to handle model-specific order-by keys + (key_type, key, ascending) = cls.parse_order_by_for_search_runs(order_by_clause) + + models = sorted( + models, + key=lambda model: cls._get_model_value_for_sort(model, key_type, key, ascending), + reverse=not ascending, + ) + return models + @classmethod def parse_start_offset_from_page_token(cls, page_token): # Note: the page_token is expected to be a base64-encoded JSON that looks like From 679292e590a5dc5ab18edea110fc1a7b20dd943e Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 02:06:23 -0700 Subject: [PATCH 43/62] fluent search Signed-off-by: dbczumar --- mlflow/tracking/fluent.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index bda9a75586c62..b358a587c8864 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -1864,6 +1864,34 @@ def create_model( ) +def search_models( + experiment_ids: Optional[List[str]] = None, + filter_string: Optional[str] = None, + max_results: Optional[int] = None, + order_by: Optional[List[str]] = None, + output_format: str = "pandas", +) -> Union[List[Model], "pandas.DataFrame"]: + experiment_ids = experiment_ids or [_get_experiment_id()] + models = MlflowClient().search_models( + experiment_ids=experiment_ids, + filter_string=filter_string, + max_results=max_results, + order_by=order_by, + ) + if output_format == "pandas": + import pandas as pd + + return pd.DataFrame([model.to_dictionary() for model in models]) + elif output_format == "list": + return models + else: + raise MlflowException( + "Unsupported output format: %s. Supported string values are 'pandas' or 'list'" + % output_format, + INVALID_PARAMETER_VALUE, + ) + + def delete_run(run_id: str) -> None: """ Deletes a run with the given ID. From 441fc93b7e44c248f89d1abf3659204f2fcd7219 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 02:08:13 -0700 Subject: [PATCH 44/62] fluent search Signed-off-by: dbczumar --- mlflow/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlflow/__init__.py b/mlflow/__init__.py index 41327ecd0221c..9bf20231b8923 100644 --- a/mlflow/__init__.py +++ b/mlflow/__init__.py @@ -154,6 +154,7 @@ log_table, log_text, search_experiments, + search_models, search_runs, set_experiment, set_experiment_tag, @@ -214,6 +215,7 @@ "register_model", "run", "search_experiments", + "search_models", "search_model_versions", "search_registered_models", "search_runs", From 975515e3c8b1254d6cd8d657e57208bb754bd247 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 02:13:30 -0700 Subject: [PATCH 45/62] proggy Signed-off-by: dbczumar --- mlflow/tracking/_tracking_service/client.py | 12 +++++++++++- mlflow/tracking/client.py | 4 ++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index df8ed29915e89..9640269af49f4 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -559,6 +559,8 @@ def log_metric( timestamp=None, step=None, synchronous=True, + dataset_name: Optional[str] = None, + dataset_digest: Optional[str] = None, model_id: Optional[str] = None, ) -> Optional[RunOperations]: """Log a metric against the run ID. @@ -586,7 +588,15 @@ def log_metric( timestamp = timestamp if timestamp is not None else get_current_time_millis() step = step if step is not None else 0 metric_value = convert_metric_value_to_float_if_possible(value) - metric = Metric(key, metric_value, timestamp, step, model_id=model_id) + metric = Metric( + key, + metric_value, + timestamp, + step, + model_id=model_id, + dataset_name=dataset_name, + dataset_digest=dataset_digest, + ) if synchronous: self.store.log_metric(run_id, metric) else: diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index 04a38b03766e2..3679c4f3b7c62 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -1446,6 +1446,8 @@ def log_metric( timestamp: Optional[int] = None, step: Optional[int] = None, synchronous: Optional[bool] = None, + dataset_name: Optional[str] = None, + dataset_digest: Optional[str] = None, model_id: Optional[str] = None, ) -> Optional[RunOperations]: """ @@ -1528,6 +1530,8 @@ def print_run_info(r): timestamp, step, synchronous=synchronous, + dataset_name=dataset_name, + dataset_digest=dataset_digest, model_id=model_id, ) From 35ccde85d1b7a4f5e9f89584be5dcd474d3b1798 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 02:16:32 -0700 Subject: [PATCH 46/62] proggy Signed-off-by: dbczumar --- mlflow/tracking/fluent.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index b358a587c8864..7630f6edb5544 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -828,6 +828,7 @@ def log_metric( timestamp: Optional[int] = None, run_id: Optional[str] = None, model_id: Optional[str] = None, + dataset: Optional[Dataset] = None, ) -> Optional[RunOperations]: """ Log a metric under the current run. If no run is active, this method will create @@ -891,6 +892,8 @@ def log_metric( step or 0, synchronous=synchronous, model_id=model_id, + dataset_name=dataset.name if dataset is not None else None, + dataset_digest=dataset.digest if dataset is not None else None, ) From b9420ba7aced3629acfea4225a4fe9cfa66e2f90 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 02:32:07 -0700 Subject: [PATCH 47/62] fix Signed-off-by: dbczumar --- mlflow/sklearn/__init__.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/mlflow/sklearn/__init__.py b/mlflow/sklearn/__init__.py index 39b3cf7f5cc5b..7bc50a5f75b22 100644 --- a/mlflow/sklearn/__init__.py +++ b/mlflow/sklearn/__init__.py @@ -333,7 +333,7 @@ def save_model( @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn")) def log_model( sk_model, - artifact_path, + name: Optional[str] = None, conda_env=None, code_paths=None, serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE, @@ -345,6 +345,11 @@ def log_model( extra_pip_requirements=None, pyfunc_predict_fn="predict", metadata=None, + params: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, Any]] = None, + model_type: Optional[str] = None, + step: int = 0, + model_id: Optional[str] = None, ): """ Log a scikit-learn model as an MLflow artifact for the current run. Produces an MLflow Model @@ -356,7 +361,7 @@ def log_model( Args: sk_model: scikit-learn model to be saved. - artifact_path: Run-relative artifact path. + name: Model name. conda_env: {{ conda_env }} code_paths: {{ code_paths }} serialization_format: The format in which to serialize the model. This should be one of @@ -410,7 +415,7 @@ def log_model( """ return Model.log( - artifact_path=artifact_path, + name=name, flavor=mlflow.sklearn, sk_model=sk_model, conda_env=conda_env, @@ -424,6 +429,11 @@ def log_model( extra_pip_requirements=extra_pip_requirements, pyfunc_predict_fn=pyfunc_predict_fn, metadata=metadata, + params=params, + tags=tags, + model_type=model_type, + step=step, + model_id=model_id, ) From 3938d9beccf6d3e8ee39d1ac09869784e9c22c36 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 02:49:51 -0700 Subject: [PATCH 48/62] fdataset Signed-off-by: dbczumar --- mlflow/tracking/fluent.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index 7630f6edb5544..8d58d445469ff 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -913,6 +913,7 @@ def log_metrics( synchronous: Optional[bool] = None, run_id: Optional[str] = None, timestamp: Optional[int] = None, + dataset: Optional[Dataset] = None, ) -> Optional[RunOperations]: """ Log multiple metrics for the current run. If no run is active, this method will create a new @@ -958,7 +959,13 @@ def log_metrics( _log_inputs_for_metrics_if_necessary(run_id, metrics_arr) synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() return MlflowClient().log_batch( - run_id=run_id, metrics=metrics_arr, params=[], tags=[], synchronous=synchronous + run_id=run_id, + metrics=metrics_arr, + params=[], + tags=[], + synchronous=synchronous, + dataset_name=dataset.name if dataset is not None else None, + dataset_digest=dataset.digest if dataset is not None else None, ) From 90b177845d7fdffbb04fa48810d72930da282969 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 02:51:38 -0700 Subject: [PATCH 49/62] fix Signed-off-by: dbczumar --- mlflow/tracking/fluent.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index 8d58d445469ff..749b50437a522 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -955,7 +955,19 @@ def log_metrics( """ run_id = run_id or _get_or_start_run().info.run_id timestamp = timestamp or get_current_time_millis() - metrics_arr = [Metric(key, value, timestamp, step or 0) for key, value in metrics.items()] + dataset_name = dataset.name if dataset is not None else None + dataset_digest = dataset.digest if dataset is not None else None + metrics_arr = [ + Metric( + key, + value, + timestamp, + step or 0, + dataset_name=dataset_name, + dataset_digest=dataset_digest, + ) + for key, value in metrics.items() + ] _log_inputs_for_metrics_if_necessary(run_id, metrics_arr) synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() return MlflowClient().log_batch( @@ -964,8 +976,6 @@ def log_metrics( params=[], tags=[], synchronous=synchronous, - dataset_name=dataset.name if dataset is not None else None, - dataset_digest=dataset.digest if dataset is not None else None, ) From 6608a4a5a6d71602a6d2e75cb77d1c26857882d3 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 03:05:17 -0700 Subject: [PATCH 50/62] Model ID Signed-off-by: dbczumar --- mlflow/tracking/fluent.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index 749b50437a522..cdc77c6f42cd7 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -913,6 +913,7 @@ def log_metrics( synchronous: Optional[bool] = None, run_id: Optional[str] = None, timestamp: Optional[int] = None, + model_id: Optional[str] = None, dataset: Optional[Dataset] = None, ) -> Optional[RunOperations]: """ @@ -963,6 +964,7 @@ def log_metrics( value, timestamp, step or 0, + model_id=model_id, dataset_name=dataset_name, dataset_digest=dataset_digest, ) From a74af6f388ff8c31123630d59e22d4bfff4e3342 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 03:08:46 -0700 Subject: [PATCH 51/62] fix Signed-off-by: dbczumar --- mlflow/tracking/fluent.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index cdc77c6f42cd7..a171683095f5e 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -1886,6 +1886,10 @@ def create_model( ) +def get_model(model_id: str) -> Model: + return MlflowClient().get_model(model_id) + + def search_models( experiment_ids: Optional[List[str]] = None, filter_string: Optional[str] = None, From f641549f9bfc359913abfaaf24c4bbbeddae3bfc Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 03:11:30 -0700 Subject: [PATCH 52/62] get model Signed-off-by: dbczumar --- mlflow/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlflow/__init__.py b/mlflow/__init__.py index 9bf20231b8923..88e7f09efe60f 100644 --- a/mlflow/__init__.py +++ b/mlflow/__init__.py @@ -137,6 +137,7 @@ get_artifact_uri, get_experiment, get_experiment_by_name, + get_model, get_parent_run, get_run, last_active_run, @@ -191,6 +192,7 @@ "get_experiment", "get_experiment_by_name", "get_last_active_trace", + "get_model", "get_parent_run", "get_registry_uri", "get_run", From 10a0c690cfbd585e1b5198c9b6527f7c7d351240 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 03:19:19 -0700 Subject: [PATCH 53/62] fix Signed-off-by: dbczumar --- mlflow/store/tracking/file_store.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index a9fcc0e3180be..71b4354ad7d74 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -5,6 +5,7 @@ import sys import time import uuid +from collections import defaultdict from dataclasses import dataclass from typing import Any, Dict, List, NamedTuple, Optional, Tuple @@ -2077,15 +2078,17 @@ def _get_all_model_metrics(self, model_id: str, model_dir: str) -> List[Metric]: ) metrics = [] for metric_file in metric_files: - metrics.append( - FileStore._get_model_metric_from_file( + metrics.extend( + FileStore._get_model_metrics_from_file( model_id=model_id, parent_path=parent_path, metric_name=metric_file ) ) return metrics @staticmethod - def _get_model_metric_from_file(model_id: str, parent_path: str, metric_name: str) -> Metric: + def _get_model_metrics_from_file( + model_id: str, parent_path: str, metric_name: str + ) -> List[Metric]: _validate_metric_name(metric_name) metric_objs = [ FileStore._get_model_metric_from_line(model_id, metric_name, line) @@ -2093,11 +2096,18 @@ def _get_model_metric_from_file(model_id: str, parent_path: str, metric_name: st ] if len(metric_objs) == 0: raise ValueError(f"Metric '{metric_name}' is malformed. No data found.") - # Python performs element-wise comparison of equal-length tuples, ordering them - # based on their first differing element. Therefore, we use max() operator to find the - # largest value at the largest timestamp. For more information, see - # https://docs.python.org/3/reference/expressions.html#value-comparisons - return max(metric_objs, key=lambda m: (m.step, m.timestamp, m.value)) + + # Group metrics by (dataset_name, dataset_digest) + grouped_metrics = defaultdict(list) + for metric in metric_objs: + key = (metric.dataset_name, metric.dataset_digest) + grouped_metrics[key].append(metric) + + # Compute the max for each group + return [ + max(group, key=lambda m: (m.step, m.timestamp, m.value)) + for group in grouped_metrics.values() + ] @staticmethod def _get_model_metric_from_line(model_id: str, metric_name: str, metric_line: str) -> Metric: From e4fd6b4d5010f9f4d4af67c1c6d950b2f8ca1498 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 03:41:13 -0700 Subject: [PATCH 54/62] all exp Signed-off-by: dbczumar --- mlflow/store/tracking/file_store.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 71b4354ad7d74..1ad5364b1414f 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -2046,8 +2046,10 @@ def _get_model_dir(self, experiment_id: str, model_id: str) -> str: def _find_model_root(self, model_id): self._check_root_dir() all_experiments = self._get_active_experiments(True) + self._get_deleted_experiments(True) + print("ALL EXPERIMENTS", all_experiments) for experiment_dir in all_experiments: models_dir_path = os.path.join(experiment_dir, FileStore.MODELS_FOLDER_NAME) + print("MODELS DIR PATH", models_dir_path) models = find(models_dir_path, model_id, full_path=True) if len(models) == 0: continue From d8470baa43a4103e84675333d98c467d27e6094c Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 03:46:45 -0700 Subject: [PATCH 55/62] fix Signed-off-by: dbczumar --- mlflow/store/tracking/file_store.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 1ad5364b1414f..59b2829eb0d26 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -2045,11 +2045,11 @@ def _get_model_dir(self, experiment_id: str, model_id: str) -> str: def _find_model_root(self, model_id): self._check_root_dir() - all_experiments = self._get_active_experiments(True) + self._get_deleted_experiments(True) - print("ALL EXPERIMENTS", all_experiments) + all_experiments = self._get_active_experiments(False) + self._get_deleted_experiments(False) for experiment_dir in all_experiments: - models_dir_path = os.path.join(experiment_dir, FileStore.MODELS_FOLDER_NAME) - print("MODELS DIR PATH", models_dir_path) + models_dir_path = os.path.join( + self.root_directory, experiment_dir, FileStore.MODELS_FOLDER_NAME + ) models = find(models_dir_path, model_id, full_path=True) if len(models) == 0: continue From ef157c51a69b14562fc9ef191f948f3a9f26d0f2 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 03:56:22 -0700 Subject: [PATCH 56/62] fix Signed-off-by: dbczumar --- mlflow/store/model_registry/file_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlflow/store/model_registry/file_store.py b/mlflow/store/model_registry/file_store.py index 3f805369a9832..3b57a9a2ff8a1 100644 --- a/mlflow/store/model_registry/file_store.py +++ b/mlflow/store/model_registry/file_store.py @@ -705,7 +705,7 @@ def next_version(registered_model_name): if tags is not None: for tag in tags: self.set_model_version_tag(name, version, tag) - return model_version.to_mlflow_entity() + return self.get_model_version(name, version) except Exception as e: more_retries = self.CREATE_MODEL_VERSION_RETRIES - attempt - 1 logging.warning( From 99c79884f70ca492a344cd3b1da237ed1e1401ad Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 04:12:01 -0700 Subject: [PATCH 57/62] fix Signed-off-by: dbczumar --- mlflow/tracing/display/display_handler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlflow/tracing/display/display_handler.py b/mlflow/tracing/display/display_handler.py index f8bff89122048..5da9553d2b618 100644 --- a/mlflow/tracing/display/display_handler.py +++ b/mlflow/tracing/display/display_handler.py @@ -90,6 +90,10 @@ def get_mimebundle(self, traces: List[Trace]): } def display_traces(self, traces: List[Trace]): + # Temporarily disable rendering of traces in Databricks notebooks, + # since it doesnt' work with file-based storage + return + # This only works in Databricks notebooks if not is_in_databricks_runtime(): return From 9b0fe2accc278527a5de4b40a8ac601c59ee0195 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 14:24:26 -0700 Subject: [PATCH 58/62] fixed Signed-off-by: dbczumar --- mlflow/tracking/fluent.py | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index a171683095f5e..d6fee586a034a 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -881,8 +881,11 @@ def log_metric( timestamp=timestamp or get_current_time_millis(), step=step or 0, model_id=model_id, + dataset_name=dataset.name if dataset is not None else None, + dataset_digest=dataset.digest if dataset is not None else None, ), ], + datasets=[dataset] if dataset is not None else None, ) return MlflowClient().log_metric( run_id, @@ -897,14 +900,34 @@ def log_metric( ) -def _log_inputs_for_metrics_if_necessary(run_id, metrics: List[Metric]) -> None: +def _log_inputs_for_metrics_if_necessary( + run_id, metrics: List[Metric], datasets: Optional[List[Dataset]] = None +) -> None: client = MlflowClient() run = client.get_run(run_id) - for metric in [metric for metric in metrics if metric.model_id is not None]: - if metric.model_id not in [inp.model_id for inp in run.inputs.model_inputs] + [ - output.model_id for output in run.outputs.model_outputs - ]: + datasets = datasets or [] + for metric in metrics: + if metric.model_id is not None and metric.model_id not in [ + inp.model_id for inp in run.inputs.model_inputs + ] + [output.model_id for output in run.outputs.model_outputs]: client.log_inputs(run_id, models=[ModelInput(model_id=metric.model_id)]) + if (metric.dataset_name, metric.dataset_digest) not in [ + (inp.dataset.name, inp.dataset.digest) for inp in run.inputs.dataset_inputs + ]: + matching_dataset = next( + ( + dataset + for dataset in datasets + if dataset.name == metric.dataset_name + and dataset.digest == metric.dataset_digest + ), + None, + ) + if matching_dataset is not None: + client.log_inputs( + run_id, + datasets=[DatasetInput(matching_dataset._to_mlflow_entity(), tags=[])], + ) def log_metrics( @@ -970,7 +993,9 @@ def log_metrics( ) for key, value in metrics.items() ] - _log_inputs_for_metrics_if_necessary(run_id, metrics_arr) + _log_inputs_for_metrics_if_necessary( + run_id, metrics_arr, [dataset] if dataset is not None else None + ) synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() return MlflowClient().log_batch( run_id=run_id, From 39242ed9e723e40d000f07ac414dfcac80c88d77 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 14:31:46 -0700 Subject: [PATCH 59/62] progress Signed-off-by: dbczumar --- mlflow/models/model.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 4a34edee3d653..355a9938ebaf1 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -735,15 +735,21 @@ def log_model_metrics_for_step(client, model_id, run_id, step): if model_id is not None: model = client.get_model(model_id) else: + params = { + **(params or {}), + **( + client.get_run(active_run.info.run_id).data.params + if active_run is not None + else {} + ), + } model = client.create_model( experiment_id=mlflow.tracking.fluent._get_experiment_id(), # TODO: Update model name name=name, run_id=active_run.info.run_id if active_run is not None else None, model_type=model_type, - params={key: str(value) for key, value in params.items()} - if params is not None - else None, + params={key: str(value) for key, value in params.items()}, tags={key: str(value) for key, value in tags.items()} if tags is not None else None, From 022189d5598ee15b50003af32393e4b7bca8d9b7 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 21 Aug 2024 14:49:33 -0700 Subject: [PATCH 60/62] fix Signed-off-by: dbczumar --- mlflow/tracking/fluent.py | 42 +++++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index d6fee586a034a..4561808d0d5c9 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -887,17 +887,25 @@ def log_metric( ], datasets=[dataset] if dataset is not None else None, ) - return MlflowClient().log_metric( - run_id, - key, - value, - timestamp or get_current_time_millis(), - step or 0, - synchronous=synchronous, - model_id=model_id, - dataset_name=dataset.name if dataset is not None else None, - dataset_digest=dataset.digest if dataset is not None else None, + timestamp = timestamp or get_current_time_millis() + step = step or 0 + model_ids = ( + [model_id] + if model_id is not None + else (_get_model_ids_for_new_metric_if_exist(run_id, step) or [None]) ) + for model_id in model_ids: + return MlflowClient().log_metric( + run_id, + key, + value, + timestamp, + step, + synchronous=synchronous, + model_id=model_id, + dataset_name=dataset.name if dataset is not None else None, + dataset_digest=dataset.digest if dataset is not None else None, + ) def _log_inputs_for_metrics_if_necessary( @@ -930,6 +938,13 @@ def _log_inputs_for_metrics_if_necessary( ) +def _get_model_ids_for_new_metric_if_exist(run_id: str, metric_step: str) -> List[str]: + client = MlflowClient() + run = client.get_run(run_id) + model_outputs_at_step = [mo for mo in run.outputs.model_outputs if mo.step == metric_step] + return [mo.model_id for mo in model_outputs_at_step] + + def log_metrics( metrics: Dict[str, float], step: Optional[int] = None, @@ -979,8 +994,14 @@ def log_metrics( """ run_id = run_id or _get_or_start_run().info.run_id timestamp = timestamp or get_current_time_millis() + step = step or 0 dataset_name = dataset.name if dataset is not None else None dataset_digest = dataset.digest if dataset is not None else None + model_ids = ( + [model_id] + if model_id is not None + else (_get_model_ids_for_new_metric_if_exist(run_id, step) or [None]) + ) metrics_arr = [ Metric( key, @@ -992,6 +1013,7 @@ def log_metrics( dataset_digest=dataset_digest, ) for key, value in metrics.items() + for model_id in model_ids ] _log_inputs_for_metrics_if_necessary( run_id, metrics_arr, [dataset] if dataset is not None else None From 6f7cb0f28115a82f757e835e3de49561e6c3c402 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Thu, 22 Aug 2024 11:26:38 -0700 Subject: [PATCH 61/62] fix Signed-off-by: dbczumar --- mlflow/__init__.py | 12 +++--- mlflow/entities/__init__.py | 4 +- mlflow/entities/{model.py => logged_model.py} | 4 +- mlflow/models/model.py | 8 ++-- mlflow/store/model_registry/file_store.py | 4 +- mlflow/store/tracking/file_store.py | 40 +++++++++---------- mlflow/tracking/_tracking_service/client.py | 30 +++++++------- mlflow/tracking/client.py | 24 +++++------ mlflow/tracking/fluent.py | 16 ++++---- mlflow/utils/search_utils.py | 6 +-- 10 files changed, 74 insertions(+), 74 deletions(-) rename mlflow/entities/{model.py => logged_model.py} (97%) diff --git a/mlflow/__init__.py b/mlflow/__init__.py index 88e7f09efe60f..5382224912a24 100644 --- a/mlflow/__init__.py +++ b/mlflow/__init__.py @@ -126,7 +126,7 @@ active_run, autolog, create_experiment, - create_model, + create_logged_model, delete_experiment, delete_run, delete_tag, @@ -137,7 +137,7 @@ get_artifact_uri, get_experiment, get_experiment_by_name, - get_model, + get_logged_model, get_parent_run, get_run, last_active_run, @@ -155,7 +155,7 @@ log_table, log_text, search_experiments, - search_models, + search_logged_models, search_runs, set_experiment, set_experiment_tag, @@ -176,7 +176,7 @@ "active_run", "autolog", "create_experiment", - "create_model", + "create_logged_model", "delete_experiment", "delete_run", "delete_tag", @@ -192,7 +192,7 @@ "get_experiment", "get_experiment_by_name", "get_last_active_trace", - "get_model", + "get_logged_model", "get_parent_run", "get_registry_uri", "get_run", @@ -217,7 +217,7 @@ "register_model", "run", "search_experiments", - "search_models", + "search_logged_models", "search_model_versions", "search_registered_models", "search_runs", diff --git a/mlflow/entities/__init__.py b/mlflow/entities/__init__.py index 86283eb1c4707..4c02c95f43fd8 100644 --- a/mlflow/entities/__init__.py +++ b/mlflow/entities/__init__.py @@ -11,8 +11,8 @@ from mlflow.entities.file_info import FileInfo from mlflow.entities.input_tag import InputTag from mlflow.entities.lifecycle_stage import LifecycleStage +from mlflow.entities.logged_model import LoggedModel from mlflow.entities.metric import Metric -from mlflow.entities.model import Model from mlflow.entities.model_input import ModelInput from mlflow.entities.model_output import ModelOutput from mlflow.entities.model_param import ModelParam @@ -65,7 +65,7 @@ "TraceInfo", "SpanStatusCode", "_DatasetSummary", - "Model", + "LoggedModel", "ModelInput", "ModelOutput", "ModelStatus", diff --git a/mlflow/entities/model.py b/mlflow/entities/logged_model.py similarity index 97% rename from mlflow/entities/model.py rename to mlflow/entities/logged_model.py index 6f7c7c7f55736..b4fb1baa13757 100644 --- a/mlflow/entities/model.py +++ b/mlflow/entities/logged_model.py @@ -7,9 +7,9 @@ from mlflow.entities.model_tag import ModelTag -class Model(_MlflowObject): +class LoggedModel(_MlflowObject): """ - MLflow entity representing a Model. + MLflow entity representing a Model logged to an MLflow Experiment. """ def __init__( diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 355a9938ebaf1..2f9d807892ae5 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -733,7 +733,7 @@ def log_model_metrics_for_step(client, model_id, run_id, step): client = mlflow.MlflowClient(tracking_uri) active_run = mlflow.tracking.fluent.active_run() if model_id is not None: - model = client.get_model(model_id) + model = client.get_logged_model(model_id) else: params = { **(params or {}), @@ -743,7 +743,7 @@ def log_model_metrics_for_step(client, model_id, run_id, step): else {} ), } - model = client.create_model( + model = client.create_logged_model( experiment_id=mlflow.tracking.fluent._get_experiment_id(), # TODO: Update model name name=name, @@ -788,7 +788,7 @@ def log_model_metrics_for_step(client, model_id, run_id, step): _logger.warning(_LOG_MODEL_MISSING_SIGNATURE_WARNING) client.log_model_artifacts(model.model_id, local_path) - client.finalize_model(model.model_id, status=ModelStatus.READY) + client.finalize_logged_model(model.model_id, status=ModelStatus.READY) # # if the model_config kwarg is passed in, then log the model config as an params # if model_config := kwargs.get("model_config"): @@ -864,7 +864,7 @@ def log_model_metrics_for_step(client, model_id, run_id, step): ) return client.get_model_version(registered_model_name, registered_model.version) else: - return client.get_model(model.model_id) + return client.get_logged_model(model.model_id) # model_info = mlflow_model.get_model_info() # if registered_model is not None: # model_info.registered_model_version = registered_model.version diff --git a/mlflow/store/model_registry/file_store.py b/mlflow/store/model_registry/file_store.py index 3b57a9a2ff8a1..13852e451a59c 100644 --- a/mlflow/store/model_registry/file_store.py +++ b/mlflow/store/model_registry/file_store.py @@ -581,7 +581,7 @@ def _get_file_model_version_from_dir(self, directory) -> FileModelVersion: # URI (individual MlflowClient instances may have different tracking URIs) if "model_id" in meta: try: - model = MlflowClient().get_model(meta["model_id"]) + model = MlflowClient().get_logged_model(meta["model_id"]) meta["metrics"] = model.metrics meta["params"] = model.params except Exception: @@ -662,7 +662,7 @@ def next_version(registered_model_name): # TODO: Propagate tracking URI to file store directly, rather than relying on # global URI (individual MlflowClient instances may have different tracking # URIs) - model = MlflowClient().get_model(parsed_model_uri.model_id) + model = MlflowClient().get_logged_model(parsed_model_uri.model_id) storage_location = model.artifact_location run_id = run_id or model.run_id else: diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 59b2829eb0d26..007bab3ad87f7 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -15,8 +15,8 @@ Experiment, ExperimentTag, InputTag, + LoggedModel, Metric, - Model, ModelInput, ModelOutput, ModelParam, @@ -1901,7 +1901,7 @@ def _list_trace_infos(self, experiment_id): ) return trace_infos - def create_model( + def create_logged_model( self, experiment_id: str, name: str, @@ -1909,7 +1909,7 @@ def create_model( tags: Optional[List[ModelTag]] = None, params: Optional[List[ModelParam]] = None, model_type: Optional[str] = None, - ) -> Model: + ) -> LoggedModel: """ Create a new model. @@ -1942,7 +1942,7 @@ def create_model( model_id = str(uuid.uuid4()) artifact_location = self._get_model_artifact_dir(experiment_id, model_id) creation_timestamp = int(time.time() * 1000) - model = Model( + model = LoggedModel( experiment_id=experiment_id, model_id=model_id, name=name, @@ -1965,9 +1965,9 @@ def create_model( for tag in tags or []: self.set_model_tag(model_id=model_id, tag=tag) - return self.get_model(model_id=model_id) + return self.get_logged_model(model_id=model_id) - def finalize_model(self, model_id: str, status: ModelStatus) -> Model: + def finalize_logged_model(self, model_id: str, status: ModelStatus) -> LoggedModel: """ Finalize a model by updating its status. @@ -1984,17 +1984,17 @@ def finalize_model(self, model_id: str, status: ModelStatus) -> Model: databricks_pb2.INVALID_PARAMETER_VALUE, ) model_dict = self._get_model_dict(model_id) - model = Model.from_dictionary(model_dict) + model = LoggedModel.from_dictionary(model_dict) model.status = status model.last_updated_timestamp = int(time.time() * 1000) model_dir = self._get_model_dir(model.experiment_id, model.model_id) model_info_dict = self._make_persisted_model_dict(model) write_yaml(model_dir, FileStore.META_DATA_FILE_NAME, model_info_dict, overwrite=True) - return self.get_model(model_id) + return self.get_logged_model(model_id) - def set_model_tag(self, model_id: str, tag: ModelTag): + def set_logged_model_tag(self, model_id: str, tag: ModelTag): _validate_tag_name(tag.key) - model = self.get_model(model_id) + model = self.get_logged_model(model_id) tag_path = os.path.join( self._get_model_dir(model.experiment_id, model.model_id), FileStore.TAGS_FOLDER_NAME, @@ -2004,8 +2004,8 @@ def set_model_tag(self, model_id: str, tag: ModelTag): # Don't add trailing newline write_to(tag_path, self._writeable_value(tag.value)) - def get_model(self, model_id: str) -> Model: - return Model.from_dictionary(self._get_model_dict(model_id)) + def get_logged_model(self, model_id: str) -> LoggedModel: + return LoggedModel.from_dictionary(self._get_model_dict(model_id)) def _get_model_artifact_dir(self, experiment_id: str, model_id: str) -> str: return append_to_uri_path( @@ -2015,7 +2015,7 @@ def _get_model_artifact_dir(self, experiment_id: str, model_id: str) -> str: FileStore.ARTIFACTS_FOLDER_NAME, ) - def _make_persisted_model_dict(self, model: Model) -> Dict[str, Any]: + def _make_persisted_model_dict(self, model: LoggedModel) -> Dict[str, Any]: model_dict = model.to_dictionary() model_dict.pop("tags", None) model_dict.pop("metrics", None) @@ -2056,8 +2056,8 @@ def _find_model_root(self, model_id): return os.path.basename(os.path.dirname(os.path.abspath(models_dir_path))), models[0] return None, None - def _get_model_from_dir(self, model_dir: str) -> Model: - return Model.from_dictionary(self._get_model_info_from_dir(model_dir)) + def _get_model_from_dir(self, model_dir: str) -> LoggedModel: + return LoggedModel.from_dictionary(self._get_model_info_from_dir(model_dir)) def _get_model_info_from_dir(self, model_dir: str) -> Dict[str, Any]: model_dict = FileStore._read_yaml(model_dir, FileStore.META_DATA_FILE_NAME) @@ -2138,21 +2138,21 @@ def _get_model_metric_from_line(model_id: str, metric_name: str, metric_line: st run_id=run_id, ) - def search_models( + def search_logged_models( self, experiment_ids: List[str], filter_string: Optional[str] = None, max_results: Optional[int] = None, order_by: Optional[List[str]] = None, - ) -> List[Model]: + ) -> List[LoggedModel]: all_models = [] for experiment_id in experiment_ids: models = self._list_models(experiment_id) all_models.extend(models) - filtered = SearchUtils.filter_models(models, filter_string) - return SearchUtils.sort_models(filtered, order_by)[:max_results] + filtered = SearchUtils.filter_logged_models(models, filter_string) + return SearchUtils.sort_logged_models(filtered, order_by)[:max_results] - def _list_models(self, experiment_id: str) -> List[Model]: + def _list_models(self, experiment_id: str) -> List[LoggedModel]: self._check_root_dir() if not self._has_experiment(experiment_id): return [] diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index 9640269af49f4..f7a0c3f95e494 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -13,8 +13,8 @@ from mlflow.entities import ( ExperimentTag, + LoggedModel, Metric, - Model, ModelInput, ModelOutput, ModelParam, @@ -1024,7 +1024,7 @@ def search_runs( page_token=page_token, ) - def create_model( + def create_logged_model( self, experiment_id: str, name: str, @@ -1032,8 +1032,8 @@ def create_model( tags: Optional[Dict[str, str]] = None, params: Optional[Dict[str, str]] = None, model_type: Optional[str] = None, - ) -> Model: - return self.store.create_model( + ) -> LoggedModel: + return self.store.create_logged_model( experiment_id=experiment_id, name=name, run_id=run_id, @@ -1046,34 +1046,34 @@ def create_model( model_type=model_type, ) - def finalize_model(self, model_id: str, status: ModelStatus) -> Model: - return self.store.finalize_model(model_id, status) + def finalize_logged_model(self, model_id: str, status: ModelStatus) -> LoggedModel: + return self.store.finalize_logged_model(model_id, status) - def get_model(self, model_id: str) -> Model: - return self.store.get_model(model_id) + def get_logged_model(self, model_id: str) -> LoggedModel: + return self.store.get_logged_model(model_id) - def set_model_tag(self, model_id: str, key: str, value: str): - return self.store.set_model_tag(model_id, ModelTag(key, value)) + def set_logged_model_tag(self, model_id: str, key: str, value: str): + return self.store.set_logged_model_tag(model_id, ModelTag(key, value)) def log_model_artifacts(self, model_id: str, local_dir: str) -> None: - self._get_artifact_repo_for_model(model_id).log_artifacts(local_dir) + self._get_artifact_repo_for_logged_model(model_id).log_artifacts(local_dir) - def search_models( + def search_logged_models( self, experiment_ids: List[str], filter_string: Optional[str] = None, max_results: Optional[int] = None, order_by: Optional[List[str]] = None, ): - return self.store.search_models(experiment_ids, filter_string, max_results, order_by) + return self.store.search_logged_models(experiment_ids, filter_string, max_results, order_by) - def _get_artifact_repo_for_model(self, model_id: str) -> ArtifactRepository: + def _get_artifact_repo_for_logged_model(self, model_id: str) -> ArtifactRepository: # Attempt to fetch the artifact repo from a local cache cached_repo = utils._artifact_repos_cache.get(model_id) if cached_repo is not None: return cached_repo else: - model = self.get_model(model_id) + model = self.get_logged_model(model_id) artifact_uri = add_databricks_profile_info_to_artifact_uri( model.artifact_location, self.tracking_uri ) diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index 3679c4f3b7c62..60394810cb44f 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -23,8 +23,8 @@ DatasetInput, Experiment, FileInfo, + LoggedModel, Metric, - Model, ModelInput, ModelOutput, ModelStatus, @@ -4760,7 +4760,7 @@ def print_model_version_info(mv): _validate_model_name(name) return self._get_registry_client().get_model_version_by_alias(name, alias) - def create_model( + def create_logged_model( self, experiment_id: str, name: str, @@ -4768,30 +4768,30 @@ def create_model( tags: Optional[Dict[str, str]] = None, params: Optional[Dict[str, str]] = None, model_type: Optional[str] = None, - ) -> Model: - return self._tracking_client.create_model( + ) -> LoggedModel: + return self._tracking_client.create_logged_model( experiment_id, name, run_id, tags, params, model_type ) - def finalize_model(self, model_id: str, status: ModelStatus) -> Model: - return self._tracking_client.finalize_model(model_id, status) + def finalize_logged_model(self, model_id: str, status: ModelStatus) -> LoggedModel: + return self._tracking_client.finalize_logged_model(model_id, status) - def get_model(self, model_id: str) -> Model: - return self._tracking_client.get_model(model_id) + def get_logged_model(self, model_id: str) -> LoggedModel: + return self._tracking_client.get_logged_model(model_id) - def set_model_tag(self, model_id: str, key: str, value: str): - return self._tracking_client.set_model_tag(model_id, key, value) + def set_logged_model_tag(self, model_id: str, key: str, value: str): + return self._tracking_client.set_logged_model_tag(model_id, key, value) def log_model_artifacts(self, model_id: str, local_dir: str) -> None: return self._tracking_client.log_model_artifacts(model_id, local_dir) - def search_models( + def search_logged_models( self, experiment_ids: List[str], filter_string: Optional[str] = None, max_results: Optional[int] = None, order_by: Optional[List[str]] = None, ): - return self._tracking_client.search_models( + return self._tracking_client.search_logged_models( experiment_ids, filter_string, max_results, order_by ) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index 4561808d0d5c9..c231f8a3b6625 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -17,8 +17,8 @@ DatasetInput, Experiment, InputTag, + LoggedModel, Metric, - Model, ModelInput, Param, Run, @@ -1911,19 +1911,19 @@ def delete_experiment(experiment_id: str) -> None: MlflowClient().delete_experiment(experiment_id) -def create_model( +def create_logged_model( name: str, run_id: Optional[str] = None, tags: Optional[Dict[str, str]] = None, params: Optional[Dict[str, str]] = None, model_type: Optional[str] = None, experiment_id: Optional[str] = None, -) -> Model: +) -> LoggedModel: run = active_run() if run_id is None and run is not None: run_id = run.info.run_id experiment_id = experiment_id if experiment_id is not None else _get_experiment_id() - return MlflowClient().create_model( + return MlflowClient().create_logged_model( experiment_id=experiment_id, name=name, run_id=run_id, @@ -1933,17 +1933,17 @@ def create_model( ) -def get_model(model_id: str) -> Model: - return MlflowClient().get_model(model_id) +def get_logged_model(model_id: str) -> LoggedModel: + return MlflowClient().get_logged_model(model_id) -def search_models( +def search_logged_models( experiment_ids: Optional[List[str]] = None, filter_string: Optional[str] = None, max_results: Optional[int] = None, order_by: Optional[List[str]] = None, output_format: str = "pandas", -) -> Union[List[Model], "pandas.DataFrame"]: +) -> Union[List[LoggedModel], "pandas.DataFrame"]: experiment_ids = experiment_ids or [_get_experiment_id()] models = MlflowClient().search_models( experiment_ids=experiment_ids, diff --git a/mlflow/utils/search_utils.py b/mlflow/utils/search_utils.py index 6d3b6e0424562..0a983421cfc01 100644 --- a/mlflow/utils/search_utils.py +++ b/mlflow/utils/search_utils.py @@ -20,7 +20,7 @@ ) from sqlparse.tokens import Token as TokenType -from mlflow.entities import Model, RunInfo +from mlflow.entities import LoggedModel, RunInfo from mlflow.entities.model_registry.model_version_stages import STAGE_DELETED_INTERNAL from mlflow.exceptions import MlflowException from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE @@ -680,7 +680,7 @@ def run_matches(run): return [run for run in runs if run_matches(run)] @classmethod - def filter_models(cls, models: List[Model], filter_string: Optional[str] = None): + def filter_logged_models(cls, models: List[LoggedModel], filter_string: Optional[str] = None): """Filters a set of runs based on a search filter string.""" if not filter_string: return models @@ -861,7 +861,7 @@ def sort(cls, runs, order_by_list): return runs @classmethod - def sort_models(cls, models, order_by_list): + def sort_logged_models(cls, models, order_by_list): models = sorted(models, key=lambda model: (-model.creation_timestamp, model.model_id)) if not order_by_list: return models From 887d513188e810f3a5edecf65d2759c4b86044ca Mon Sep 17 00:00:00 2001 From: dbczumar Date: Thu, 22 Aug 2024 11:29:16 -0700 Subject: [PATCH 62/62] fix Signed-off-by: dbczumar --- mlflow/store/artifact/models_artifact_repo.py | 2 +- mlflow/store/tracking/file_store.py | 2 +- mlflow/tracking/fluent.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlflow/store/artifact/models_artifact_repo.py b/mlflow/store/artifact/models_artifact_repo.py index c21878064b78a..614743200d147 100644 --- a/mlflow/store/artifact/models_artifact_repo.py +++ b/mlflow/store/artifact/models_artifact_repo.py @@ -95,7 +95,7 @@ def _get_model_uri_infos(uri): name = None version = None model_id = name_and_version_or_id[0] - download_uri = client.get_model(model_id).artifact_location + download_uri = client.get_logged_model(model_id).artifact_location else: name, version = name_and_version_or_id download_uri = client.get_model_version_download_uri(name, version) diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 007bab3ad87f7..1a4cf44ac1f8c 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -1963,7 +1963,7 @@ def create_logged_model( write_yaml(model_dir, FileStore.META_DATA_FILE_NAME, model_info_dict) mkdir(model_dir, FileStore.METRICS_FOLDER_NAME) for tag in tags or []: - self.set_model_tag(model_id=model_id, tag=tag) + self.set_logged_model_tag(model_id=model_id, tag=tag) return self.get_logged_model(model_id=model_id) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index c231f8a3b6625..4e4d3f5fc10a7 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -1945,7 +1945,7 @@ def search_logged_models( output_format: str = "pandas", ) -> Union[List[LoggedModel], "pandas.DataFrame"]: experiment_ids = experiment_ids or [_get_experiment_id()] - models = MlflowClient().search_models( + models = MlflowClient().search_logged_models( experiment_ids=experiment_ids, filter_string=filter_string, max_results=max_results,