diff --git a/conftest.py b/conftest.py index 20ada52e6163e..717dafb83d782 100644 --- a/conftest.py +++ b/conftest.py @@ -7,7 +7,6 @@ import sys import threading -import click import pytest from mlflow.environment_variables import _MLFLOW_TESTING, MLFLOW_TRACKING_URI @@ -93,6 +92,8 @@ def pytest_cmdline_main(config): def pytest_sessionstart(session): if uri := MLFLOW_TRACKING_URI.get(): + import click + click.echo( click.style( ( diff --git a/dev/tracing-requirements.txt b/dev/tracing-requirements.txt new file mode 100644 index 0000000000000..598ba0dd0894a --- /dev/null +++ b/dev/tracing-requirements.txt @@ -0,0 +1,9 @@ +cachetools<6,>=5.0.0 +gitpython<4,>=3.1.9 +opentelemetry-api<3,>=1.9.0 +opentelemetry-sdk<3,>=1.9.0 +packaging<25 +protobuf<6,>=3.12.0 +requests<3,>=2.17.3 +# Databricks SDK is only required for [databricks] extra +databricks-sdk<1,>=0.20.0 diff --git a/examples/tracing/tracing_smaller_client.py b/examples/tracing/tracing_smaller_client.py new file mode 100644 index 0000000000000..0856a64c5ff44 --- /dev/null +++ b/examples/tracing/tracing_smaller_client.py @@ -0,0 +1,85 @@ +""" +This example demonstrates how to create a trace with multiple spans using the low-level MLflow client APIs. +""" + +import mlflow +from mlflow.tracing.destination import MlflowExperiment + +mlflow.login() + +mlflow.set_tracking_uri("databricks") +mlflow.tracing.set_destination( + MlflowExperiment( + experiment_id="ID of your experiment" + ) +) + +client = mlflow.MlflowClient() + +def run(x: int, y: int) -> int: + # Create a trace. The `start_trace` API returns a root span of the trace. + root_span = client.start_trace( + name="my_trace", + inputs={"x": x, "y": y}, + # Tags are key-value pairs associated with the trace. + # You can update the tags later using `client.set_trace_tag` API. + tags={ + "fruit": "apple", + "vegetable": "carrot", + }, + ) + + z = x + y + + # Request ID is a unique identifier for the trace. You will need this ID + # to interact with the trace later using the MLflow client. + request_id = root_span.request_id + + # Create a child span of the root span. + child_span = client.start_span( + name="child_span", + # Specify the request ID to which the child span belongs. + request_id=request_id, + # Also specify the ID of the parent span to build the span hierarchy. + # You can access the span ID via `span_id` property of the span object. + parent_id=root_span.span_id, + # Each span has its own inputs. + inputs={"z": z}, + # Attributes are key-value pairs associated with the span. + attributes={ + "model": "my_model", + "temperature": 0.5, + }, + ) + + z = z**2 + + # End the child span. Please make sure to end the child span before ending the root span. + client.end_span( + request_id=request_id, + span_id=child_span.span_id, + # Set the output(s) of the span. + outputs=z, + # Set the completion status, such as "OK" (default), "ERROR", etc. + status="OK", + ) + + z = z + 1 + + # End the root span. + client.end_trace( + request_id=request_id, + # Set the output(s) of the span. + outputs=z, + ) + + return z + + +assert run(1, 2) == 10 + +trace = mlflow.get_last_active_trace() +print("Last active trace", trace) + +assert trace.info.tags["fruit"] == "apple" +assert trace.info.tags["vegetable"] == "carrot" diff --git a/mlflow/__init__.py b/mlflow/__init__.py index e6f6c60966f04..e2ec9c93c7050 100644 --- a/mlflow/__init__.py +++ b/mlflow/__init__.py @@ -39,17 +39,17 @@ with contextlib.suppress(Exception): mlflow.mismatch._check_version_mismatch() -from mlflow import ( - artifacts, # noqa: F401 - client, # noqa: F401 - config, # noqa: F401 - data, # noqa: F401 - exceptions, # noqa: F401 - models, # noqa: F401 - projects, # noqa: F401 - tracing, # noqa: F401 - tracking, # noqa: F401 -) +# from mlflow import ( +# artifacts, # noqa: F401 +# client, # noqa: F401 +# config, # noqa: F401 +# data, # noqa: F401 +# exceptions, # noqa: F401 +# models, # noqa: F401 +# projects, # noqa: F401 +# tracing, # noqa: F401 +# tracking, # noqa: F401 +# ) from mlflow.environment_variables import MLFLOW_CONFIGURE_LOGGING from mlflow.utils.lazy_load import LazyLoader from mlflow.utils.logging_utils import _configure_mlflow_loggers @@ -105,162 +105,179 @@ if MLFLOW_CONFIGURE_LOGGING.get() is True: _configure_mlflow_loggers(root_module_name=__name__) +# NB: We need the client in order to support low-level tracing calls +# But we don't need most of the client methods...consider removing some from the tracing package? P1 from mlflow.client import MlflowClient # For backward compatibility, we expose the following functions and classes at the top level in # addition to `mlflow.config`. from mlflow.config import ( - disable_system_metrics_logging, - enable_system_metrics_logging, - get_registry_uri, - get_tracking_uri, - is_tracking_uri_set, - set_registry_uri, - set_system_metrics_node_id, - set_system_metrics_samples_before_logging, - set_system_metrics_sampling_interval, + # disable_system_metrics_logging, + # enable_system_metrics_logging, + # get_registry_uri, + # get_tracking_uri, + # is_tracking_uri_set, + # set_registry_uri, + # set_system_metrics_node_id, + # set_system_metrics_samples_before_logging, + # set_system_metrics_sampling_interval, set_tracking_uri, ) from mlflow.exceptions import MlflowException -from mlflow.models import evaluate -from mlflow.models.evaluation.validation import validate_evaluation_results -from mlflow.projects import run +# from mlflow.models import evaluate +# from mlflow.models.evaluation.validation import validate_evaluation_results +# from mlflow.projects import run from mlflow.tracing.fluent import ( add_trace, get_current_active_span, get_last_active_trace, - get_trace, + # get_trace, log_trace, - search_traces, + # search_traces, start_span, trace, update_current_trace, ) -from mlflow.tracking._model_registry.fluent import ( - register_model, - search_model_versions, - search_registered_models, -) +# from mlflow.tracking._model_registry.fluent import ( +# register_model, +# search_model_versions, +# search_registered_models, +# ) from mlflow.tracking.fluent import ( - ActiveRun, + # ActiveRun, + # TODO (TRACE REFACTOR) - Remove this but have autologgin import it elsewhere active_run, - autolog, - create_experiment, - create_logged_model, - delete_experiment, - delete_run, - delete_tag, - end_run, - flush_artifact_async_logging, - flush_async_logging, + + # autolog, + + # TODO: MIGHT NEED THIS! + # create_experiment, + + # create_logged_model, + # delete_experiment, + # delete_run, + # delete_tag, + # end_run, + + # TODO: MIGHT NEED THESE + # flush_artifact_async_logging, + # flush_async_logging, + flush_trace_async_logging, - get_artifact_uri, - get_experiment, - get_experiment_by_name, - get_logged_model, - get_parent_run, - get_run, - last_active_run, - load_table, - log_artifact, - log_artifacts, - log_dict, - log_figure, - log_image, - log_input, - log_metric, - log_metrics, - log_outputs, - log_param, - log_params, - log_table, - log_text, - search_experiments, - search_logged_models, - search_runs, - set_experiment, - set_experiment_tag, - set_experiment_tags, - set_tag, - set_tags, - start_run, + + # TODO: MIGHT NEED THIS! + # get_artifact_uri, + + # TODO: MIGHT NEED THESE! + # get_experiment, + # get_experiment_by_name, + + # get_logged_model, + # get_parent_run, + # get_run, + # last_active_run, + # load_table, + # log_artifact, + # log_artifacts, + # log_dict, + # log_figure, + # log_image, + # log_input, + # log_metric, + # log_metrics, + # log_outputs, + # log_param, + # log_params, + # log_table, + # log_text, + # search_experiments, + # search_logged_models, + # search_runs, + # set_experiment, + # set_experiment_tag, + # set_experiment_tags, + # set_tag, + # set_tags, + # start_run, ) -from mlflow.tracking.multimedia import Image -from mlflow.utils.async_logging.run_operations import RunOperations # noqa: F401 +# from mlflow.tracking.multimedia import Image +# from mlflow.utils.async_logging.run_operations import RunOperations # noqa: F401 + +# TODO: MIGHT NEED THIS! (PROBABLY DON'T) from mlflow.utils.credentials import login -from mlflow.utils.doctor import doctor +# from mlflow.utils.doctor import doctor __all__ = [ - "ActiveRun", - "MlflowClient", + # "ActiveRun", + # "MlflowClient", "MlflowException", - "active_run", - "autolog", - "create_experiment", - "create_logged_model", - "delete_experiment", - "delete_run", - "delete_tag", - "disable_system_metrics_logging", - "doctor", - "enable_system_metrics_logging", - "end_run", - "evaluate", - "flush_async_logging", - "flush_artifact_async_logging", + # "active_run", + # "autolog", + # "create_experiment", + # "create_logged_model", + # "delete_experiment", + # "delete_run", + # "delete_tag", + # "disable_system_metrics_logging", + # "doctor", + # "enable_system_metrics_logging", + # "end_run", + # "evaluate", + # "flush_async_logging", + # "flush_artifact_async_logging", "flush_trace_async_logging", - "get_artifact_uri", - "get_experiment", - "get_experiment_by_name", + # "get_artifact_uri", + # "get_experiment", + # "get_experiment_by_name", "get_last_active_trace", - "get_logged_model", - "get_parent_run", - "get_registry_uri", - "get_run", - "get_tracking_uri", - "is_tracking_uri_set", - "last_active_run", - "load_table", - "log_artifact", - "log_artifacts", - "log_dict", - "log_figure", - "log_image", - "log_input", - "log_outputs", - "log_metric", - "log_metrics", - "log_param", - "log_params", - "log_table", - "log_text", + # "get_logged_model", + # "get_parent_run", + # "get_registry_uri", + # "get_run", + # "get_tracking_uri", + # "is_tracking_uri_set", + # "last_active_run", + # "load_table", + # "log_artifact", + # "log_artifacts", + # "log_dict", + # "log_figure", + # "log_image", + # "log_input", + # "log_outputs", + # "log_metric", + # "log_metrics", + # "log_param", + # "log_params", + # "log_table", + # "log_text", "log_trace", "login", "pyfunc", - "register_model", - "run", - "search_experiments", - "search_logged_models", - "search_model_versions", - "search_registered_models", - "search_runs", - "set_experiment", - "set_experiment_tag", - "set_experiment_tags", - "set_registry_uri", - "set_system_metrics_node_id", - "set_system_metrics_samples_before_logging", - "set_system_metrics_sampling_interval", - "set_tag", - "set_tags", + # "register_model", + # "run", + # "search_experiments", + # "search_logged_models", + # "search_model_versions", + # "search_registered_models", + # "search_runs", + # "set_experiment", + # "set_experiment_tag", + # "set_experiment_tags", + # "set_registry_uri", + # "set_system_metrics_node_id", + # "set_system_metrics_samples_before_logging", + # "set_system_metrics_sampling_interval", + # "set_tag", + # "set_tags", "set_tracking_uri", - "start_run", - "validate_evaluation_results", - "Image", + # "start_run", + # "validate_evaluation_results", + # "Image", # Tracing Fluent APIs "get_current_active_span", "get_trace", - "search_traces", + # "search_traces", "start_span", "trace", "add_trace", diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 1e6fca2b84c5b..66fc23137c81f 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -10,7 +10,6 @@ from typing import Any, Callable, Literal, NamedTuple, Optional, Union from urllib.parse import urlparse -import yaml from packaging.requirements import InvalidRequirement, Requirement import mlflow @@ -713,6 +712,8 @@ def to_dict(self) -> dict[str, Any]: def to_yaml(self, stream=None) -> str: """Write the model as yaml string.""" + import yaml + return yaml.safe_dump(self.to_dict(), stream=stream, default_flow_style=False) def __str__(self): @@ -751,6 +752,8 @@ def load(cls, path) -> "Model": # Load the Model object from a remote model directory model2 = Model.load("s3://mybucket/path/to/my/model") """ + import yaml + # Check if the path is a local directory and not remote path_scheme = urlparse(str(path)).scheme if (not path_scheme or path_scheme == "file") and not os.path.exists(path): diff --git a/mlflow/models/model_config.py b/mlflow/models/model_config.py index 25336310cbbc0..eee3e63f48446 100644 --- a/mlflow/models/model_config.py +++ b/mlflow/models/model_config.py @@ -1,8 +1,6 @@ import os from typing import Any, Optional, Union -import yaml - from mlflow.exceptions import MlflowException from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE @@ -121,6 +119,8 @@ def _read_config(self): dict or None: The content of the YAML file as a dictionary, or None if the config path is not set. """ + import yaml + if isinstance(self.config, dict): return self.config diff --git a/mlflow/models/resources.py b/mlflow/models/resources.py index 6619005623d64..580b0175a2b69 100644 --- a/mlflow/models/resources.py +++ b/mlflow/models/resources.py @@ -3,8 +3,6 @@ from enum import Enum from typing import Any, Optional -import yaml - DEFAULT_API_VERSION = "1" diff --git a/mlflow/openai/__init__.py b/mlflow/openai/__init__.py index a86937c98ca9e..1ead8a3216405 100644 --- a/mlflow/openai/__init__.py +++ b/mlflow/openai/__init__.py @@ -41,17 +41,16 @@ from string import Formatter from typing import Any, Optional -import yaml from packaging.version import Version import mlflow from mlflow import pyfunc from mlflow.environment_variables import MLFLOW_OPENAI_SECRET_SCOPE from mlflow.exceptions import MlflowException -from mlflow.models import Model, ModelInputExample, ModelSignature +# TODO (TRACE REFACTOR) +from mlflow.models import Model +# from mlflow.models import Model, ModelInputExample, ModelSignature from mlflow.models.model import MLMODEL_FILE_NAME -from mlflow.models.signature import _infer_signature_from_input_example -from mlflow.models.utils import _save_example from mlflow.openai._openai_autolog import ( patched_agent_get_chat_completion, patched_call, @@ -60,7 +59,6 @@ from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS from mlflow.tracking.artifact_utils import _download_artifact_from_uri -from mlflow.types import ColSpec, Schema, TensorSpec from mlflow.utils.annotations import experimental from mlflow.utils.autologging_utils import autologging_integration, safe_patch from mlflow.utils.databricks_utils import ( @@ -211,6 +209,8 @@ def _get_openai_package_version(): def _log_secrets_yaml(local_model_dir, scope): + import yaml + with open(os.path.join(local_model_dir, "openai.yaml"), "w") as f: yaml.safe_dump({e.value: f"{scope}:{e.secret_key}" for e in _OpenAIEnvVar}, f) @@ -221,6 +221,8 @@ def _parse_format_fields(s) -> set[str]: def _get_input_schema(task, content): + from mlflow.types import ColSpec, Schema + if content: formatter = _ContentFormatter(task, content) variables = formatter.variables @@ -243,8 +245,11 @@ def save_model( conda_env=None, code_paths=None, mlflow_model=None, - signature: ModelSignature = None, - input_example: ModelInputExample = None, + # TODO (TRACE REFACTOR) + # signature: ModelSignature = None, + # input_example: ModelInputExample = None, + signature=None, + input_example=None, pip_requirements=None, extra_pip_requirements=None, metadata=None, @@ -316,7 +321,13 @@ def save_model( if Version(_get_openai_package_version()).major < 1: raise MlflowException("Only openai>=1.0 is supported.") + import yaml + import numpy as np + + from mlflow.models.signature import _infer_signature_from_input_example + from mlflow.models.utils import _save_example + from mlflow.types import TensorSpec _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements) path = os.path.abspath(path) @@ -444,8 +455,11 @@ def log_model( conda_env=None, code_paths=None, registered_model_name=None, - signature: ModelSignature = None, - input_example: ModelInputExample = None, + # TODO (TRACE REFACTOR) + # signature: ModelSignature = None, + # input_example: ModelInputExample = None, + signature=None, + input_example=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, @@ -563,6 +577,8 @@ def log_model( def _load_model(path): + import yaml + with open(path) as f: return yaml.safe_load(f) diff --git a/mlflow/store/model_registry/file_store.py b/mlflow/store/model_registry/file_store.py index d3dc9c4b62328..14da50172bebe 100644 --- a/mlflow/store/model_registry/file_store.py +++ b/mlflow/store/model_registry/file_store.py @@ -54,7 +54,6 @@ write_to, write_yaml, ) -from mlflow.utils.search_utils import SearchModelUtils, SearchModelVersionUtils, SearchUtils from mlflow.utils.string_utils import is_string_type from mlflow.utils.time import get_current_time_millis from mlflow.utils.validation import ( @@ -364,6 +363,8 @@ def search_registered_models( that satisfy the search expressions. The pagination token for the next page can be obtained via the ``token`` attribute of the object. """ + from mlflow.utils.search_utils import SearchModelUtils, SearchUtils + if not isinstance(max_results, int) or max_results < 1: raise MlflowException( "Invalid value for max_results. It must be a positive integer," @@ -892,6 +893,8 @@ def search_model_versions( page can be obtained via the ``token`` attribute of the object. """ + from mlflow.utils.search_utils import SearchModelVersionUtils, SearchUtils + if not isinstance(max_results, int) or max_results < 1: raise MlflowException( "Invalid value for max_results. It must be a positive integer," diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 63797bb69a8cf..87ef0e8473d72 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -87,11 +87,6 @@ _get_run_name_from_tags, ) from mlflow.utils.name_utils import _generate_random_name, _generate_unique_integer_id -from mlflow.utils.search_utils import ( - SearchExperimentsUtils, - SearchTraceUtils, - SearchUtils, -) from mlflow.utils.string_utils import is_string_type from mlflow.utils.time import get_current_time_millis from mlflow.utils.uri import ( @@ -318,6 +313,8 @@ def search_experiments( order_by=None, page_token=None, ): + from mlflow.utils.search_utils import SearchExperimentsUtils, SearchUtils + if not isinstance(max_results, int) or max_results < 1: raise MlflowException( f"Invalid value {max_results} for parameter 'max_results' supplied. It must be " @@ -995,6 +992,8 @@ def _search_runs( order_by, page_token, ): + from mlflow.utils.search_utils import SearchUtils + if max_results > SEARCH_MAX_RESULTS_THRESHOLD: raise MlflowException( "Invalid value for request parameter max_results. It must be at " @@ -1903,6 +1902,8 @@ def search_traces( some store implementations may not support pagination and thus the returned token would not be meaningful in such cases. """ + from mlflow.utils.search_utils import SearchTraceUtils + if max_results > SEARCH_MAX_RESULTS_THRESHOLD: raise MlflowException( "Invalid value for request parameter max_results. It must be at " diff --git a/mlflow/tracing/__init__.py b/mlflow/tracing/__init__.py index b66c71efda7ea..f109359351aac 100644 --- a/mlflow/tracing/__init__.py +++ b/mlflow/tracing/__init__.py @@ -1,8 +1,10 @@ from mlflow.tracing.display import disable_notebook_display, enable_notebook_display from mlflow.tracing.provider import disable, enable, reset, set_destination from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools +from mlflow.tracing.autologging import autolog __all__ = [ + "autolog", "disable", "enable", "disable_notebook_display", diff --git a/mlflow/tracing/autologging.py b/mlflow/tracing/autologging.py new file mode 100644 index 0000000000000..adc8d3b75ee23 --- /dev/null +++ b/mlflow/tracing/autologging.py @@ -0,0 +1,23 @@ +from typing import Optional + + +def autolog( + disable: bool = False, + silent: bool = False, + exclude_flavors: Optional[list[str]] = None, +) -> None: + from mlflow.tracking.fluent import autolog as _autolog + + return _autolog( + log_traces=True, + disable=disable, + exclude_flavors=exclude_flavors, + log_input_examples=False, + log_model_signatures=False, + log_models=False, + log_datasets=False, + exclusive=False, + disable_for_unsupported_versions=False, + silent=False, + extra_tags=None, + ) diff --git a/mlflow/tracing/export/mlflow.py b/mlflow/tracing/export/mlflow.py index e911fb8b24796..a94d10168de96 100644 --- a/mlflow/tracing/export/mlflow.py +++ b/mlflow/tracing/export/mlflow.py @@ -80,6 +80,7 @@ def export(self, root_spans: Sequence[ReadableSpan]): def _log_trace(self, trace: Trace): """Log the trace to MLflow backend.""" + self._client._upload_trace_data(trace.info, trace.data) upload_trace_data_task = Task( handler=self._client._upload_trace_data, args=(trace.info, trace.data), diff --git a/mlflow/tracing/fluent.py b/mlflow/tracing/fluent.py index 07a6228ffa19b..5e011f0d353a8 100644 --- a/mlflow/tracing/fluent.py +++ b/mlflow/tracing/fluent.py @@ -11,7 +11,6 @@ from cachetools import TTLCache from opentelemetry import trace as trace_api -from mlflow import MlflowClient from mlflow.entities import NoOpSpan, SpanType, Trace from mlflow.entities.span import LiveSpan, create_mlflow_span from mlflow.entities.span_event import SpanEvent @@ -46,6 +45,7 @@ start_client_span_or_trace, ) from mlflow.tracing.utils.search import extract_span_inputs_outputs, traces_to_df +from mlflow.tracking.client import MlflowClient from mlflow.tracking.fluent import _get_experiment_id from mlflow.utils import get_results_from_paginated_fn from mlflow.utils.annotations import experimental diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index a1ca944a00de8..f36da89b23ea6 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -17,8 +17,6 @@ import warnings from typing import TYPE_CHECKING, Any, Optional, Sequence, Union -import yaml - import mlflow from mlflow.entities import ( DatasetInput, @@ -2119,6 +2117,8 @@ def log_dict(self, run_id: str, dictionary: dict[str, Any], artifact_file: str) mlflow.log_dict(run_id, dictionary, "data.txt") """ + import yaml + extension = os.path.splitext(artifact_file)[1] with self._log_artifact_helper(run_id, artifact_file) as tmp_path: diff --git a/mlflow/types/__init__.py b/mlflow/types/__init__.py index 051b233fe6c21..c90aebe268b1b 100644 --- a/mlflow/types/__init__.py +++ b/mlflow/types/__init__.py @@ -3,14 +3,15 @@ components to describe interface independent of other frameworks or languages. """ -import mlflow.types.llm # noqa: F401 -from mlflow.types.schema import ColSpec, DataType, ParamSchema, ParamSpec, Schema, TensorSpec - -__all__ = [ - "Schema", - "ColSpec", - "DataType", - "TensorSpec", - "ParamSchema", - "ParamSpec", -] +# TODO (TRACE REFACTOR) +# import mlflow.types.llm # noqa: F401 +# from mlflow.types.schema import ColSpec, DataType, ParamSchema, ParamSpec, Schema, TensorSpec +# +# __all__ = [ +# "Schema", +# "ColSpec", +# "DataType", +# "TensorSpec", +# "ParamSchema", +# "ParamSpec", +# ] diff --git a/mlflow/utils/environment.py b/mlflow/utils/environment.py index 5caa5d8e82088..905d814fb3226 100644 --- a/mlflow/utils/environment.py +++ b/mlflow/utils/environment.py @@ -10,7 +10,6 @@ from copy import deepcopy from typing import Optional -import yaml from packaging.requirements import InvalidRequirement, Requirement from packaging.version import Version @@ -118,6 +117,8 @@ def from_dict(cls, dct): return cls(**dct) def to_yaml(self, path): + import yaml + with open(path, "w") as f: # Exclude None and empty lists data = {k: v for k, v in self.to_dict().items() if v} @@ -125,11 +126,15 @@ def to_yaml(self, path): @classmethod def from_yaml(cls, path): + import yaml + with open(path) as f: return cls.from_dict(yaml.safe_load(f)) @staticmethod def get_dependencies_from_conda_yaml(path): + import yaml + with open(path) as f: conda_env = yaml.safe_load(f) @@ -230,6 +235,8 @@ def _mlflow_conda_env( # noqa: D417 Conda environment. """ + import yaml + additional_pip_deps = additional_pip_deps or [] mlflow_deps = ( [f"mlflow=={VERSION}"] @@ -712,6 +719,8 @@ def _process_conda_env(conda_env): Processes `conda_env` passed to `mlflow.*.save_model` or `mlflow.*.log_model`, and returns a tuple of (conda_env, pip_requirements, pip_constraints). """ + import yaml + if isinstance(conda_env, str): with open(conda_env) as f: conda_env = yaml.safe_load(f) @@ -767,6 +776,8 @@ def _get_pip_install_mlflow(): def _get_requirements_from_file( file_path: pathlib.Path, ) -> list[Requirement]: + import yaml + data = file_path.read_text() if file_path.name == _CONDA_ENV_FILE_NAME: conda_env = yaml.safe_load(data) @@ -780,6 +791,8 @@ def _write_requirements_to_file( file_path: pathlib.Path, new_reqs: list[str], ) -> None: + import yaml + if file_path.name == _CONDA_ENV_FILE_NAME: conda_env = yaml.safe_load(file_path.read_text()) conda_env = _overwrite_pip_deps(conda_env, new_reqs) diff --git a/mlflow/utils/file_utils.py b/mlflow/utils/file_utils.py index f6cc20936a3e5..97ceb8e88c5d8 100644 --- a/mlflow/utils/file_utils.py +++ b/mlflow/utils/file_utils.py @@ -27,16 +27,8 @@ from urllib.parse import unquote from urllib.request import pathname2url -import yaml - from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE -try: - from yaml import CSafeDumper as YamlSafeDumper - from yaml import CSafeLoader as YamlSafeLoader -except ImportError: - from yaml import SafeDumper as YamlSafeDumper - from yaml import SafeLoader as YamlSafeLoader from mlflow.entities import FileInfo from mlflow.environment_variables import ( @@ -233,6 +225,15 @@ def write_yaml(root, file_name, data, overwrite=False, sort_keys=True, ensure_ya sort_keys: Whether to sort the keys when writing the yaml file. ensure_yaml_extension: If True, will automatically add .yaml extension if not given. """ + import yaml + + try: + from yaml import CSafeDumper as YamlSafeDumper + from yaml import CSafeLoader as YamlSafeLoader + except ImportError: + from yaml import SafeDumper as YamlSafeDumper + from yaml import SafeLoader as YamlSafeLoader + if not exists(root): raise MissingConfigException(f"Parent directory '{root}' does not exist.") @@ -300,6 +301,15 @@ def read_yaml(root, file_name): Returns: Data in yaml file as dictionary. """ + import yaml + + try: + from yaml import CSafeDumper as YamlSafeDumper + from yaml import CSafeLoader as YamlSafeLoader + except ImportError: + from yaml import SafeDumper as YamlSafeDumper + from yaml import SafeLoader as YamlSafeLoader + if not exists(root): raise MissingConfigException( f"Cannot read '{file_name}'. Parent dir '{root}' does not exist." @@ -312,17 +322,6 @@ def read_yaml(root, file_name): return yaml.load(yaml_file, Loader=YamlSafeLoader) -class UniqueKeyLoader(YamlSafeLoader): - def construct_mapping(self, node, deep=False): - mapping = set() - for key_node, _ in node.value: - key = self.construct_object(key_node, deep=deep) - if key in mapping: - raise ValueError(f"Duplicate '{key}' key found in YAML.") - mapping.add(key) - return super().construct_mapping(node, deep) - - def render_and_merge_yaml(root, template_name, context_name): """Renders a Jinja2-templated YAML file based on a YAML context file, merge them, and return result as a dictionary. @@ -335,9 +334,28 @@ def render_and_merge_yaml(root, template_name, context_name): Returns: Data in yaml file as dictionary. """ + import yaml + from jinja2 import FileSystemLoader, StrictUndefined from jinja2.sandbox import SandboxedEnvironment + try: + from yaml import CSafeDumper as YamlSafeDumper + from yaml import CSafeLoader as YamlSafeLoader + except ImportError: + from yaml import SafeDumper as YamlSafeDumper + from yaml import SafeLoader as YamlSafeLoader + + class UniqueKeyLoader(YamlSafeLoader): + def construct_mapping(self, node, deep=False): + mapping = set() + for key_node, _ in node.value: + key = self.construct_object(key_node, deep=deep) + if key in mapping: + raise ValueError(f"Duplicate '{key}' key found in YAML.") + mapping.add(key) + return super().construct_mapping(node, deep) + template_path = os.path.join(root, template_name) context_path = os.path.join(root, context_name) diff --git a/mlflow/utils/mime_type_utils.py b/mlflow/utils/mime_type_utils.py index 79465fed74d7c..578c0708d5a1e 100644 --- a/mlflow/utils/mime_type_utils.py +++ b/mlflow/utils/mime_type_utils.py @@ -6,10 +6,8 @@ # TODO: Create a module to define constants to avoid circular imports # and move MLMODEL_FILE_NAME and MLPROJECT_FILE_NAME in the module. def get_text_extensions(): - from mlflow.models.model import MLMODEL_FILE_NAME - from mlflow.projects._project_spec import MLPROJECT_FILE_NAME - return [ + text_extensions = [ "txt", "log", "err", @@ -33,9 +31,21 @@ def get_text_extensions(): "tsv", "md", "rst", - MLMODEL_FILE_NAME, - MLPROJECT_FILE_NAME, ] + try: + from mlflow.models.model import MLMODEL_FILE_NAME + + text_extensions.append(MLMODEL_FILE_NAME) + except ImportError: + pass + try: + from mlflow.projects._project_spec import MLPROJECT_FILE_NAME + + text_extensions.append(MLPROJECT_FILE_NAME) + except ImportError: + pass + + return text_extensions def _guess_mime_type(file_path): diff --git a/mlflow/utils/model_utils.py b/mlflow/utils/model_utils.py index 62f26719c0d7e..fde57437a6a02 100644 --- a/mlflow/utils/model_utils.py +++ b/mlflow/utils/model_utils.py @@ -7,8 +7,6 @@ from pathlib import Path from typing import Any -import yaml - from mlflow.exceptions import MlflowException from mlflow.models import Model from mlflow.models.model import MLMODEL_FILE_NAME @@ -385,6 +383,8 @@ def _get_overridden_pyfunc_model_config( def _validate_and_get_model_config_from_file(model_config): + import yaml + model_config = os.path.abspath(model_config) if os.path.exists(model_config): with open(model_config) as file: diff --git a/mlflow/utils/proto_json_utils.py b/mlflow/utils/proto_json_utils.py index c67929c511bb1..a8e97bc997197 100644 --- a/mlflow/utils/proto_json_utils.py +++ b/mlflow/utils/proto_json_utils.py @@ -9,12 +9,10 @@ from json import JSONEncoder from typing import Any, Optional -import pydantic from google.protobuf.descriptor import FieldDescriptor from google.protobuf.json_format import MessageToJson, ParseDict from mlflow.exceptions import MlflowException -from mlflow.utils import IS_PYDANTIC_V2_OR_NEWER _PROTOBUF_INT64_FIELDS = [ FieldDescriptor.TYPE_INT64, @@ -169,6 +167,9 @@ class NumpyEncoder(JSONEncoder): def try_convert(self, o): import numpy as np import pandas as pd + # MAY HAVE TO TRY CATCH!!! + import pydantic + from mlflow.utils import IS_PYDANTIC_V2_OR_NEWER def encode_binary(x): return base64.encodebytes(x).decode("ascii") diff --git a/mlflow/utils/pydantic_utils.py b/mlflow/utils/pydantic_utils.py index 232516b1b26cb..32cf1897a86d9 100644 --- a/mlflow/utils/pydantic_utils.py +++ b/mlflow/utils/pydantic_utils.py @@ -4,7 +4,9 @@ from packaging.version import Version from pydantic import BaseModel -IS_PYDANTIC_V2_OR_NEWER = Version(pydantic.VERSION).major >= 2 +# TODO (TRACE REFACTOR) +IS_PYDANTIC_V2_OR_NEWER = False +# IS_PYDANTIC_V2_OR_NEWER = Version(pydantic.VERSION).major >= 2 def model_dump_compat(pydantic_model: BaseModel, **kwargs: Any) -> dict[str, Any]: diff --git a/requirements/test-requirements.txt b/requirements/test-requirements.txt index d31e1fefdf19e..ba3ee12b5cf64 100644 --- a/requirements/test-requirements.txt +++ b/requirements/test-requirements.txt @@ -1,4 +1,4 @@ -## Dependencies required to run tests +#https://runbot-ci.cloud.databricks.com/build/TestShard-LKG-Aws/run/62381388# Dependencies required to run tests # Required for testing utilities for parsing pip requirements pip>=20.1 ## Test-only dependencies diff --git a/tests/conftest.py b/tests/conftest.py index 83c7b96af7c48..9dc3ee4262435 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,24 +100,25 @@ def enable_test_mode_by_default_for_autologging_integrations(): yield from enable_test_mode() -@pytest.fixture(autouse=True) -def clean_up_leaked_runs(): - """ - Certain test cases validate safety API behavior when runs are leaked. Leaked runs that - are not cleaned up between test cases may result in cascading failures that are hard to - debug. Accordingly, this fixture attempts to end any active runs it encounters and - throws an exception (which reported as an additional error in the pytest execution output). - """ - try: - yield - assert not mlflow.active_run(), ( - "test case unexpectedly leaked a run. Run info: {}. Run data: {}".format( - mlflow.active_run().info, mlflow.active_run().data - ) - ) - finally: - while mlflow.active_run(): - mlflow.end_run() +# TODO: HANDLE THE CASE WHERE ACTIVE RUN ISN'T AVAILABLE! +# @pytest.fixture(autouse=True) +# def clean_up_leaked_runs(): +# """ +# Certain test cases validate safety API behavior when runs are leaked. Leaked runs that +# are not cleaned up between test cases may result in cascading failures that are hard to +# debug. Accordingly, this fixture attempts to end any active runs it encounters and +# throws an exception (which reported as an additional error in the pytest execution output). +# """ +# try: +# yield +# assert not mlflow.active_run(), ( +# "test case unexpectedly leaked a run. Run info: {}. Run data: {}".format( +# mlflow.active_run().info, mlflow.active_run().data +# ) +# ) +# finally: +# while mlflow.active_run(): +# mlflow.end_run() def _called_in_save_model(): @@ -127,26 +128,27 @@ def _called_in_save_model(): return False -@pytest.fixture(autouse=True) -def prevent_infer_pip_requirements_fallback(request): - """ - Prevents `mlflow.models.infer_pip_requirements` from falling back in `mlflow.*.save_model` - unless explicitly disabled via `pytest.mark.allow_infer_pip_requirements_fallback`. - """ - from mlflow.utils.environment import _INFER_PIP_REQUIREMENTS_GENERAL_ERROR_MESSAGE - - def new_exception(msg, *_, **__): - if msg == _INFER_PIP_REQUIREMENTS_GENERAL_ERROR_MESSAGE and _called_in_save_model(): - raise Exception( - "`mlflow.models.infer_pip_requirements` should not fall back in" - "`mlflow.*.save_model` during test" - ) - - if "allow_infer_pip_requirements_fallback" not in request.keywords: - with mock.patch("mlflow.utils.environment._logger.exception", new=new_exception): - yield - else: - yield +# TODO: Figure out how to move this somewhere else! +# @pytest.fixture(autouse=True) +# def prevent_infer_pip_requirements_fallback(request): +# """ +# Prevents `mlflow.models.infer_pip_requirements` from falling back in `mlflow.*.save_model` +# unless explicitly disabled via `pytest.mark.allow_infer_pip_requirements_fallback`. +# """ +# from mlflow.utils.environment import _INFER_PIP_REQUIREMENTS_GENERAL_ERROR_MESSAGE +# +# def new_exception(msg, *_, **__): +# if msg == _INFER_PIP_REQUIREMENTS_GENERAL_ERROR_MESSAGE and _called_in_save_model(): +# raise Exception( +# "`mlflow.models.infer_pip_requirements` should not fall back in" +# "`mlflow.*.save_model` during test" +# ) +# +# if "allow_infer_pip_requirements_fallback" not in request.keywords: +# with mock.patch("mlflow.utils.environment._logger.exception", new=new_exception): +# yield +# else: +# yield @pytest.fixture(autouse=True) diff --git a/tests/helper_functions.py b/tests/helper_functions.py index 76ebef90dfffc..d477deb6731a5 100644 --- a/tests/helper_functions.py +++ b/tests/helper_functions.py @@ -17,18 +17,10 @@ import pytest import requests -import yaml import mlflow from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS from mlflow.tracking.artifact_utils import _download_artifact_from_uri -from mlflow.utils.environment import ( - _CONDA_ENV_FILE_NAME, - _CONSTRAINTS_FILE_NAME, - _REQUIREMENTS_FILE_NAME, - _generate_mlflow_version_pinning, - _get_pip_deps, -) from mlflow.utils.file_utils import read_yaml, write_yaml from mlflow.utils.os import is_windows @@ -408,6 +400,8 @@ def create_mock_response(status_code, text): def _read_yaml(path): + import yaml + with open(path) as f: return yaml.safe_load(f) @@ -435,6 +429,8 @@ def _compare_logged_code_paths(code_path, model_path, flavor_name): def _compare_conda_env_requirements(env_path, req_path): + from mlflow.utils.environment import _get_pip_deps + assert os.path.exists(req_path) custom_env_parsed = _read_yaml(env_path) requirements = _read_lines(req_path) @@ -470,6 +466,13 @@ def _assert_pip_requirements(model_uri, requirements, constraints=None, strict=F If `strict` is True, evaluate `set(requirements) == set(loaded_requirements)`. Otherwise, evaluate `set(requirements) <= set(loaded_requirements)`. """ + from mlflow.utils.environment import ( + _CONDA_ENV_FILE_NAME, + _CONSTRAINTS_FILE_NAME, + _REQUIREMENTS_FILE_NAME, + _get_pip_deps, + ) + local_path = _download_artifact_from_uri(model_uri) txt_reqs = _read_lines(os.path.join(local_path, _REQUIREMENTS_FILE_NAME)) conda_reqs = _get_pip_deps(_read_yaml(os.path.join(local_path, _CONDA_ENV_FILE_NAME))) @@ -600,6 +603,8 @@ def assert_array_almost_equal(actual_array, desired_array, rtol=1e-6): def _mlflow_major_version_string(): + from mlflow.utils.environment import _generate_mlflow_version_pinning + return _generate_mlflow_version_pinning() diff --git a/tests/tracing/conftest.py b/tests/tracing/conftest.py index a21fea4b673b5..8a9fa78a34eae 100644 --- a/tests/tracing/conftest.py +++ b/tests/tracing/conftest.py @@ -1,3 +1,4 @@ +import os import subprocess import tempfile import time diff --git a/tests/tracing/display/test_ipython.py b/tests/tracing/display/test_ipython.py index bf618140fcec2..01890d03bec8d 100644 --- a/tests/tracing/display/test_ipython.py +++ b/tests/tracing/display/test_ipython.py @@ -1,9 +1,17 @@ +import pytest + +pytest.importorskip( + "IPython", + reason=( + "These tests require IPython. Run this suite separately from tracing core tests " + "in an environment with IPython installed." + ), +) + import json from collections import defaultdict from unittest.mock import Mock -import pytest - import mlflow from mlflow.tracing.display import ( IPythonTraceDisplayHandler, diff --git a/tests/tracing/test_fluent.py b/tests/tracing/test_fluent.py index b94a4a13e5653..4b853d98af762 100644 --- a/tests/tracing/test_fluent.py +++ b/tests/tracing/test_fluent.py @@ -20,9 +20,6 @@ from mlflow.entities.trace_status import TraceStatus from mlflow.environment_variables import MLFLOW_TRACKING_USERNAME from mlflow.exceptions import MlflowException -from mlflow.pyfunc.context import Context, set_prediction_context -from mlflow.store.entities.paged_list import PagedList -from mlflow.store.tracking import SEARCH_TRACES_DEFAULT_MAX_RESULTS from mlflow.tracing.constant import ( TRACE_SCHEMA_VERSION, TRACE_SCHEMA_VERSION_KEY, @@ -37,8 +34,6 @@ from mlflow.utils.file_utils import local_file_uri_to_path from mlflow.utils.os import is_windows -from tests.tracing.helper import create_test_trace_info, get_traces - class DefaultTestModel: @mlflow.trace() @@ -356,6 +351,8 @@ def test_trace_in_databricks_model_serving( # Dummy flask app for prediction import flask + from mlflow.pyfunc.context import Context, set_prediction_context + app = flask.Flask(__name__) @app.route("/invocations", methods=["POST"]) @@ -461,6 +458,8 @@ def square(self, t): def test_trace_in_model_evaluation(mock_store, monkeypatch, async_logging_enabled): + from mlflow.pyfunc.context import Context, set_prediction_context + monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "bob") monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test") @@ -856,437 +855,6 @@ def test_get_trace(mock_get_display_handler): mock_logger.warning.assert_called_once() -def test_test_search_traces_empty(mock_client): - mock_client.search_traces.return_value = PagedList([], token=None) - - traces = mlflow.search_traces() - assert traces.empty - - default_columns = Trace.pandas_dataframe_columns() - assert traces.columns.tolist() == default_columns - - traces = mlflow.search_traces(extract_fields=["foo.inputs.bar"]) - assert traces.columns.tolist() == [*default_columns, "foo.inputs.bar"] - - mock_client.search_traces.assert_called() - - -def test_search_traces(mock_client): - mock_client.search_traces.return_value = PagedList( - [ - Trace( - info=create_test_trace_info(f"tr-{i}"), - data=TraceData([], "", ""), - ) - for i in range(10) - ], - token=None, - ) - - traces = mlflow.search_traces( - experiment_ids=["1"], - filter_string="name = 'foo'", - max_results=10, - order_by=["timestamp DESC"], - ) - - assert len(traces) == 10 - mock_client.search_traces.assert_called_once_with( - experiment_ids=["1"], - run_id=None, - filter_string="name = 'foo'", - max_results=10, - order_by=["timestamp DESC"], - page_token=None, - model_id=None, - ) - - -def test_search_traces_with_pagination(mock_client): - traces = [ - Trace( - info=create_test_trace_info(f"tr-{i}"), - data=TraceData([], "", ""), - ) - for i in range(30) - ] - - mock_client.search_traces.side_effect = [ - PagedList(traces[:10], token="token-1"), - PagedList(traces[10:20], token="token-2"), - PagedList(traces[20:], token=None), - ] - - traces = mlflow.search_traces(experiment_ids=["1"]) - - assert len(traces) == 30 - common_args = { - "experiment_ids": ["1"], - "run_id": None, - "max_results": SEARCH_TRACES_DEFAULT_MAX_RESULTS, - "filter_string": None, - "order_by": None, - } - mock_client.search_traces.assert_has_calls( - [ - mock.call(**common_args, page_token=None, model_id=None), - mock.call(**common_args, page_token="token-1", model_id=None), - mock.call(**common_args, page_token="token-2", model_id=None), - ] - ) - - -def test_search_traces_with_default_experiment_id(mock_client): - mock_client.search_traces.return_value = PagedList([], token=None) - with mock.patch("mlflow.tracing.fluent._get_experiment_id", return_value="123"): - mlflow.search_traces() - - mock_client.search_traces.assert_called_once_with( - experiment_ids=["123"], - run_id=None, - filter_string=None, - max_results=SEARCH_TRACES_DEFAULT_MAX_RESULTS, - order_by=None, - page_token=None, - model_id=None, - ) - - -def test_search_traces_yields_expected_dataframe_contents(monkeypatch): - model = DefaultTestModel() - client = mlflow.MlflowClient() - expected_traces = [] - for _ in range(10): - model.predict(2, 5) - time.sleep(0.1) - - # The in-memory trace returned from get_last_active_trace() is not guaranteed to be - # exactly same as the trace stored in the backend (e.g., tags created by the backend). - # Therefore, we fetch the trace from the backend to compare the results. - trace = client.get_trace(mlflow.get_last_active_trace().info.request_id) - expected_traces.append(trace) - - df = mlflow.search_traces(max_results=10, order_by=["timestamp ASC"]) - assert df.columns.tolist() == [ - "request_id", - "trace", - "timestamp_ms", - "status", - "execution_time_ms", - "request", - "response", - "request_metadata", - "spans", - "tags", - ] - for idx, trace in enumerate(expected_traces): - assert df.iloc[idx].request_id == trace.info.request_id - assert df.iloc[idx].trace.info.request_id == trace.info.request_id - assert df.iloc[idx].timestamp_ms == trace.info.timestamp_ms - assert df.iloc[idx].status == trace.info.status - assert df.iloc[idx].execution_time_ms == trace.info.execution_time_ms - assert df.iloc[idx].request == json.loads(trace.data.request) - assert df.iloc[idx].response == json.loads(trace.data.response) - assert df.iloc[idx].request_metadata == trace.info.request_metadata - assert df.iloc[idx].spans == [s.to_dict() for s in trace.data.spans] - assert df.iloc[idx].tags == trace.info.tags - - -def test_search_traces_handles_missing_response_tags_and_metadata(monkeypatch): - class MockMlflowClient: - def search_traces(self, *args, **kwargs): - return [ - Trace( - info=TraceInfo( - request_id=5, - experiment_id="test", - timestamp_ms=1, - execution_time_ms=2, - status=TraceStatus.OK, - ), - data=TraceData( - spans=[], - request="request", - # Response is missing - ), - ) - ] - - monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) - - df = mlflow.search_traces() - assert df["response"].isnull().all() - assert df["tags"].tolist() == [{}] - assert df["request_metadata"].tolist() == [{}] - - -def test_search_traces_extracts_fields_as_expected(monkeypatch): - model = DefaultTestModel() - model.predict(2, 5) - - class MockMlflowClient: - def search_traces(self, *args, **kwargs): - return get_traces() - - monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) - - df = mlflow.search_traces( - extract_fields=["predict.inputs.x", "predict.outputs", "add_one_with_custom_name.inputs.z"] - ) - assert df["predict.inputs.x"].tolist() == [2] - assert df["predict.outputs"].tolist() == [64] - assert df["add_one_with_custom_name.inputs.z"].tolist() == [7] - - -# Test cases should cover case where there are no spans at all -def test_search_traces_with_no_spans(monkeypatch): - class MockMlflowClient: - def search_traces(self, *args, **kwargs): - return [] - - monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) - - df = mlflow.search_traces() - assert df.empty - - -# no spans have the input or output with name, -# some span has an input but we’re looking for output, -def test_search_traces_with_input_and_no_output(monkeypatch): - with mlflow.start_span(name="with_input_and_no_output") as span: - span.set_inputs({"a": 1}) - - class MockMlflowClient: - def search_traces(self, *args, **kwargs): - return get_traces() - - monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) - - df = mlflow.search_traces( - extract_fields=["with_input_and_no_output.inputs.a", "with_input_and_no_output.outputs"] - ) - assert df["with_input_and_no_output.inputs.a"].tolist() == [1] - assert df["with_input_and_no_output.outputs"].isnull().all() - - -# Test case where span content is invalid -def test_search_traces_with_invalid_span_content(monkeypatch): - class MockMlflowClient: - def search_traces(self, *args, **kwargs): - # Invalid span content - return [ - Trace( - info=TraceInfo( - request_id=5, - experiment_id="test", - timestamp_ms=1, - execution_time_ms=2, - status=TraceStatus.OK, - ), - data=TraceData(spans=[None], request="request", response="response"), - ) - ] - - monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) - - with pytest.raises(AttributeError, match="NoneType"): - mlflow.search_traces() - - -# Test case where span inputs / outputs aren’t dict -def test_search_traces_with_non_dict_span_inputs_outputs(monkeypatch): - with mlflow.start_span(name="non_dict_span") as span: - span.set_inputs(["a", "b"]) - span.set_outputs([1, 2, 3]) - - class MockMlflowClient: - def search_traces(self, *args, **kwargs): - return get_traces() - - monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) - - df = mlflow.search_traces( - extract_fields=["non_dict_span.inputs", "non_dict_span.outputs", "non_dict_span.inputs.x"] - ) - assert df["non_dict_span.inputs"].tolist() == [["a", "b"]] - assert df["non_dict_span.outputs"].tolist() == [[1, 2, 3]] - assert df["non_dict_span.inputs.x"].isnull().all() - - -# Test case where there are multiple spans with the same name -def test_search_traces_with_multiple_spans_with_same_name(monkeypatch): - class TestModel: - @mlflow.trace(name="duplicate_name") - def predict(self, x, y): - z = x + y - z = self.add_one(z) - z = mlflow.trace(self.square)(z) - return z # noqa: RET504 - - @mlflow.trace(span_type=SpanType.LLM, name="duplicate_name", attributes={"delta": 1}) - def add_one(self, z): - return z + 1 - - def square(self, t): - res = t**2 - time.sleep(0.1) - return res - - model = TestModel() - model.predict(2, 5) - - class MockMlflowClient: - def search_traces(self, *args, **kwargs): - return get_traces() - - monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) - - df = mlflow.search_traces( - extract_fields=[ - "duplicate_name.inputs.y", - "duplicate_name.inputs.x", - "duplicate_name.inputs.z", - "duplicate_name_1.inputs.x", - "duplicate_name_1.inputs.y", - "duplicate_name_2.inputs.z", - ] - ) - # Duplicate spans would all be null - assert df["duplicate_name.inputs.y"].isnull().all() - assert df["duplicate_name.inputs.x"].isnull().all() - assert df["duplicate_name.inputs.z"].isnull().all() - assert df["duplicate_name_1.inputs.x"].tolist() == [2] - assert df["duplicate_name_1.inputs.y"].tolist() == [5] - assert df["duplicate_name_2.inputs.z"].tolist() == [7] - - -# Test a field that doesn’t exist for extraction - we shouldn’t throw, just return empty column -def test_search_traces_with_non_existent_field(monkeypatch): - model = DefaultTestModel() - model.predict(2, 5) - - class MockMlflowClient: - def search_traces(self, *args, **kwargs): - return get_traces() - - monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) - - df = mlflow.search_traces( - extract_fields=[ - "predict.inputs.k", - "predict.inputs.x", - "predict.outputs", - "add_one_with_custom_name.inputs.z", - ] - ) - assert df["predict.inputs.k"].isnull().all() - assert df["predict.inputs.x"].tolist() == [2] - assert df["predict.outputs"].tolist() == [64] - assert df["add_one_with_custom_name.inputs.z"].tolist() == [7] - - -# Test experiment ID doesn’t need to be specified -def test_search_traces_without_experiment_id(monkeypatch): - model = DefaultTestModel() - model.predict(2, 5) - - class MockMlflowClient: - def search_traces(self, experiment_ids, *args, **kwargs): - assert experiment_ids == ["0"] - return get_traces() - - monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) - - mlflow.search_traces() - - -def test_search_traces_span_and_field_name_with_dot(): - with mlflow.start_span(name="span.name") as span: - span.set_inputs({"a.b": 0}) - span.set_outputs({"x.y": 1}) - - df = mlflow.search_traces( - extract_fields=[ - "`span.name`.inputs", - "`span.name`.inputs.`a.b`", - "`span.name`.outputs", - "`span.name`.outputs.`x.y`", - ] - ) - - assert df["span.name.inputs"].tolist() == [{"a.b": 0}] - assert df["span.name.inputs.a.b"].tolist() == [0] - assert df["span.name.outputs"].tolist() == [{"x.y": 1}] - assert df["span.name.outputs.x.y"].tolist() == [1] - - -def test_search_traces_with_span_name(monkeypatch): - class TestModel: - @mlflow.trace(name="span.llm") - def predict(self, x, y): - z = x + y - z = self.add_one(z) - z = mlflow.trace(self.square)(z) - return z # noqa: RET504 - - @mlflow.trace(span_type=SpanType.LLM, name="span.invalidname", attributes={"delta": 1}) - def add_one(self, z): - return z + 1 - - def square(self, t): - res = t**2 - time.sleep(0.1) - return res - - model = TestModel() - model.predict(2, 5) - - class MockMlflowClient: - def search_traces(self, experiment_ids, *args, **kwargs): - return get_traces() - - monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) - - -def test_search_traces_with_run_id(): - def _create_trace(name, tags=None): - with mlflow.start_span(name=name) as span: - for k, v in (tags or {}).items(): - mlflow.MlflowClient().set_trace_tag(request_id=span.request_id, key=k, value=v) - return span.request_id - - def _get_names(traces): - tags = traces["tags"].tolist() - return [tags[i].get(TraceTagKey.TRACE_NAME) for i in range(len(tags))] - - with mlflow.start_run() as run1: - _create_trace(name="tr-1") - _create_trace(name="tr-2", tags={"fruit": "apple"}) - - with mlflow.start_run() as run2: - _create_trace(name="tr-3") - _create_trace(name="tr-4", tags={"fruit": "banana"}) - _create_trace(name="tr-5", tags={"fruit": "apple"}) - - traces = mlflow.search_traces() - assert _get_names(traces) == ["tr-5", "tr-4", "tr-3", "tr-2", "tr-1"] - - traces = mlflow.search_traces(run_id=run1.info.run_id) - assert _get_names(traces) == ["tr-2", "tr-1"] - - traces = mlflow.search_traces( - run_id=run2.info.run_id, - filter_string="tag.fruit = 'apple'", - ) - assert _get_names(traces) == ["tr-5"] - - with pytest.raises(MlflowException, match="You cannot filter by run_id when it is already"): - mlflow.search_traces( - run_id=run2.info.run_id, - filter_string="metadata.mlflow.sourceRun = '123'", - ) - - @pytest.mark.parametrize( "extract_fields", [ @@ -1604,6 +1172,8 @@ def test_add_trace_raise_for_invalid_trace(): def test_add_trace_in_databricks_model_serving(mock_databricks_serving_with_tracing_env): + from mlflow.pyfunc.context import Context, set_prediction_context + # Mimic a remote service call that returns a trace as a part of the response def dummy_remote_call(): return {"prediction": 1, "trace": _SAMPLE_REMOTE_TRACE} diff --git a/tests/tracing/test_search.py b/tests/tracing/test_search.py new file mode 100644 index 0000000000000..49095da774d08 --- /dev/null +++ b/tests/tracing/test_search.py @@ -0,0 +1,474 @@ +import pytest + +import time +from unittest import mock + +import mlflow +from mlflow.entities import ( + SpanEvent, + SpanStatusCode, + SpanType, + Trace, + TraceData, + TraceInfo, +) +from mlflow.store.entities.paged_list import PagedList +from mlflow.store.tracking import SEARCH_TRACES_DEFAULT_MAX_RESULTS + +from tests.tracing.helper import create_test_trace_info, get_traces + + +class DefaultTestModel: + @mlflow.trace() + def predict(self, x, y): + z = x + y + z = self.add_one(z) + z = mlflow.trace(self.square)(z) + return z # noqa: RET504 + + @mlflow.trace(span_type=SpanType.LLM, name="add_one_with_custom_name", attributes={"delta": 1}) + def add_one(self, z): + return z + 1 + + def square(self, t): + res = t**2 + time.sleep(0.1) + return res + + +@pytest.fixture +def mock_client(): + client = mock.MagicMock() + with mock.patch("mlflow.tracing.fluent.MlflowClient", return_value=client): + yield client + + +def test_test_search_traces_empty(mock_client): + mock_client.search_traces.return_value = PagedList([], token=None) + + traces = mlflow.search_traces() + assert traces.empty + + default_columns = Trace.pandas_dataframe_columns() + assert traces.columns.tolist() == default_columns + + traces = mlflow.search_traces(extract_fields=["foo.inputs.bar"]) + assert traces.columns.tolist() == [*default_columns, "foo.inputs.bar"] + + mock_client.search_traces.assert_called() + + +def test_search_traces(mock_client): + mock_client.search_traces.return_value = PagedList( + [ + Trace( + info=create_test_trace_info(f"tr-{i}"), + data=TraceData([], "", ""), + ) + for i in range(10) + ], + token=None, + ) + + traces = mlflow.search_traces( + experiment_ids=["1"], + filter_string="name = 'foo'", + max_results=10, + order_by=["timestamp DESC"], + ) + + assert len(traces) == 10 + mock_client.search_traces.assert_called_once_with( + experiment_ids=["1"], + run_id=None, + filter_string="name = 'foo'", + max_results=10, + order_by=["timestamp DESC"], + page_token=None, + model_id=None, + ) + + +def test_search_traces_with_pagination(mock_client): + traces = [ + Trace( + info=create_test_trace_info(f"tr-{i}"), + data=TraceData([], "", ""), + ) + for i in range(30) + ] + + mock_client.search_traces.side_effect = [ + PagedList(traces[:10], token="token-1"), + PagedList(traces[10:20], token="token-2"), + PagedList(traces[20:], token=None), + ] + + traces = mlflow.search_traces(experiment_ids=["1"]) + + assert len(traces) == 30 + common_args = { + "experiment_ids": ["1"], + "run_id": None, + "max_results": SEARCH_TRACES_DEFAULT_MAX_RESULTS, + "filter_string": None, + "order_by": None, + } + mock_client.search_traces.assert_has_calls( + [ + mock.call(**common_args, page_token=None, model_id=None), + mock.call(**common_args, page_token="token-1", model_id=None), + mock.call(**common_args, page_token="token-2", model_id=None), + ] + ) + + +def test_search_traces_with_default_experiment_id(mock_client): + mock_client.search_traces.return_value = PagedList([], token=None) + with mock.patch("mlflow.tracing.fluent._get_experiment_id", return_value="123"): + mlflow.search_traces() + + mock_client.search_traces.assert_called_once_with( + experiment_ids=["123"], + run_id=None, + filter_string=None, + max_results=SEARCH_TRACES_DEFAULT_MAX_RESULTS, + order_by=None, + page_token=None, + model_id=None, + ) + + +def test_search_traces_yields_expected_dataframe_contents(monkeypatch): + model = DefaultTestModel() + client = mlflow.MlflowClient() + expected_traces = [] + for _ in range(10): + model.predict(2, 5) + time.sleep(0.1) + + # The in-memory trace returned from get_last_active_trace() is not guaranteed to be + # exactly same as the trace stored in the backend (e.g., tags created by the backend). + # Therefore, we fetch the trace from the backend to compare the results. + trace = client.get_trace(mlflow.get_last_active_trace().info.request_id) + expected_traces.append(trace) + + df = mlflow.search_traces(max_results=10, order_by=["timestamp ASC"]) + assert df.columns.tolist() == [ + "request_id", + "trace", + "timestamp_ms", + "status", + "execution_time_ms", + "request", + "response", + "request_metadata", + "spans", + "tags", + ] + for idx, trace in enumerate(expected_traces): + assert df.iloc[idx].request_id == trace.info.request_id + assert df.iloc[idx].trace.info.request_id == trace.info.request_id + assert df.iloc[idx].timestamp_ms == trace.info.timestamp_ms + assert df.iloc[idx].status == trace.info.status + assert df.iloc[idx].execution_time_ms == trace.info.execution_time_ms + assert df.iloc[idx].request == json.loads(trace.data.request) + assert df.iloc[idx].response == json.loads(trace.data.response) + assert df.iloc[idx].request_metadata == trace.info.request_metadata + assert df.iloc[idx].spans == [s.to_dict() for s in trace.data.spans] + assert df.iloc[idx].tags == trace.info.tags + + +def test_search_traces_handles_missing_response_tags_and_metadata(monkeypatch): + class MockMlflowClient: + def search_traces(self, *args, **kwargs): + return [ + Trace( + info=TraceInfo( + request_id=5, + experiment_id="test", + timestamp_ms=1, + execution_time_ms=2, + status=TraceStatus.OK, + ), + data=TraceData( + spans=[], + request="request", + # Response is missing + ), + ) + ] + + monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) + + df = mlflow.search_traces() + assert df["response"].isnull().all() + assert df["tags"].tolist() == [{}] + assert df["request_metadata"].tolist() == [{}] + + +def test_search_traces_extracts_fields_as_expected(monkeypatch): + model = DefaultTestModel() + model.predict(2, 5) + + class MockMlflowClient: + def search_traces(self, *args, **kwargs): + return get_traces() + + monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) + + df = mlflow.search_traces( + extract_fields=["predict.inputs.x", "predict.outputs", "add_one_with_custom_name.inputs.z"] + ) + assert df["predict.inputs.x"].tolist() == [2] + assert df["predict.outputs"].tolist() == [64] + assert df["add_one_with_custom_name.inputs.z"].tolist() == [7] + + +# Test cases should cover case where there are no spans at all +def test_search_traces_with_no_spans(monkeypatch): + class MockMlflowClient: + def search_traces(self, *args, **kwargs): + return [] + + monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) + + df = mlflow.search_traces() + assert df.empty + + +# no spans have the input or output with name, +# some span has an input but we’re looking for output, +def test_search_traces_with_input_and_no_output(monkeypatch): + with mlflow.start_span(name="with_input_and_no_output") as span: + span.set_inputs({"a": 1}) + + class MockMlflowClient: + def search_traces(self, *args, **kwargs): + return get_traces() + + monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) + + df = mlflow.search_traces( + extract_fields=["with_input_and_no_output.inputs.a", "with_input_and_no_output.outputs"] + ) + assert df["with_input_and_no_output.inputs.a"].tolist() == [1] + assert df["with_input_and_no_output.outputs"].isnull().all() + + +# Test case where span content is invalid +def test_search_traces_with_invalid_span_content(monkeypatch): + class MockMlflowClient: + def search_traces(self, *args, **kwargs): + # Invalid span content + return [ + Trace( + info=TraceInfo( + request_id=5, + experiment_id="test", + timestamp_ms=1, + execution_time_ms=2, + status=TraceStatus.OK, + ), + data=TraceData(spans=[None], request="request", response="response"), + ) + ] + + monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) + + with pytest.raises(AttributeError, match="NoneType"): + mlflow.search_traces() + + +# Test case where span inputs / outputs aren’t dict +def test_search_traces_with_non_dict_span_inputs_outputs(monkeypatch): + with mlflow.start_span(name="non_dict_span") as span: + span.set_inputs(["a", "b"]) + span.set_outputs([1, 2, 3]) + + class MockMlflowClient: + def search_traces(self, *args, **kwargs): + return get_traces() + + monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) + + df = mlflow.search_traces( + extract_fields=["non_dict_span.inputs", "non_dict_span.outputs", "non_dict_span.inputs.x"] + ) + assert df["non_dict_span.inputs"].tolist() == [["a", "b"]] + assert df["non_dict_span.outputs"].tolist() == [[1, 2, 3]] + assert df["non_dict_span.inputs.x"].isnull().all() + + +# Test case where there are multiple spans with the same name +def test_search_traces_with_multiple_spans_with_same_name(monkeypatch): + class TestModel: + @mlflow.trace(name="duplicate_name") + def predict(self, x, y): + z = x + y + z = self.add_one(z) + z = mlflow.trace(self.square)(z) + return z # noqa: RET504 + + @mlflow.trace(span_type=SpanType.LLM, name="duplicate_name", attributes={"delta": 1}) + def add_one(self, z): + return z + 1 + + def square(self, t): + res = t**2 + time.sleep(0.1) + return res + + model = TestModel() + model.predict(2, 5) + + class MockMlflowClient: + def search_traces(self, *args, **kwargs): + return get_traces() + + monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) + + df = mlflow.search_traces( + extract_fields=[ + "duplicate_name.inputs.y", + "duplicate_name.inputs.x", + "duplicate_name.inputs.z", + "duplicate_name_1.inputs.x", + "duplicate_name_1.inputs.y", + "duplicate_name_2.inputs.z", + ] + ) + # Duplicate spans would all be null + assert df["duplicate_name.inputs.y"].isnull().all() + assert df["duplicate_name.inputs.x"].isnull().all() + assert df["duplicate_name.inputs.z"].isnull().all() + assert df["duplicate_name_1.inputs.x"].tolist() == [2] + assert df["duplicate_name_1.inputs.y"].tolist() == [5] + assert df["duplicate_name_2.inputs.z"].tolist() == [7] + + +# Test a field that doesn’t exist for extraction - we shouldn’t throw, just return empty column +def test_search_traces_with_non_existent_field(monkeypatch): + model = DefaultTestModel() + model.predict(2, 5) + + class MockMlflowClient: + def search_traces(self, *args, **kwargs): + return get_traces() + + monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) + + df = mlflow.search_traces( + extract_fields=[ + "predict.inputs.k", + "predict.inputs.x", + "predict.outputs", + "add_one_with_custom_name.inputs.z", + ] + ) + assert df["predict.inputs.k"].isnull().all() + assert df["predict.inputs.x"].tolist() == [2] + assert df["predict.outputs"].tolist() == [64] + assert df["add_one_with_custom_name.inputs.z"].tolist() == [7] + + +# Test experiment ID doesn’t need to be specified +def test_search_traces_without_experiment_id(monkeypatch): + model = DefaultTestModel() + model.predict(2, 5) + + class MockMlflowClient: + def search_traces(self, experiment_ids, *args, **kwargs): + assert experiment_ids == ["0"] + return get_traces() + + monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) + + mlflow.search_traces() + + +def test_search_traces_span_and_field_name_with_dot(): + with mlflow.start_span(name="span.name") as span: + span.set_inputs({"a.b": 0}) + span.set_outputs({"x.y": 1}) + + df = mlflow.search_traces( + extract_fields=[ + "`span.name`.inputs", + "`span.name`.inputs.`a.b`", + "`span.name`.outputs", + "`span.name`.outputs.`x.y`", + ] + ) + + assert df["span.name.inputs"].tolist() == [{"a.b": 0}] + assert df["span.name.inputs.a.b"].tolist() == [0] + assert df["span.name.outputs"].tolist() == [{"x.y": 1}] + assert df["span.name.outputs.x.y"].tolist() == [1] + + +def test_search_traces_with_span_name(monkeypatch): + class TestModel: + @mlflow.trace(name="span.llm") + def predict(self, x, y): + z = x + y + z = self.add_one(z) + z = mlflow.trace(self.square)(z) + return z # noqa: RET504 + + @mlflow.trace(span_type=SpanType.LLM, name="span.invalidname", attributes={"delta": 1}) + def add_one(self, z): + return z + 1 + + def square(self, t): + res = t**2 + time.sleep(0.1) + return res + + model = TestModel() + model.predict(2, 5) + + class MockMlflowClient: + def search_traces(self, experiment_ids, *args, **kwargs): + return get_traces() + + monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient) + + +def test_search_traces_with_run_id(): + def _create_trace(name, tags=None): + with mlflow.start_span(name=name) as span: + for k, v in (tags or {}).items(): + mlflow.MlflowClient().set_trace_tag(request_id=span.request_id, key=k, value=v) + return span.request_id + + def _get_names(traces): + tags = traces["tags"].tolist() + return [tags[i].get(TraceTagKey.TRACE_NAME) for i in range(len(tags))] + + with mlflow.start_run() as run1: + _create_trace(name="tr-1") + _create_trace(name="tr-2", tags={"fruit": "apple"}) + + with mlflow.start_run() as run2: + _create_trace(name="tr-3") + _create_trace(name="tr-4", tags={"fruit": "banana"}) + _create_trace(name="tr-5", tags={"fruit": "apple"}) + + traces = mlflow.search_traces() + assert _get_names(traces) == ["tr-5", "tr-4", "tr-3", "tr-2", "tr-1"] + + traces = mlflow.search_traces(run_id=run1.info.run_id) + assert _get_names(traces) == ["tr-2", "tr-1"] + + traces = mlflow.search_traces( + run_id=run2.info.run_id, + filter_string="tag.fruit = 'apple'", + ) + assert _get_names(traces) == ["tr-5"] + + with pytest.raises(MlflowException, match="You cannot filter by run_id when it is already"): + mlflow.search_traces( + run_id=run2.info.run_id, + filter_string="metadata.mlflow.sourceRun = '123'", + ) diff --git a/tests/tracing/export/test_databricks_agent_exporter.py b/tests/tracing/trace_logging/export/test_databricks_agent_exporter.py similarity index 100% rename from tests/tracing/export/test_databricks_agent_exporter.py rename to tests/tracing/trace_logging/export/test_databricks_agent_exporter.py diff --git a/tests/tracing/export/test_inference_table_exporter.py b/tests/tracing/trace_logging/export/test_inference_table_exporter.py similarity index 100% rename from tests/tracing/export/test_inference_table_exporter.py rename to tests/tracing/trace_logging/export/test_inference_table_exporter.py diff --git a/tests/tracing/export/test_mlflow_exporter.py b/tests/tracing/trace_logging/export/test_mlflow_exporter.py similarity index 100% rename from tests/tracing/export/test_mlflow_exporter.py rename to tests/tracing/trace_logging/export/test_mlflow_exporter.py diff --git a/tests/tracing/processor/test_inference_table_processor.py b/tests/tracing/trace_logging/processor/test_inference_table_processor.py similarity index 98% rename from tests/tracing/processor/test_inference_table_processor.py rename to tests/tracing/trace_logging/processor/test_inference_table_processor.py index ddc5c716c3b15..c1e43a1a512cd 100644 --- a/tests/tracing/processor/test_inference_table_processor.py +++ b/tests/tracing/trace_logging/processor/test_inference_table_processor.py @@ -5,7 +5,6 @@ from mlflow.entities.span import LiveSpan from mlflow.entities.trace_status import TraceStatus -from mlflow.pyfunc.context import Context, set_prediction_context from mlflow.tracing.constant import SpanAttributeKey from mlflow.tracing.processor.inference_table import ( _HEADER_REQUEST_ID_KEY, @@ -21,6 +20,8 @@ @pytest.mark.parametrize("context_type", ["mlflow", "flask"]) def test_on_start(context_type): + from mlflow.pyfunc.context import Context, set_prediction_context + # Root span should create a new trace on start span = create_mock_otel_span( trace_id=_TRACE_ID, span_id=1, parent_id=None, start_time=5_000_000 diff --git a/tests/tracing/processor/test_mlflow_processor.py b/tests/tracing/trace_logging/processor/test_mlflow_processor.py similarity index 99% rename from tests/tracing/processor/test_mlflow_processor.py rename to tests/tracing/trace_logging/processor/test_mlflow_processor.py index c44b5a66a0ae6..91d26c3b48e70 100644 --- a/tests/tracing/processor/test_mlflow_processor.py +++ b/tests/tracing/trace_logging/processor/test_mlflow_processor.py @@ -8,7 +8,6 @@ from mlflow.entities.span import LiveSpan from mlflow.entities.trace_status import TraceStatus from mlflow.environment_variables import MLFLOW_TRACKING_USERNAME -from mlflow.pyfunc.context import Context, set_prediction_context from mlflow.tracing.constant import ( TRACE_SCHEMA_VERSION, TRACE_SCHEMA_VERSION_KEY, @@ -125,6 +124,8 @@ def test_on_start_with_experiment_id(monkeypatch): def test_on_start_during_model_evaluation(): + from mlflow.pyfunc.context import Context, set_prediction_context + # Root span should create a new trace on start span = create_mock_otel_span(trace_id=_TRACE_ID, span_id=1) mock_client = mock.MagicMock() diff --git a/tests/tracing/test_provider.py b/tests/tracing/trace_logging/test_provider.py similarity index 100% rename from tests/tracing/test_provider.py rename to tests/tracing/trace_logging/test_provider.py diff --git a/tests/tracing/test_trace_manager.py b/tests/tracing/trace_logging/test_trace_manager.py similarity index 100% rename from tests/tracing/test_trace_manager.py rename to tests/tracing/trace_logging/test_trace_manager.py diff --git a/tests/tracing/utils/test_otlp.py b/tests/tracing/utils/test_otlp.py index b2a82769542dc..5a19480114bc9 100644 --- a/tests/tracing/utils/test_otlp.py +++ b/tests/tracing/utils/test_otlp.py @@ -1,4 +1,14 @@ import pytest + +pytest.importorskip( + "opentelemetry.exporter", + reason=( + "These tests require opentelemetry-exporter-otlp-proto-http and " + "opentelemetry-exporter-otlp-proto-grpc. Run this suite separately from " + "other tracing core tests with these dependencies installed." + ), +) + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpExporter diff --git a/tests/tracing/utils/test_search.py b/tests/tracing/utils/test_search_parsing.py similarity index 100% rename from tests/tracing/utils/test_search.py rename to tests/tracing/utils/test_search_parsing.py diff --git a/tests/tracing/utils/test_timeout.py b/tests/tracing/utils/test_timeout.py index f94585655f40e..75286b063ce99 100644 --- a/tests/tracing/utils/test_timeout.py +++ b/tests/tracing/utils/test_timeout.py @@ -7,7 +7,6 @@ import mlflow from mlflow.entities.span_event import SpanEvent from mlflow.entities.span_status import SpanStatusCode -from mlflow.pyfunc.context import Context, set_prediction_context from mlflow.tracing.export.inference_table import _TRACE_BUFFER, pop_trace from mlflow.tracing.trace_manager import _Trace from mlflow.tracing.utils.timeout import MlflowTraceTimeoutCache @@ -100,6 +99,8 @@ def test_trace_halted_after_timeout(monkeypatch): def test_trace_halted_after_timeout_in_model_serving( monkeypatch, mock_databricks_serving_with_tracing_env ): + from mlflow.pyfunc.context import Context, set_prediction_context + monkeypatch.setenv("MLFLOW_TRACE_TIMEOUT_SECONDS", "3") # Simulate model serving env where multiple requests are processed concurrently diff --git a/tests/tracing/utils/test_utils.py b/tests/tracing/utils/test_utils.py index b46662c3374f0..38e563f4d9cfd 100644 --- a/tests/tracing/utils/test_utils.py +++ b/tests/tracing/utils/test_utils.py @@ -1,5 +1,4 @@ import pytest -from pydantic import ValidationError import mlflow from mlflow.entities import LiveSpan @@ -135,6 +134,11 @@ def test_set_span_chat_messages_append(): def test_set_chat_messages_validation(): + ValidationError = ( + pytest.importorskip("pydantic", reason="pydantic is required for chat message")\ + .ValidationError + ) + messages = [{"invalid_field": "user", "content": "hello"}] @mlflow.trace(span_type=SpanType.CHAT_MODEL) @@ -148,6 +152,11 @@ def dummy_call(messages): def test_set_chat_tools_validation(): + ValidationError = ( + pytest.importorskip("pydantic", reason="pydantic is required for chat message")\ + .ValidationError + ) + tools = [ { "type": "unsupported_function",