From bcd1ba35a9b38faba84c8ed0864a0bb274c6c99f Mon Sep 17 00:00:00 2001 From: Mohamad Arabi Date: Wed, 9 Dec 2020 16:24:29 -0800 Subject: [PATCH 01/22] Add public API for autologging integration configs Signed-off-by: Mohamad Arabi --- mlflow/utils/autologging_utils.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/mlflow/utils/autologging_utils.py b/mlflow/utils/autologging_utils.py index 3faf27446d734..e1f3beb112190 100644 --- a/mlflow/utils/autologging_utils.py +++ b/mlflow/utils/autologging_utils.py @@ -316,6 +316,30 @@ def autolog(**kwargs): return wrapper +def autologging_integration_config(flavor_name, config_key, default_value=None): + """ + Returns a desired config value for a specified autologging integration. + Returns `None` if specified `flavor_name` has no recorded configs. + If `config_key` is not set on the config object, default vlaue is returned. + + :param flavor_name: An autologging integration flavor name. + :param config_key: The key for the desired config value. + :param default_value: The default_value to return + """ + config = AUTOLOGGING_INTEGRATIONS.get(flavor_name) + if config is not None: + return config.get(config_key, default_value) + + +def autologging_is_disabled(flavor_name): + """ + Returns a boolean flag of whether the autologging integration is disabled. + + :param flavor_name: An autologging integration flavor name. + """ + return autologging_integration_config(flavor_name, "disable", False) + + def _is_testing(): """ Indicates whether or not autologging functionality is running in test mode (as determined @@ -550,10 +574,9 @@ def safe_patch_function(*args, **kwargs): """ original = gorilla.get_original_attribute(destination, function_name) - config = AUTOLOGGING_INTEGRATIONS.get(autologging_integration) # If the autologging integration associated with this patch is disabled, # call the original function and return - if config is not None and config.get("disable", False): + if autologging_is_disabled(autologging_integration): return original(*args, **kwargs) # Whether or not the original / underlying function has been called during the From 1c72efe020b647af3184d40d85b51b5fef3709c4 Mon Sep 17 00:00:00 2001 From: Harutaka Kawamura Date: Fri, 11 Dec 2020 10:20:28 +0900 Subject: [PATCH 02/22] Prefix with Lightning (#3809) Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- tests/pytorch/test_pytorch_autolog.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_pytorch_autolog.py b/tests/pytorch/test_pytorch_autolog.py index 46c8975e7602d..98cc84d971c37 100644 --- a/tests/pytorch/test_pytorch_autolog.py +++ b/tests/pytorch/test_pytorch_autolog.py @@ -1,3 +1,5 @@ +from distutils.version import LooseVersion + import pytest import pytorch_lightning as pl import torch @@ -59,7 +61,10 @@ def test_pytorch_autolog_logs_expected_data(pytorch_model): # Testing optimizer parameters are logged assert "optimizer_name" in data.params - assert data.params["optimizer_name"] == "Adam" + + # In pytorch-lightning >= 1.1.0, optimizer names are prefixed with "Lightning". + prefix = "Lightning" if LooseVersion(pl.__version__) >= LooseVersion("1.1.0") else "" + assert data.params["optimizer_name"] == prefix + "Adam" # Testing model_summary.txt is saved client = mlflow.tracking.MlflowClient() From 5a264466a9e84ec72ec17d7d290dddb707114e75 Mon Sep 17 00:00:00 2001 From: Harutaka Kawamura Date: Fri, 11 Dec 2020 10:21:25 +0900 Subject: [PATCH 03/22] [HOT FIX] skip `test` job if matrix ends up being empty in cross version tests (#3800) * skip if the matrix is empty Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * set is_matrix_empty Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * fix syntax error Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * minor comment fix Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- .github/workflows/cross-version-tests.yml | 2 ++ dev/set_matrix.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/.github/workflows/cross-version-tests.yml b/.github/workflows/cross-version-tests.yml index c1dca09262dca..6549861f82b4e 100644 --- a/.github/workflows/cross-version-tests.yml +++ b/.github/workflows/cross-version-tests.yml @@ -13,6 +13,7 @@ jobs: runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} + is_matrix_empty: ${{ steps.set-matrix.outputs.is_matrix_empty }} steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 @@ -40,6 +41,7 @@ jobs: fi test: needs: set-matrix + if: ${{ needs.set-matrix.outputs.is_matrix_empty == 'false' }} runs-on: ubuntu-latest strategy: fail-fast: false diff --git a/dev/set_matrix.py b/dev/set_matrix.py index 974a7b2825526..0954448315f28 100644 --- a/dev/set_matrix.py +++ b/dev/set_matrix.py @@ -488,6 +488,10 @@ def main(): # Note that this actually doesn't print anything to the console. print("::set-output name=matrix::{}".format(json.dumps(matrix))) + # Set a flag that indicates whether or not the matrix is empty. If this flag is 'false', + # skip the subsequent jobs. + print("::set-output name=is_matrix_empty::{}".format("false" if job_names else "true")) + if __name__ == "__main__": main() From f43e5b69655821707ed2431a53cc2a0be8217185 Mon Sep 17 00:00:00 2001 From: Harutaka Kawamura Date: Fri, 11 Dec 2020 12:13:43 +0900 Subject: [PATCH 04/22] Allow using post releases in the cross version tests (#3807) * Fix for xgboost 1.3.0 Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * do not include 1.3.0 since it has been removed Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * re-run all the tests if set_matrix contains changes Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * nit Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * fix regexp Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * add test case Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Refactor using packaging Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * add packaging Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * nit Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- .github/workflows/cross-version-tests.yml | 2 +- dev/set_matrix.py | 65 ++++++++--------------- 2 files changed, 22 insertions(+), 45 deletions(-) diff --git a/.github/workflows/cross-version-tests.yml b/.github/workflows/cross-version-tests.yml index 6549861f82b4e..aeee9d7fb893e 100644 --- a/.github/workflows/cross-version-tests.yml +++ b/.github/workflows/cross-version-tests.yml @@ -21,7 +21,7 @@ jobs: python-version: "3.6" - name: Install dependencies run: | - pip install pyyaml pytest + pip install packaging pyyaml pytest - name: Test set_matrix.py run: | pytest dev/set_matrix.py --doctest-modules --verbose diff --git a/dev/set_matrix.py b/dev/set_matrix.py index 0954448315f28..1d437bf224063 100644 --- a/dev/set_matrix.py +++ b/dev/set_matrix.py @@ -34,7 +34,7 @@ """ import argparse -from distutils.version import LooseVersion +from packaging.version import Version import json import operator import os @@ -104,35 +104,6 @@ def get_released_versions(package_name): return versions -def get_major_version(ver): - """ - Examples - -------- - >>> get_major_version("1.2.3") - 1 - """ - return LooseVersion(ver).version[0] - - -def is_final_release(ver): - """ - Returns True if the given version matches PEP440's final release scheme. - - Examples - -------- - >>> is_final_release("0.1") - True - >>> is_final_release("0.23.0") - True - >>> is_final_release("0.4.0a1") - False - >>> is_final_release("0.5.0rc") - False - """ - # Ref.: https://www.python.org/dev/peps/pep-0440/#final-releases - return re.search(r"^\d+(\.\d+)+$", ver) is not None - - def select_latest_micro_versions(versions): """ Selects the latest micro version in each minor version. @@ -155,10 +126,10 @@ def select_latest_micro_versions(versions): for ver, _ in sorted( versions.items(), # Sort by (minor_version, upload_time) in descending order - key=lambda x: (LooseVersion(x[0]).version[:2], x[1]), + key=lambda x: (Version(x[0]).release[:2], x[1]), reverse=True, ): - minor_ver = tuple(LooseVersion(ver).version[:2]) # A set doesn't accept a list + minor_ver = Version(ver).release[:2] if minor_ver not in seen_minors: seen_minors.add(minor_ver) @@ -171,9 +142,10 @@ def filter_versions(versions, min_ver, max_ver, excludes=None): """ Filter versions that satisfy the following conditions: - 1. is newer than or equal to `min_ver` - 2. shares the same major version as `max_ver` or `min_ver` - 3. (Optional) is not in `excludes` + 1. is a final or post release that PEP 440 defines + 2. is newer than or equal to `min_ver` + 3. shares the same major version as `max_ver` or `min_ver` + 4. (Optional) is not in `excludes` Examples -------- @@ -198,12 +170,16 @@ def filter_versions(versions, min_ver, max_ver, excludes=None): assert max_ver in versions assert all(v in versions for v in excludes) - versions = {v: t for v, t in versions.items() if v not in excludes} - versions = {v: t for v, t in versions.items() if is_final_release(v)} + versions = {Version(v): t for v, t in versions.items() if v not in excludes} + + def _is_final_or_post_release(v): + # final release: https://www.python.org/dev/peps/pep-0440/#final-releases + # post release: https://www.python.org/dev/peps/pep-0440/#post-releases + return (v.base_version == v.public) or (v.is_postrelease) - max_major = get_major_version(max_ver) - versions = {v: t for v, t in versions.items() if get_major_version(v) <= max_major} - versions = {v: t for v, t in versions.items() if LooseVersion(v) >= LooseVersion(min_ver)} + versions = {v: t for v, t in versions.items() if _is_final_or_post_release(v)} + versions = {v: t for v, t in versions.items() if v.major <= Version(max_ver).major} + versions = {str(v): t for v, t in versions.items() if v >= Version(min_ver)} return versions @@ -324,8 +300,7 @@ def process_requirements(requirements, version=None): op_and_ver_pairs = map(get_operator_and_version, ver_spec.split(",")) match_all = all( comp_op( - LooseVersion(version), - LooseVersion(dev_numeric if req_ver == DEV_VERSION else req_ver), + Version(version), Version(dev_numeric if req_ver == DEV_VERSION else req_ver), ) for comp_op, req_ver in op_and_ver_pairs ) @@ -475,7 +450,9 @@ def main(): ) diff_flavor = set(filter(lambda x: x["flavor"] in changed_flavors, matrix)) - include = sorted(diff_config.union(diff_flavor), key=lambda x: x["job_name"]) + # If this file contains changes, re-run all the tests, otherwise re-run the affected tests. + include = matrix if (__file__ in changed_files) else diff_config.union(diff_flavor) + include = sorted(include, key=lambda x: x["job_name"]) job_names = [x["job_name"] for x in include] matrix = {"job_name": job_names, "include": include} @@ -488,7 +465,7 @@ def main(): # Note that this actually doesn't print anything to the console. print("::set-output name=matrix::{}".format(json.dumps(matrix))) - # Set a flag that indicates whether or not the matrix is empty. If this flag is 'false', + # Set a flag that indicates whether or not the matrix is empty. If this flag is 'true', # skip the subsequent jobs. print("::set-output name=is_matrix_empty::{}".format("false" if job_names else "true")) From 3ce637f0131f6d946c64f82cf6198f052d4a16a8 Mon Sep 17 00:00:00 2001 From: Mohamad Arabi Date: Fri, 11 Dec 2020 11:12:33 -0800 Subject: [PATCH 05/22] Add support for disabling spark autologging Signed-off-by: Mohamad Arabi --- mlflow/_spark_autologging.py | 11 +- mlflow/spark.py | 6 +- .../test_spark_disable_autologging.py | 140 ++++++++++++++++++ tests/spark_autologging/utils.py | 4 + 4 files changed, 159 insertions(+), 2 deletions(-) create mode 100644 tests/spark_autologging/test_spark_disable_autologging.py diff --git a/mlflow/_spark_autologging.py b/mlflow/_spark_autologging.py index 1ddba8643695a..477a61e3e782a 100644 --- a/mlflow/_spark_autologging.py +++ b/mlflow/_spark_autologging.py @@ -15,10 +15,14 @@ from mlflow.tracking.client import MlflowClient from mlflow.tracking.context.abstract_context import RunContextProvider from mlflow.utils import gorilla -from mlflow.utils.autologging_utils import wrap_patch +from mlflow.utils.autologging_utils import ( + wrap_patch, + autologging_is_disabled, +) _JAVA_PACKAGE = "org.mlflow.spark.autologging" _SPARK_TABLE_INFO_TAG_NAME = "sparkDatasourceInfo" +FLAVOR_NAME = "spark" _logger = logging.getLogger(__name__) _lock = threading.Lock() @@ -217,6 +221,8 @@ def _notify(self, path, version, data_format): Method called by Scala SparkListener to propagate datasource read events to the current Python process """ + if autologging_is_disabled(FLAVOR_NAME): + return # If there's an active run, simply set the tag on it # Note that there's a TOCTOU race condition here - active_run() here can actually throw # if the main thread happens to end the run & pop from the active run stack after we check @@ -248,6 +254,9 @@ def in_context(self): return True def tags(self): + # if autologging is disabled, then short circuit `tags()` and return empty dict. + if autologging_is_disabled(FLAVOR_NAME): + return {} with _lock: global _table_infos seen = set() diff --git a/mlflow/spark.py b/mlflow/spark.py index 6d1341d98f5d8..bfc24d7b9f656 100644 --- a/mlflow/spark.py +++ b/mlflow/spark.py @@ -41,6 +41,7 @@ from mlflow.utils.model_utils import _get_flavor_configuration_from_uri from mlflow.utils.annotations import experimental from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS +from mlflow.utils.autologging_utils import autologging_integration FLAVOR_NAME = "spark" @@ -627,7 +628,8 @@ def predict(self, pandas_df): @experimental -def autolog(): +@autologging_integration(FLAVOR_NAME) +def autolog(disable=False): # pylint: disable=unused-argument """ Enables automatic logging of Spark datasource paths, versions (if applicable), and formats when they are read. This method is not threadsafe and assumes a @@ -682,6 +684,8 @@ def autolog(): # next-created MLflow run if no run is currently active with mlflow.start_run() as active_run: pandas_df = loaded_df.toPandas() + + :param disable: Whether to enable or disable autologging. """ from mlflow import _spark_autologging diff --git a/tests/spark_autologging/test_spark_disable_autologging.py b/tests/spark_autologging/test_spark_disable_autologging.py new file mode 100644 index 0000000000000..b5283dd63bc56 --- /dev/null +++ b/tests/spark_autologging/test_spark_disable_autologging.py @@ -0,0 +1,140 @@ +import time + +import pytest + +import mlflow +import mlflow.spark + +from tests.tracking.test_rest_tracking import BACKEND_URIS +from tests.tracking.test_rest_tracking import tracking_server_uri # pylint: disable=unused-import +from tests.tracking.test_rest_tracking import mlflow_client # pylint: disable=unused-import +from tests.spark_autologging.utils import ( + _assert_spark_data_logged, + _assert_spark_data_not_logged, +) +from tests.spark_autologging.utils import spark_session # pylint: disable=unused-import +from tests.spark_autologging.utils import format_to_file_path # pylint: disable=unused-import +from tests.spark_autologging.utils import data_format # pylint: disable=unused-import +from tests.spark_autologging.utils import file_path # pylint: disable=unused-import + + +def pytest_generate_tests(metafunc): + """ + Automatically parametrize each each fixture/test that depends on `backend_store_uri` with the + list of backend store URIs. + """ + if "backend_store_uri" in metafunc.fixturenames: + metafunc.parametrize("backend_store_uri", BACKEND_URIS) + + +@pytest.fixture() +def http_tracking_uri_mock(): + mlflow.set_tracking_uri("http://some-cool-uri") + yield + mlflow.set_tracking_uri(None) + + +def _get_expected_table_info_row(path, data_format, version=None): + expected_path = "file:%s" % path + if version is None: + return "path={path},format={format}".format(path=expected_path, format=data_format) + return "path={path},version={version},format={format}".format( + path=expected_path, version=version, format=data_format + ) + + +# Note that the following tests run one-after-the-other and operate on the SAME spark_session +# (it is not reset between tests) + + +@pytest.mark.large +def test_autologging_disabled_logging_datasource_with_different_formats( + spark_session, format_to_file_path +): + mlflow.spark.autolog(disable=True) + for data_format, file_path in format_to_file_path.items(): + df = ( + spark_session.read.format(data_format) + .option("header", "true") + .option("inferSchema", "true") + .load(file_path) + ) + + with mlflow.start_run(): + run_id = mlflow.active_run().info.run_id + df.collect() + time.sleep(1) + run = mlflow.get_run(run_id) + _assert_spark_data_not_logged(run=run) + + +@pytest.mark.large +def test_autologging_disabled_logging_with_or_without_active_run( + spark_session, format_to_file_path +): + mlflow.spark.autolog(disable=True) + data_format = list(format_to_file_path.keys())[0] + file_path = format_to_file_path[data_format] + df = ( + spark_session.read.format(data_format) + .option("header", "true") + .option("inferSchema", "true") + .load(file_path) + ) + + # Reading data source before starting a run + df.filter("number1 > 0").collect() + df.limit(2).collect() + df.collect() + + # If there was any tag info collected it will be logged here + with mlflow.start_run(): + run_id = mlflow.active_run().info.run_id + time.sleep(1) + + # Confirm nothing was logged. + run = mlflow.get_run(run_id) + _assert_spark_data_not_logged(run=run) + + # Reading data source during an active run + with mlflow.start_run(): + run_id = mlflow.active_run().info.run_id + df.collect() + time.sleep(1) + run = mlflow.get_run(run_id) + _assert_spark_data_not_logged(run=run) + + +@pytest.mark.large +def test_autologging_disabled_then_enabled(spark_session, format_to_file_path): + mlflow.spark.autolog(disable=True) + data_format = list(format_to_file_path.keys())[0] + file_path = format_to_file_path[data_format] + df = ( + spark_session.read.format(data_format) + .option("header", "true") + .option("inferSchema", "true") + .load(file_path) + ) + # Logging is disabled here. + with mlflow.start_run(): + run_id = mlflow.active_run().info.run_id + df.collect() + time.sleep(1) + run = mlflow.get_run(run_id) + _assert_spark_data_not_logged(run=run) + + # Logging is enabled here. + mlflow.spark.autolog(disable=False) + with mlflow.start_run(): + run_id = mlflow.active_run().info.run_id + df.filter("number1 > 0").collect() + time.sleep(1) + run = mlflow.get_run(run_id) + _assert_spark_data_logged(run=run, path=file_path, data_format=data_format) + + +@pytest.mark.large +def test_enabling_autologging_does_not_throw_when_spark_hasnt_been_started(spark_session): + spark_session.stop() + mlflow.spark.autolog(disable=True) diff --git a/tests/spark_autologging/utils.py b/tests/spark_autologging/utils.py index fa7626916212f..f1e97e622cbfe 100644 --- a/tests/spark_autologging/utils.py +++ b/tests/spark_autologging/utils.py @@ -39,6 +39,10 @@ def _assert_spark_data_logged(run, path, data_format, version=None): assert table_info_tag == _get_expected_table_info_row(path, data_format, version) +def _assert_spark_data_not_logged(run): + assert _SPARK_TABLE_INFO_TAG_NAME not in run.data.tags + + def _get_or_create_spark_session(jars=None): jar_path = jars if jars is not None else _get_mlflow_spark_jar_path() return SparkSession.builder.config("spark.jars", jar_path).master("local[*]").getOrCreate() From 26da1b4eb9700016fc4df7e1ecf2f63adbe7ff84 Mon Sep 17 00:00:00 2001 From: Mohamad Arabi Date: Fri, 11 Dec 2020 11:23:53 -0800 Subject: [PATCH 06/22] Remove unused method Signed-off-by: Mohamad Arabi --- .../spark_autologging/test_spark_disable_autologging.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/spark_autologging/test_spark_disable_autologging.py b/tests/spark_autologging/test_spark_disable_autologging.py index b5283dd63bc56..7af3beac12102 100644 --- a/tests/spark_autologging/test_spark_disable_autologging.py +++ b/tests/spark_autologging/test_spark_disable_autologging.py @@ -34,15 +34,6 @@ def http_tracking_uri_mock(): mlflow.set_tracking_uri(None) -def _get_expected_table_info_row(path, data_format, version=None): - expected_path = "file:%s" % path - if version is None: - return "path={path},format={format}".format(path=expected_path, format=data_format) - return "path={path},version={version},format={format}".format( - path=expected_path, version=version, format=data_format - ) - - # Note that the following tests run one-after-the-other and operate on the SAME spark_session # (it is not reset between tests) From b2d077130ccfd70538ce4ed054f229e0aa43745a Mon Sep 17 00:00:00 2001 From: Mohamad Arabi Date: Fri, 11 Dec 2020 11:26:06 -0800 Subject: [PATCH 07/22] Remove unused method part 2 Signed-off-by: Mohamad Arabi --- .../test_spark_disable_autologging.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/tests/spark_autologging/test_spark_disable_autologging.py b/tests/spark_autologging/test_spark_disable_autologging.py index 7af3beac12102..97b22eb5f31e3 100644 --- a/tests/spark_autologging/test_spark_disable_autologging.py +++ b/tests/spark_autologging/test_spark_disable_autologging.py @@ -18,22 +18,6 @@ from tests.spark_autologging.utils import file_path # pylint: disable=unused-import -def pytest_generate_tests(metafunc): - """ - Automatically parametrize each each fixture/test that depends on `backend_store_uri` with the - list of backend store URIs. - """ - if "backend_store_uri" in metafunc.fixturenames: - metafunc.parametrize("backend_store_uri", BACKEND_URIS) - - -@pytest.fixture() -def http_tracking_uri_mock(): - mlflow.set_tracking_uri("http://some-cool-uri") - yield - mlflow.set_tracking_uri(None) - - # Note that the following tests run one-after-the-other and operate on the SAME spark_session # (it is not reset between tests) From fa21520d2d7c1b9c095910c0b1e2198f96ae4ff7 Mon Sep 17 00:00:00 2001 From: dbczumar <39497902+dbczumar@users.noreply.github.com> Date: Fri, 11 Dec 2020 17:17:42 -0800 Subject: [PATCH 08/22] Introduce utilities for autologging error tolerance / safety (#3682) * Safe Signed-off-by: Corey Zumar * Keras Signed-off-by: Corey Zumar * Lint Signed-off-by: Corey Zumar * TF Signed-off-by: Corey Zumar * Fixes Signed-off-by: Corey Zumar * Lint Signed-off-by: Corey Zumar * Some unit tests Signed-off-by: Corey Zumar * More unit tests Signed-off-by: Corey Zumar * Test coverage for safe_patch Signed-off-by: Corey Zumar * Add public API for autologging integration configs Signed-off-by: Mohamad Arabi Signed-off-by: Corey Zumar * Remove big comment Signed-off-by: Corey Zumar * Conf tests Signed-off-by: Corey Zumar * Tests Signed-off-by: Corey Zumar * Mark large Signed-off-by: Corey Zumar * Whitespace Signed-off-by: Corey Zumar * Blackspace Signed-off-by: Corey Zumar * Rename Signed-off-by: Corey Zumar * Simplify, will raise integrations as separate PR Signed-off-by: Corey Zumar * Lint Signed-off-by: Corey Zumar * Black Signed-off-by: Corey Zumar * Remove test_mode_off for now Signed-off-by: Corey Zumar * Support positional arguments Signed-off-by: Corey Zumar * Docstring fix Signed-off-by: Corey Zumar * use match instead of comparison to str(exc) Signed-off-by: Corey Zumar * Black Signed-off-by: Corey Zumar * Forward args Signed-off-by: Corey Zumar * Lint Signed-off-by: Corey Zumar * Try importing mock from unittest? Signed-off-by: Corey Zumar * Fix import mock in statsmodel Signed-off-by: Corey Zumar * Revert "Fix import mock in statsmodel" This reverts commit a81e8101a3a5da1962bfee22dee1006adf1fd728. Signed-off-by: Corey Zumar * Black Signed-off-by: Corey Zumar * Support tuple Signed-off-by: Corey Zumar * Address more comments Signed-off-by: Corey Zumar * Stop patching log_param Signed-off-by: Corey Zumar Co-authored-by: Mohamad Arabi --- mlflow/utils/autologging_utils.py | 497 ++++++++++- .../test_autologging_safety_unit.py | 835 ++++++++++++++++++ .../test_autologging_utils.py | 155 +++- 3 files changed, 1459 insertions(+), 28 deletions(-) create mode 100644 tests/autologging/test_autologging_safety_unit.py rename tests/{utils => autologging}/test_autologging_utils.py (73%) diff --git a/mlflow/utils/autologging_utils.py b/mlflow/utils/autologging_utils.py index 518ad3278475b..c775dbaa61ff6 100644 --- a/mlflow/utils/autologging_utils.py +++ b/mlflow/utils/autologging_utils.py @@ -1,10 +1,14 @@ import inspect +import itertools import functools import warnings +import logging import time import contextlib +from abc import abstractmethod import mlflow +from mlflow.entities.run_status import RunStatus from mlflow.utils import gorilla from mlflow.entities import Metric from mlflow.tracking.client import MlflowClient @@ -16,13 +20,18 @@ "please ensure that autologging is enabled before constructing the dataset." ) +# Dict mapping integration name to its config. +AUTOLOGGING_INTEGRATIONS = {} + +_logger = logging.getLogger(__name__) + def try_mlflow_log(fn, *args, **kwargs): """ Catch exceptions and log a warning to avoid autolog throwing. """ try: - fn(*args, **kwargs) + return fn(*args, **kwargs) except Exception as e: # pylint: disable=broad-except warnings.warn("Logging to MLflow failed: " + str(e), stacklevel=2) @@ -66,10 +75,33 @@ def log_fn_args_as_params(fn, args, kwargs, unlogged=[]): # pylint: disable=W01 try_mlflow_log(mlflow.log_params, params_to_log) +def _update_wrapper_extended(wrapper, wrapped): + """ + Update a `wrapper` function to look like the `wrapped` function. This is an extension of + `functools.update_wrapper` that applies the docstring *and* signature of `wrapped` to + `wrapper`, producing a new function. + + :return: A new function with the same implementation as `wrapper` and the same docstring + & signature as `wrapped`. + """ + updated_wrapper = functools.update_wrapper(wrapper, wrapped) + # Assign the signature of the `wrapped` function to the updated wrapper function. + # Certain frameworks may disallow signature inspection, causing `inspect.signature()` to throw. + # One such example is the `tensorflow.estimator.Estimator.export_savedmodel()` function + try: + updated_wrapper.__signature__ = inspect.signature(wrapped) + except Exception: # pylint: disable=broad-except + _logger.debug("Failed to restore original signature for wrapper around %s", wrapped) + return updated_wrapper + + def wrap_patch(destination, name, patch, settings=None): """ Apply a patch while preserving the attributes (e.g. __doc__) of an original function. + TODO(dbczumar): Convert this to an internal method once existing `wrap_patch` calls + outside of `autologging_utils` have been converted to `safe_patch` + :param destination: Patch destination :param name: Name of the attribute at the destination :param patch: Patch function @@ -79,7 +111,8 @@ def wrap_patch(destination, name, patch, settings=None): settings = gorilla.Settings(allow_hit=True, store_hit=True) original = getattr(destination, name) - wrapped = functools.wraps(original)(patch) + wrapped = _update_wrapper_extended(patch, original) + patch = gorilla.Patch(destination, name, wrapped, settings=settings) gorilla.apply(patch) @@ -260,3 +293,463 @@ def batch_metrics_logger(run_id): batch_metrics_logger = BatchMetricsLogger(run_id) yield batch_metrics_logger batch_metrics_logger.flush() + + +def autologging_integration(name): + """ + **All autologging integrations should be decorated with this wrapper.** + + Wraps an autologging function in order to store its configuration arguments. This enables + patch functions to broadly obey certain configurations (e.g., disable=True) without + requiring specific logic to be present in each autologging integration. + """ + + def validate_param_spec(param_spec): + if "disable" not in param_spec or param_spec["disable"].default is not False: + raise Exception( + "Invalid `autolog()` function for integration '{}'. `autolog()` functions" + " must specify a 'disable' argument with default value 'False'".format(name) + ) + + def wrapper(_autolog): + param_spec = inspect.signature(_autolog).parameters + validate_param_spec(param_spec) + + AUTOLOGGING_INTEGRATIONS[name] = {} + default_params = {param.name: param.default for param in param_spec.values()} + + def autolog(*args, **kwargs): + config_to_store = dict(default_params) + config_to_store.update( + {param.name: arg for arg, param in zip(args, param_spec.values())} + ) + config_to_store.update(kwargs) + AUTOLOGGING_INTEGRATIONS[name] = config_to_store + + return _autolog(*args, **kwargs) + + wrapped_autolog = _update_wrapper_extended(autolog, _autolog) + return wrapped_autolog + + return wrapper + + +def get_autologging_config(flavor_name, config_key, default_value=None): + """ + Returns a desired config value for a specified autologging integration. + Returns `None` if specified `flavor_name` has no recorded configs. + If `config_key` is not set on the config object, default value is returned. + + :param flavor_name: An autologging integration flavor name. + :param config_key: The key for the desired config value. + :param default_value: The default_value to return + """ + config = AUTOLOGGING_INTEGRATIONS.get(flavor_name) + if config is not None: + return config.get(config_key, default_value) + else: + return default_value + + +def autologging_is_disabled(flavor_name): + """ + Returns a boolean flag of whether the autologging integration is disabled. + + :param flavor_name: An autologging integration flavor name. + """ + return get_autologging_config(flavor_name, "disable", True) + + +def _is_testing(): + """ + Indicates whether or not autologging functionality is running in test mode (as determined + by the `MLFLOW_AUTOLOGGING_TESTING` environment variable). Test mode performs additional + validation during autologging, including: + + - Checks for the exception safety of arguments passed to model training functions + (i.e. all additional arguments should be "exception safe" functions or classes) + - Disables exception handling for patched function logic, ensuring that patch code + executes without errors during testing + """ + import os + + return os.environ.get("MLFLOW_AUTOLOGGING_TESTING", "false") == "true" + + +# Function attribute used for testing purposes to verify that a given function +# has been wrapped with the `exception_safe_function` decorator +_ATTRIBUTE_EXCEPTION_SAFE = "exception_safe" + + +def exception_safe_function(function): + """ + Wraps the specified function with broad exception handling to guard + against unexpected errors during autologging. + """ + if _is_testing(): + setattr(function, _ATTRIBUTE_EXCEPTION_SAFE, True) + + def safe_function(*args, **kwargs): + try: + return function(*args, **kwargs) + except Exception as e: # pylint: disable=broad-except + if _is_testing(): + raise + else: + _logger.warning("Encountered unexpected error during autologging: %s", e) + + safe_function = _update_wrapper_extended(safe_function, function) + return safe_function + + +class ExceptionSafeClass(type): + """ + Metaclass that wraps all functions defined on the specified class with broad error handling + logic to guard against unexpected errors during autlogging. + + Rationale: Patched autologging functions commonly pass additional class instances as arguments + to their underlying original training routines; for example, Keras autologging constructs + a subclass of `keras.callbacks.Callback` and forwards it to `Model.fit()`. To prevent errors + encountered during method execution within such classes from disrupting model training, + this metaclass wraps all class functions in a broad try / catch statement. + + Note: `ExceptionSafeClass` does not handle exceptions in class methods or static methods, + as these are not always Python callables and are difficult to wrap + """ + + def __new__(cls, name, bases, dct): + for m in dct: + if callable(dct[m]): + dct[m] = exception_safe_function(dct[m]) + return type.__new__(cls, name, bases, dct) + + +class PatchFunction: + """ + Base class representing a function patch implementation with a callback for error handling. + `PatchFunction` should be subclassed and used in conjunction with `safe_patch` to + safely modify the implementation of a function. Subclasses of `PatchFunction` should + use `_patch_implementation` to define modified ("patched") function implementations and + `_on_exception` to define cleanup logic when `_patch_implementation` terminates due + to an unhandled exception. + """ + + @abstractmethod + def _patch_implementation(self, original, *args, **kwargs): + """ + Invokes the patch function code. + + :param original: The original, underlying function over which the `PatchFunction` + is being applied. + :param *args: The positional arguments passed to the original function. + :param **kwargs: The keyword arguments passed to the original function. + """ + pass + + @abstractmethod + def _on_exception(self, exception): + """ + Called when an unhandled exception prematurely terminates the execution + of `_patch_implementation`. + + :param exception: The unhandled exception thrown by `_patch_implementation`. + """ + pass + + @classmethod + def call(cls, original, *args, **kwargs): + return cls().__call__(original, *args, **kwargs) + + def __call__(self, original, *args, **kwargs): + try: + return self._patch_implementation(original, *args, **kwargs) + except Exception as e: # pylint: disable=broad-except + try: + self._on_exception(e) + finally: + # Regardless of what happens during the `_on_exception` callback, reraise + # the original implementation exception once the callback completes + raise e + + +def with_managed_run(patch_function): + """ + Given a `patch_function`, returns an `augmented_patch_function` that wraps the execution of + `patch_function` with an active MLflow run. The following properties apply: + + - An MLflow run is only created if there is no active run present when the + patch function is executed + + - If an active run is created by the `augmented_patch_function`, it is terminated + with the `FINISHED` state at the end of function execution + + - If an active run is created by the `augmented_patch_function`, it is terminated + with the `FAILED` if an unhandled exception is thrown during function execution + + Note that, if nested runs or non-fluent runs are created by `patch_function`, `patch_function` + is responsible for terminating them by the time it terminates (or in the event of an exception). + + :param patch_function: A `PatchFunction` class definition or a function object + compatible with `safe_patch`. + """ + + if inspect.isclass(patch_function): + + class PatchWithManagedRun(patch_function): + def __init__(self): + super(PatchWithManagedRun, self).__init__() + self.managed_run = None + + def _patch_implementation(self, original, *args, **kwargs): + if not mlflow.active_run(): + self.managed_run = try_mlflow_log(mlflow.start_run) + + result = super(PatchWithManagedRun, self)._patch_implementation( + original, *args, **kwargs + ) + + if self.managed_run: + try_mlflow_log(mlflow.end_run, RunStatus.to_string(RunStatus.FINISHED)) + + return result + + def _on_exception(self, e): + if self.managed_run: + try_mlflow_log(mlflow.end_run, RunStatus.to_string(RunStatus.FAILED)) + super(PatchWithManagedRun, self)._on_exception(e) + + return PatchWithManagedRun + + else: + + def patch_with_managed_run(original, *args, **kwargs): + managed_run = None + if not mlflow.active_run(): + managed_run = try_mlflow_log(mlflow.start_run) + + try: + result = patch_function(original, *args, **kwargs) + except: + if managed_run: + try_mlflow_log(mlflow.end_run, RunStatus.to_string(RunStatus.FAILED)) + raise + else: + if managed_run: + try_mlflow_log(mlflow.end_run, RunStatus.to_string(RunStatus.FINISHED)) + return result + + return patch_with_managed_run + + +def safe_patch( + autologging_integration, destination, function_name, patch_function, manage_run=False +): + """ + Patches the specified `function_name` on the specified `destination` class for autologging + purposes, replacing its implementation with an error-safe copy of the specified patch + `function` with the following error handling behavior: + + - Exceptions thrown from the underlying / original function + (`.`) are propagated to the caller. + + - Exceptions thrown from other parts of the patched implementation (`patch_function`) + are caught and logged as warnings. + + + :param autologging_integration: The name of the autologging integration associated with the + patch. + :param destination: The Python class on which the patch is being defined. + :param function_name: The name of the function to patch on the specified `destination` class. + :param patch_function: The patched function code to apply. This is either a `PatchFunction` + class definition or a function object. If it is a function object, the + first argument should be reserved for an `original` method argument + representing the underlying / original function. Subsequent arguments + should be identical to those of the original function being patched. + :param manage_run: If `True`, applies the `with_managed_run` wrapper to the specified + `patch_function`, which automatically creates & terminates an MLflow + active run during patch code execution if necessary. If `False`, + does not apply the `with_managed_run` wrapper to the specified + `patch_function`. + """ + if manage_run: + patch_function = with_managed_run(patch_function) + + patch_is_class = inspect.isclass(patch_function) + if patch_is_class: + assert issubclass(patch_function, PatchFunction) + else: + assert callable(patch_function) + + def safe_patch_function(*args, **kwargs): + """ + A safe wrapper around the specified `patch_function` implementation designed to + handle exceptions thrown during the execution of `patch_function`. This wrapper + distinguishes exceptions thrown from the underlying / original function + (`.`) from exceptions thrown from other parts of + `patch_function`. This distinction is made by passing an augmented version of the + underlying / original function to `patch_function` that uses nonlocal state to track + whether or not it has been executed and whether or not it threw an exception. + + Exceptions thrown from the underlying / original function are propagated to the caller, + while exceptions thrown from other parts of `patch_function` are caught and logged as + warnings. + """ + original = gorilla.get_original_attribute(destination, function_name) + + # If the autologging integration associated with this patch is disabled, + # call the original function and return + if autologging_is_disabled(autologging_integration): + return original(*args, **kwargs) + + # Whether or not the original / underlying function has been called during the + # execution of patched code + original_has_been_called = False + # The value returned by the call to the original / underlying function during + # the execution of patched code + original_result = None + # Whether or not an exception was raised from within the original / underlying function + # during the execution of patched code + failed_during_original = False + + try: + + def call_original(*og_args, **og_kwargs): + try: + if _is_testing(): + _validate_args(args, kwargs, og_args, og_kwargs) + + nonlocal original_has_been_called + original_has_been_called = True + + nonlocal original_result + original_result = original(*og_args, **og_kwargs) + return original_result + except Exception: # pylint: disable=broad-except + nonlocal failed_during_original + failed_during_original = True + raise + + # Apply the name, docstring, and signature of `original` to `call_original`. + # This is important because several autologging patch implementations inspect + # the signature of the `original` argument during execution + call_original = _update_wrapper_extended(call_original, original) + + if patch_is_class: + patch_function.call(call_original, *args, **kwargs) + else: + patch_function(call_original, *args, **kwargs) + + except Exception as e: # pylint: disable=broad-except + # Exceptions thrown during execution of the original function should be propagated + # to the caller. Additionally, exceptions encountered during test mode should be + # reraised to detect bugs in autologging implementations + if failed_during_original or _is_testing(): + raise + + _logger.warning( + "Encountered unexpected error during %s autologging: %s", autologging_integration, e + ) + + if original_has_been_called: + return original_result + else: + return original(*args, **kwargs) + + wrap_patch(destination, function_name, safe_patch_function) + + +def _validate_args( + user_call_args, user_call_kwargs, autologging_call_args, autologging_call_kwargs +): + """ + Used for testing purposes to verify that, when a patched ML function calls its underlying + / original ML function, the following properties are satisfied: + + - All arguments supplied to the patched ML function are forwarded to the + original ML function + - Any additional arguments supplied to the original function are exception safe (i.e. + they are either functions decorated with the `@exception_safe_function` decorator + or classes / instances of classes with type `ExceptionSafeClass` + """ + + def _validate_new_input(inp): + """ + Validates a new input (arg or kwarg) introduced to the underlying / original ML function + call during the execution of a patched ML function. The new input is valid if: + + - The new input is a function that has been decorated with `exception_safe_function` + - OR the new input is a class with the `ExceptionSafeClass` metaclass + - OR the new input is a list and each of its elements is valid according to the + these criteria + """ + if type(inp) == list: + for item in inp: + _validate_new_input(item) + elif callable(inp): + assert getattr(inp, _ATTRIBUTE_EXCEPTION_SAFE, False), ( + "New function argument '{}' passed to original function is not exception-safe." + " Please decorate the function with `exception_safe_function`.".format(inp) + ) + else: + assert hasattr(inp, "__class__") and type(inp.__class__) == ExceptionSafeClass, ( + "Invalid new input '{}'. New args / kwargs introduced to `original` function" + " calls by patched code must either be functions decorated with" + "`exception_safe_function`, instances of classes with the `ExceptionSafeClass`" + " metaclass safe or lists of such exception safe functions / classes.".format(inp) + ) + + def _validate(autologging_call_input, user_call_input=None): + """ + Validates that the specified `autologging_call_input` and `user_call_input` + are compatible. If `user_call_input` is `None`, then `autologging_call_input` + is regarded as a new input added by autologging and is validated using + `_validate_new_input`. Otherwise, the following properties must hold: + + - `autologging_call_input` and `user_call_input` must have the same type + (referred to as "input type") + - if the input type is a tuple, list or dictionary, then `autologging_call_input` must + be equivalent to `user_call_input` or be a superset of `user_call_input` + - for all other input types, `autologging_call_input` and `user_call_input` + must be equivalent by reference equality or by object equality + """ + if user_call_input is None and autologging_call_input is not None: + _validate_new_input(autologging_call_input) + return + + assert type(autologging_call_input) == type( + user_call_input + ), "Type of input to original function '{}' does not match expected type '{}'".format( + type(autologging_call_input), type(user_call_input) + ) + + if type(autologging_call_input) in [list, tuple]: + length_difference = len(autologging_call_input) - len(user_call_input) + assert length_difference >= 0, ( + "{} expected inputs are missing from the call" + " to the original function.".format(length_difference) + ) + # If the autologging call input is longer than the user call input, we `zip_longest` + # will pad the user call input with `None` values to ensure that the subsequent calls + # to `_validate` identify new inputs added by the autologging call + for a, u in itertools.zip_longest(autologging_call_input, user_call_input): + _validate(a, u) + elif type(autologging_call_input) == dict: + assert set(user_call_input.keys()).issubset(set(autologging_call_input.keys())), ( + "Keyword or dictionary arguments to original function omit" + " one or more expected keys: '{}'".format( + set(user_call_input.keys()) - set(autologging_call_input.keys()) + ) + ) + for key in autologging_call_input.keys(): + _validate(autologging_call_input[key], user_call_input.get(key, None)) + else: + assert ( + autologging_call_input is user_call_input + or autologging_call_input == user_call_input + ), ( + "Input to original function does not match expected input." + " Original: '{}'. Expected: '{}'".format(autologging_call_input, user_call_input) + ) + + _validate(autologging_call_args, user_call_args) + _validate(autologging_call_kwargs, user_call_kwargs) diff --git a/tests/autologging/test_autologging_safety_unit.py b/tests/autologging/test_autologging_safety_unit.py new file mode 100644 index 0000000000000..fbb0e36edb1be --- /dev/null +++ b/tests/autologging/test_autologging_safety_unit.py @@ -0,0 +1,835 @@ +# pylint: disable=unused-argument + +import copy +import inspect +import os +import pytest +from unittest import mock + +import mlflow +import mlflow.utils.autologging_utils as autologging_utils +from mlflow.entities import RunStatus +from mlflow.tracking.client import MlflowClient +from mlflow.utils.autologging_utils import ( + safe_patch, + autologging_integration, + exception_safe_function, + ExceptionSafeClass, + PatchFunction, + with_managed_run, + _validate_args, + _is_testing, +) + + +pytestmark = pytest.mark.large + + +@pytest.fixture +def test_mode_on(): + with mock.patch("mlflow.utils.autologging_utils._is_testing") as testing_mock: + testing_mock.return_value = True + assert autologging_utils._is_testing() + yield + + +PATCH_DESTINATION_FN_DEFAULT_RESULT = "original_result" + + +@pytest.fixture +def patch_destination(): + class PatchObj: + def __init__(self): + self.fn_call_count = 0 + + def fn(self, *args, **kwargs): + self.fn_call_count += 1 + return PATCH_DESTINATION_FN_DEFAULT_RESULT + + return PatchObj() + + +@pytest.fixture +def test_autologging_integration(): + integration_name = "test_integration" + + @autologging_integration(integration_name) + def autolog(disable=False): + pass + + autolog() + + return integration_name + + +def test_is_testing_respects_environment_variable(): + try: + prev_env_var_value = os.environ.pop("MLFLOW_AUTOLOGGING_TESTING", None) + assert not _is_testing() + + os.environ["MLFLOW_AUTOLOGGING_TESTING"] = "false" + assert not _is_testing() + + os.environ["MLFLOW_AUTOLOGGING_TESTING"] = "true" + assert _is_testing() + finally: + if prev_env_var_value: + os.environ["MLFLOW_AUTOLOGGING_TESTING"] = prev_env_var_value + else: + del os.environ["MLFLOW_AUTOLOGGING_TESTING"] + + +def test_safe_patch_forwards_expected_arguments_to_function_based_patch_implementation( + patch_destination, test_autologging_integration +): + + foo_val = None + bar_val = None + + def patch_impl(original, foo, bar=10): + nonlocal foo_val + nonlocal bar_val + foo_val = foo + bar_val = bar + + safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) + patch_destination.fn(foo=7, bar=11) + assert foo_val == 7 + assert bar_val == 11 + + +def test_safe_patch_forwards_expected_arguments_to_class_based_patch( + patch_destination, test_autologging_integration +): + + foo_val = None + bar_val = None + + class TestPatch(PatchFunction): + def _patch_implementation(self, original, foo, bar=10): # pylint: disable=arguments-differ + nonlocal foo_val + nonlocal bar_val + foo_val = foo + bar_val = bar + + def _on_exception(self, exception): + pass + + safe_patch(test_autologging_integration, patch_destination, "fn", TestPatch) + with mock.patch( + "mlflow.utils.autologging_utils.PatchFunction.call", wraps=TestPatch.call + ) as call_mock: + patch_destination.fn(foo=7, bar=11) + assert call_mock.call_count == 1 + assert foo_val == 7 + assert bar_val == 11 + + +def test_safe_patch_provides_expected_original_function( + patch_destination, test_autologging_integration +): + def original_fn(foo, bar=10): + return { + "foo": foo, + "bar": bar, + } + + patch_destination.fn = original_fn + + def patch_impl(original, foo, bar): + return original(foo + 1, bar + 2) + + safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) + assert patch_destination.fn(1, 2) == {"foo": 2, "bar": 4} + + +def test_safe_patch_provides_expected_original_function_to_class_based_patch( + patch_destination, test_autologging_integration +): + def original_fn(foo, bar=10): + return { + "foo": foo, + "bar": bar, + } + + patch_destination.fn = original_fn + + class TestPatch(PatchFunction): + def _patch_implementation(self, original, foo, bar=10): # pylint: disable=arguments-differ + return original(foo + 1, bar + 2) + + def _on_exception(self, exception): + pass + + safe_patch(test_autologging_integration, patch_destination, "fn", TestPatch) + with mock.patch( + "mlflow.utils.autologging_utils.PatchFunction.call", wraps=TestPatch.call + ) as call_mock: + assert patch_destination.fn(1, 2) == {"foo": 2, "bar": 4} + assert call_mock.call_count == 1 + + +def test_safe_patch_propagates_exceptions_raised_from_original_function( + patch_destination, test_autologging_integration +): + + exc_to_throw = Exception("Bad original function") + + def original(*args, **kwargs): + raise exc_to_throw + + patch_destination.fn = original + + patch_impl_called = False + + def patch_impl(original, *args, **kwargs): + nonlocal patch_impl_called + patch_impl_called = True + return original(*args, **kwargs) + + safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) + + with pytest.raises(Exception) as exc: + patch_destination.fn() + + assert exc.value == exc_to_throw + assert patch_impl_called + + +def test_safe_patch_logs_exceptions_raised_outside_of_original_function_as_warnings( + patch_destination, test_autologging_integration +): + + exc_to_throw = Exception("Bad patch implementation") + + def patch_impl(original, *args, **kwargs): + raise exc_to_throw + + safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) + with mock.patch("mlflow.utils.autologging_utils._logger.warning") as logger_mock: + assert patch_destination.fn() == PATCH_DESTINATION_FN_DEFAULT_RESULT + assert logger_mock.call_count == 1 + message, formatting_arg1, formatting_arg2 = logger_mock.call_args[0] + assert "Encountered unexpected error" in message + assert formatting_arg1 == test_autologging_integration + assert formatting_arg2 == exc_to_throw + + +@pytest.mark.usefixtures(test_mode_on.__name__) +def test_safe_patch_propagates_exceptions_raised_outside_of_original_function_in_test_mode( + patch_destination, test_autologging_integration +): + + exc_to_throw = Exception("Bad patch implementation") + + def patch_impl(original, *args, **kwargs): + raise exc_to_throw + + safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) + with pytest.raises(Exception) as exc: + patch_destination.fn() + + assert exc.value == exc_to_throw + + +def test_safe_patch_calls_original_function_when_patch_preamble_throws( + patch_destination, test_autologging_integration +): + + patch_impl_called = False + + def patch_impl(original, *args, **kwargs): + nonlocal patch_impl_called + patch_impl_called = True + raise Exception("Bad patch preamble") + + safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) + assert patch_destination.fn() == PATCH_DESTINATION_FN_DEFAULT_RESULT + assert patch_destination.fn_call_count == 1 + assert patch_impl_called + + +def test_safe_patch_returns_original_result_without_second_call_when_patch_postamble_throws( + patch_destination, test_autologging_integration +): + + patch_impl_called = False + + def patch_impl(original, *args, **kwargs): + nonlocal patch_impl_called + patch_impl_called = True + original(*args, **kwargs) + raise Exception("Bad patch postamble") + + safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) + assert patch_destination.fn() == PATCH_DESTINATION_FN_DEFAULT_RESULT + assert patch_destination.fn_call_count == 1 + assert patch_impl_called + + +def test_safe_patch_respects_disable_flag(patch_destination): + + patch_impl_call_count = 0 + + @autologging_integration("test_respects_disable") + def autolog(disable=False): + def patch_impl(original, *args, **kwargs): + nonlocal patch_impl_call_count + patch_impl_call_count += 1 + return original(*args, **kwargs) + + safe_patch("test_respects_disable", patch_destination, "fn", patch_impl) + + autolog(disable=False) + patch_destination.fn() + assert patch_impl_call_count == 1 + + autolog(disable=True) + patch_destination.fn() + assert patch_impl_call_count == 1 + + +def test_safe_patch_returns_original_result_and_ignores_patch_return_value( + patch_destination, test_autologging_integration +): + + patch_impl_called = False + + def patch_impl(original, *args, **kwargs): + nonlocal patch_impl_called + patch_impl_called = True + return 10 + + safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) + assert patch_destination.fn() == PATCH_DESTINATION_FN_DEFAULT_RESULT + assert patch_destination.fn_call_count == 1 + assert patch_impl_called + + +@pytest.mark.usefixtures(test_mode_on.__name__) +def test_safe_patch_validates_arguments_to_original_function_in_test_mode( + patch_destination, test_autologging_integration +): + def patch_impl(original, *args, **kwargs): + return original("1", "2", "3") + + safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) + + with pytest.raises(Exception, match="does not match expected input"), mock.patch( + "mlflow.utils.autologging_utils._validate_args", wraps=autologging_utils._validate_args + ) as validate_mock: + patch_destination.fn("a", "b", "c") + + assert validate_mock.call_count == 1 + + +def test_safe_patch_manages_run_if_specified(patch_destination, test_autologging_integration): + + active_run = None + + def patch_impl(original, *args, **kwargs): + nonlocal active_run + active_run = mlflow.active_run() + return original(*args, **kwargs) + + with mock.patch( + "mlflow.utils.autologging_utils.with_managed_run", wraps=with_managed_run + ) as managed_run_mock: + safe_patch( + test_autologging_integration, patch_destination, "fn", patch_impl, manage_run=True + ) + patch_destination.fn() + assert managed_run_mock.call_count == 1 + assert active_run is not None + assert active_run.info.run_id is not None + + +def test_safe_patch_does_not_manage_run_if_unspecified( + patch_destination, test_autologging_integration +): + + active_run = None + + def patch_impl(original, *args, **kwargs): + nonlocal active_run + active_run = mlflow.active_run() + return original(*args, **kwargs) + + with mock.patch( + "mlflow.utils.autologging_utils.with_managed_run", wraps=with_managed_run + ) as managed_run_mock: + safe_patch( + test_autologging_integration, patch_destination, "fn", patch_impl, manage_run=False + ) + patch_destination.fn() + assert managed_run_mock.call_count == 0 + assert active_run is None + + +def test_safe_patch_preserves_signature_of_patched_function( + patch_destination, test_autologging_integration +): + def original(a, b, c=10, *, d=11): + return 10 + + patch_destination.fn = original + + patch_impl_called = False + + def patch_impl(original, *args, **kwargs): + nonlocal patch_impl_called + patch_impl_called = True + return original(*args, **kwargs) + + safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) + patch_destination.fn(1, 2) + assert patch_impl_called + assert inspect.signature(patch_destination.fn) == inspect.signature(original) + + +def test_safe_patch_provides_original_function_with_expected_signature( + patch_destination, test_autologging_integration +): + def original(a, b, c=10, *, d=11): + return 10 + + patch_destination.fn = original + + original_signature = False + + def patch_impl(original, *args, **kwargs): + nonlocal original_signature + original_signature = inspect.signature(original) + return original(*args, **kwargs) + + safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) + patch_destination.fn(1, 2) + assert original_signature == inspect.signature(original) + + +def test_exception_safe_function_exhibits_expected_behavior_in_standard_mode(): + assert not autologging_utils._is_testing() + + @exception_safe_function + def non_throwing_function(): + return 10 + + assert non_throwing_function() == 10 + + exc_to_throw = Exception("bad implementation") + + @exception_safe_function + def throwing_function(): + raise exc_to_throw + + with mock.patch("mlflow.utils.autologging_utils._logger.warning") as logger_mock: + throwing_function() + assert logger_mock.call_count == 1 + message, formatting_arg = logger_mock.call_args[0] + assert "unexpected error during autologging" in message + assert formatting_arg == exc_to_throw + + +@pytest.mark.usefixtures(test_mode_on.__name__) +def test_exception_safe_function_exhibits_expected_behavior_in_test_mode(): + assert autologging_utils._is_testing() + + @exception_safe_function + def non_throwing_function(): + return 10 + + assert non_throwing_function() == 10 + + exc_to_throw = Exception("function error") + + @exception_safe_function + def throwing_function(): + raise exc_to_throw + + with pytest.raises(Exception) as exc: + throwing_function() + + assert exc.value == exc_to_throw + + +def test_exception_safe_class_exhibits_expected_behavior_in_standard_mode(): + assert not autologging_utils._is_testing() + + class NonThrowingClass(metaclass=ExceptionSafeClass): + def function(self): + return 10 + + assert NonThrowingClass().function() == 10 + + exc_to_throw = Exception("function error") + + class ThrowingClass(metaclass=ExceptionSafeClass): + def function(self): + raise exc_to_throw + + with mock.patch("mlflow.utils.autologging_utils._logger.warning") as logger_mock: + ThrowingClass().function() + + assert logger_mock.call_count == 1 + + message, formatting_arg = logger_mock.call_args[0] + assert "unexpected error during autologging" in message + assert formatting_arg == exc_to_throw + + +@pytest.mark.usefixtures(test_mode_on.__name__) +def test_exception_safe_class_exhibits_expected_behavior_in_test_mode(): + assert autologging_utils._is_testing() + + class NonThrowingClass(metaclass=ExceptionSafeClass): + def function(self): + return 10 + + assert NonThrowingClass().function() == 10 + + exc_to_throw = Exception("function error") + + class ThrowingClass(metaclass=ExceptionSafeClass): + def function(self): + raise exc_to_throw + + with pytest.raises(Exception) as exc: + ThrowingClass().function() + + assert exc.value == exc_to_throw + + +def test_patch_function_class_call_invokes_implementation_and_returns_result(): + class TestPatchFunction(PatchFunction): + def _patch_implementation(self, original, *args, **kwargs): + return 10 + + def _on_exception(self, exception): + pass + + assert TestPatchFunction.call("foo", lambda: "foo") == 10 + + +def test_patch_function_class_call_handles_exceptions_properly(): + + called_on_exception = False + + class TestPatchFunction(PatchFunction): + def _patch_implementation(self, original, *args, **kwargs): + raise Exception("implementation exception") + + def _on_exception(self, exception): + nonlocal called_on_exception + called_on_exception = True + raise Exception("on_exception exception") + + # Even if an exception is thrown from `_on_exception`, we expect the original + # exception from the implementation to be surfaced to the caller + with pytest.raises(Exception, match="implementation exception"): + TestPatchFunction.call("foo", lambda: "foo") + + assert called_on_exception + + +def test_with_managed_runs_yields_functions_and_classes_as_expected(): + def patch_function(original, *args, **kwargs): + pass + + class TestPatch(PatchFunction): + def _patch_implementation(self, original, *args, **kwargs): + pass + + def _on_exception(self, exception): + pass + + assert callable(with_managed_run(patch_function)) + assert inspect.isclass(with_managed_run(TestPatch)) + + +def test_with_managed_run_with_non_throwing_function_exhibits_expected_behavior(): + client = MlflowClient() + + @with_managed_run + def patch_function(original, *args, **kwargs): + return mlflow.active_run() + + run1 = patch_function(lambda: "foo") + run1_status = client.get_run(run1.info.run_id).info.status + assert RunStatus.from_string(run1_status) == RunStatus.FINISHED + + with mlflow.start_run() as active_run: + run2 = patch_function(lambda: "foo") + + assert run2 == active_run + run2_status = client.get_run(run2.info.run_id).info.status + assert RunStatus.from_string(run2_status) == RunStatus.FINISHED + + +def test_with_managed_run_with_throwing_function_exhibits_expected_behavior(): + client = MlflowClient() + patch_function_active_run = None + + @with_managed_run + def patch_function(original, *args, **kwargs): + nonlocal patch_function_active_run + patch_function_active_run = mlflow.active_run() + raise Exception("bad implementation") + + with pytest.raises(Exception): + patch_function(lambda: "foo") + + assert patch_function_active_run is not None + status1 = client.get_run(patch_function_active_run.info.run_id).info.status + assert RunStatus.from_string(status1) == RunStatus.FAILED + + with mlflow.start_run() as active_run, pytest.raises(Exception): + patch_function(lambda: "foo") + assert patch_function_active_run == active_run + # `with_managed_run` should not terminate a preexisting MLflow run, + # even if the patch function throws + status2 = client.get_run(active_run.info.run_id).info.status + assert RunStatus.from_string(status2) == RunStatus.FINISHED + + +def test_with_managed_run_with_non_throwing_class_exhibits_expected_behavior(): + client = MlflowClient() + + @with_managed_run + class TestPatch(PatchFunction): + def _patch_implementation(self, original, *args, **kwargs): + return mlflow.active_run() + + def _on_exception(self, exception): + pass + + run1 = TestPatch.call(lambda: "foo") + run1_status = client.get_run(run1.info.run_id).info.status + assert RunStatus.from_string(run1_status) == RunStatus.FINISHED + + with mlflow.start_run() as active_run: + run2 = TestPatch.call(lambda: "foo") + + assert run2 == active_run + run2_status = client.get_run(run2.info.run_id).info.status + assert RunStatus.from_string(run2_status) == RunStatus.FINISHED + + +def test_with_managed_run_with_throwing_class_exhibits_expected_behavior(): + client = MlflowClient() + patch_function_active_run = None + + @with_managed_run + class TestPatch(PatchFunction): + def _patch_implementation(self, original, *args, **kwargs): + nonlocal patch_function_active_run + patch_function_active_run = mlflow.active_run() + raise Exception("bad implementation") + + def _on_exception(self, exception): + pass + + with pytest.raises(Exception): + TestPatch.call(lambda: "foo") + + assert patch_function_active_run is not None + status1 = client.get_run(patch_function_active_run.info.run_id).info.status + assert RunStatus.from_string(status1) == RunStatus.FAILED + + with mlflow.start_run() as active_run, pytest.raises(Exception): + TestPatch.call(lambda: "foo") + assert patch_function_active_run == active_run + # `with_managed_run` should not terminate a preexisting MLflow run, + # even if the patch function throws + status2 = client.get_run(active_run.info.run_id).info.status + assert RunStatus.from_string(status2) == RunStatus.FINISHED + + +@pytest.mark.usefixtures(test_mode_on.__name__) +def test_validate_args_succeeds_when_arg_sets_are_equivalent_or_identical(): + args = (1, "b", ["c"]) + kwargs = { + "foo": ["bar"], + "biz": {"baz": 5}, + } + + _validate_args(args, kwargs, args, kwargs) + _validate_args(args, None, args, None) + _validate_args(None, kwargs, None, kwargs) + + args_copy = copy.deepcopy(args) + kwargs_copy = copy.deepcopy(kwargs) + + _validate_args(args, kwargs, args_copy, kwargs_copy) + _validate_args(args, None, args_copy, None) + _validate_args(None, kwargs, None, kwargs_copy) + + +@pytest.mark.usefixtures(test_mode_on.__name__) +def test_validate_args_throws_when_extra_args_are_not_functions_classes_or_lists(): + user_call_args = (1, "b", ["c"]) + user_call_kwargs = { + "foo": ["bar"], + "biz": {"baz": 5}, + } + + invalid_type_autologging_call_args = copy.deepcopy(user_call_args) + invalid_type_autologging_call_args[2].append(10) + invalid_type_autologging_call_kwargs = copy.deepcopy(user_call_kwargs) + invalid_type_autologging_call_kwargs["new"] = {} + + with pytest.raises(Exception, match="Invalid new input"): + _validate_args( + user_call_args, user_call_kwargs, invalid_type_autologging_call_args, user_call_kwargs + ) + + with pytest.raises(Exception, match="Invalid new input"): + _validate_args( + user_call_args, user_call_kwargs, user_call_args, invalid_type_autologging_call_kwargs + ) + + +@pytest.mark.usefixtures(test_mode_on.__name__) +def test_validate_args_throws_when_extra_args_are_not_exception_safe(): + user_call_args = (1, "b", ["c"]) + user_call_kwargs = { + "foo": ["bar"], + "biz": {"baz": 5}, + } + + class Unsafe: + pass + + unsafe_autologging_call_args = copy.deepcopy(user_call_args) + unsafe_autologging_call_args += (lambda: "foo",) + unsafe_autologging_call_kwargs1 = copy.deepcopy(user_call_kwargs) + unsafe_autologging_call_kwargs1["foo"].append(Unsafe()) + + with pytest.raises(Exception, match="not exception-safe"): + _validate_args( + user_call_args, user_call_kwargs, unsafe_autologging_call_args, user_call_kwargs + ) + + with pytest.raises(Exception, match="Invalid new input"): + _validate_args( + user_call_args, user_call_kwargs, user_call_args, unsafe_autologging_call_kwargs1 + ) + + unsafe_autologging_call_kwargs2 = copy.deepcopy(user_call_kwargs) + unsafe_autologging_call_kwargs2["biz"]["new"] = Unsafe() + + with pytest.raises(Exception, match="Invalid new input"): + _validate_args( + user_call_args, user_call_kwargs, user_call_args, unsafe_autologging_call_kwargs2 + ) + + +@pytest.mark.usefixtures(test_mode_on.__name__) +def test_validate_args_succeeds_when_extra_args_are_exception_safe_functions_or_classes(): + user_call_args = (1, "b", ["c"]) + user_call_kwargs = { + "foo": ["bar"], + } + + class Safe(metaclass=ExceptionSafeClass): + pass + + autologging_call_args = copy.deepcopy(user_call_args) + autologging_call_args[2].append(Safe()) + autologging_call_args += (exception_safe_function(lambda: "foo"),) + + autologging_call_kwargs = copy.deepcopy(user_call_kwargs) + autologging_call_kwargs["foo"].append(exception_safe_function(lambda: "foo")) + autologging_call_kwargs["new"] = Safe() + + _validate_args(user_call_args, user_call_kwargs, autologging_call_args, autologging_call_kwargs) + + +@pytest.mark.usefixtures(test_mode_on.__name__) +def test_validate_args_throws_when_args_are_omitted(): + user_call_args = (1, "b", ["c"], {"d": "e"}) + user_call_kwargs = { + "foo": ["bar"], + "biz": {"baz": 4, "fuzz": 5}, + } + + invalid_autologging_call_args_1 = copy.deepcopy(user_call_args) + invalid_autologging_call_args_1[2].pop() + invalid_autologging_call_kwargs_1 = copy.deepcopy(user_call_kwargs) + invalid_autologging_call_kwargs_1["foo"].pop() + + with pytest.raises(Exception, match="missing from the call"): + _validate_args( + user_call_args, user_call_kwargs, invalid_autologging_call_args_1, user_call_kwargs + ) + + with pytest.raises(Exception, match="missing from the call"): + _validate_args( + user_call_args, user_call_kwargs, user_call_args, invalid_autologging_call_kwargs_1 + ) + + invalid_autologging_call_args_2 = copy.deepcopy(user_call_args)[1:] + invalid_autologging_call_kwargs_2 = copy.deepcopy(user_call_kwargs) + invalid_autologging_call_kwargs_2.pop("foo") + + with pytest.raises(Exception, match="missing from the call"): + _validate_args( + user_call_args, user_call_kwargs, invalid_autologging_call_args_2, user_call_kwargs + ) + + with pytest.raises(Exception, match="omit one or more expected keys"): + _validate_args( + user_call_args, user_call_kwargs, user_call_args, invalid_autologging_call_kwargs_2 + ) + + invalid_autologging_call_args_3 = copy.deepcopy(user_call_args) + invalid_autologging_call_args_3[3].pop("d") + invalid_autologging_call_kwargs_3 = copy.deepcopy(user_call_kwargs) + invalid_autologging_call_kwargs_3["biz"].pop("baz") + + with pytest.raises(Exception, match="omit one or more expected keys"): + _validate_args( + user_call_args, user_call_kwargs, invalid_autologging_call_args_3, user_call_kwargs + ) + + with pytest.raises(Exception, match="omit one or more expected keys"): + _validate_args( + user_call_args, user_call_kwargs, user_call_args, invalid_autologging_call_kwargs_3 + ) + + +@pytest.mark.usefixtures(test_mode_on.__name__) +def test_validate_args_throws_when_arg_types_or_values_are_changed(): + user_call_args = (1, "b", ["c"]) + user_call_kwargs = { + "foo": ["bar"], + } + + invalid_autologging_call_args_1 = copy.deepcopy(user_call_args) + invalid_autologging_call_args_1 = (2,) + invalid_autologging_call_args_1[1:] + invalid_autologging_call_kwargs_1 = copy.deepcopy(user_call_kwargs) + invalid_autologging_call_kwargs_1["foo"] = ["biz"] + + with pytest.raises(Exception, match="does not match expected input"): + _validate_args( + user_call_args, user_call_kwargs, invalid_autologging_call_args_1, user_call_kwargs + ) + + with pytest.raises(Exception, match="does not match expected input"): + _validate_args( + user_call_args, user_call_kwargs, user_call_args, invalid_autologging_call_kwargs_1 + ) + + call_arg_1, call_arg_2, _ = copy.deepcopy(user_call_args) + invalid_autologging_call_args_2 = ({"7": 1}, call_arg_1, call_arg_2) + invalid_autologging_call_kwargs_2 = copy.deepcopy(user_call_kwargs) + invalid_autologging_call_kwargs_2["foo"] = 8 + + with pytest.raises(Exception, match="does not match expected type"): + _validate_args( + user_call_args, user_call_kwargs, invalid_autologging_call_args_2, user_call_kwargs + ) + + with pytest.raises(Exception, match="does not match expected type"): + _validate_args( + user_call_args, user_call_kwargs, user_call_args, invalid_autologging_call_kwargs_2 + ) diff --git a/tests/utils/test_autologging_utils.py b/tests/autologging/test_autologging_utils.py similarity index 73% rename from tests/utils/test_autologging_utils.py rename to tests/autologging/test_autologging_utils.py index 0690a07af0e5b..a0182079e7a84 100644 --- a/tests/utils/test_autologging_utils.py +++ b/tests/autologging/test_autologging_utils.py @@ -1,3 +1,5 @@ +# pylint: disable=unused-argument + import inspect import time import pytest @@ -14,7 +16,15 @@ resolve_input_example_and_signature, batch_metrics_logger, BatchMetricsLogger, + autologging_integration, + get_autologging_config, + autologging_is_disabled, ) +from mlflow.utils.autologging_utils import AUTOLOGGING_INTEGRATIONS + + +pytestmark = pytest.mark.large + # Example function signature we are testing on # def fn(arg1, default1=1, default2=2): @@ -89,7 +99,6 @@ def dummy_fn(arg1, arg2="value2", arg3="value3"): # pylint: disable=W0613 ] -@pytest.mark.large @pytest.mark.parametrize("args,kwargs,expected", log_test_args) def test_log_fn_args_as_params(args, kwargs, expected, start_run): # pylint: disable=W0613 log_fn_args_as_params(dummy_fn, args, kwargs) @@ -100,7 +109,6 @@ def test_log_fn_args_as_params(args, kwargs, expected, start_run): # pylint: di assert params[arg] == value -@pytest.mark.large def test_log_fn_args_as_params_ignores_unwanted_parameters(start_run): # pylint: disable=W0613 args, kwargs, unlogged = ("arg1", {"arg2": "value"}, ["arg1", "arg2", "arg3"]) log_fn_args_as_params(dummy_fn, args, kwargs, unlogged) @@ -115,7 +123,6 @@ def get_func_attrs(f): return (f.__name__, f.__doc__, f.__module__, inspect.signature(f)) -@pytest.mark.large def test_wrap_patch_with_class(): class Math: def add(self, a, b): @@ -135,18 +142,26 @@ def new_add(self, *args, **kwargs): assert Math().add(1, 2) == 6 -@pytest.mark.large +def sample_function_to_patch(a, b): + return a + b + + def test_wrap_patch_with_module(): - def new_log_param(key, value): + import sys + + this_module = sys.modules[__name__] + + def new_sample_function(a, b): """new mlflow.log_param""" - return (key, value) + return a - b - before = get_func_attrs(mlflow.log_param) - wrap_patch(mlflow, mlflow.log_param.__name__, new_log_param) - after = get_func_attrs(mlflow.log_param) + before_attrs = get_func_attrs(mlflow.log_param) + assert sample_function_to_patch(10, 5) == 15 - assert after == before - assert mlflow.log_param("foo", "bar") == ("foo", "bar") + wrap_patch(this_module, sample_function_to_patch.__name__, new_sample_function) + after_attrs = get_func_attrs(mlflow.log_param) + assert after_attrs == before_attrs + assert sample_function_to_patch(10, 5) == 5 @pytest.fixture() @@ -240,7 +255,7 @@ def modifies(_): logger.warning.assert_not_called() -def test_batch_metrics_logger_logs_all_metrics(start_run,): # pylint: disable=unused-argument +def test_batch_metrics_logger_logs_all_metrics(start_run,): run_id = mlflow.active_run().info.run_id with batch_metrics_logger(run_id) as metrics_logger: for i in range(100): @@ -253,7 +268,7 @@ def test_batch_metrics_logger_logs_all_metrics(start_run,): # pylint: disable=u assert metrics_on_run[hex(i)] == i -def test_batch_metrics_logger_flush_logs_to_mlflow(start_run): # pylint: disable=unused-argument +def test_batch_metrics_logger_flush_logs_to_mlflow(start_run): run_id = mlflow.active_run().info.run_id # Need to patch _should_flush() to return False, so that we can manually flush the logger @@ -275,9 +290,7 @@ def test_batch_metrics_logger_flush_logs_to_mlflow(start_run): # pylint: disabl assert metrics_on_run["my_metric"] == 10 -def test_batch_metrics_logger_runs_training_and_logging_in_correct_ratio( - start_run, -): # pylint: disable=unused-argument +def test_batch_metrics_logger_runs_training_and_logging_in_correct_ratio(start_run,): with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock: run_id = mlflow.active_run().info.run_id with batch_metrics_logger(run_id) as metrics_logger: @@ -319,9 +332,7 @@ def test_batch_metrics_logger_runs_training_and_logging_in_correct_ratio( log_batch_mock.assert_called_once() -def test_batch_metrics_logger_chunks_metrics_when_batch_logging( - start_run, -): # pylint: disable=unused-argument +def test_batch_metrics_logger_chunks_metrics_when_batch_logging(start_run,): with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock: run_id = mlflow.active_run().info.run_id with batch_metrics_logger(run_id) as metrics_logger: @@ -339,7 +350,7 @@ def test_batch_metrics_logger_chunks_metrics_when_batch_logging( assert metric.step == 0 -def test_batch_metrics_logger_records_time_correctly(start_run,): # pylint: disable=unused-argument +def test_batch_metrics_logger_records_time_correctly(start_run,): with mock.patch.object(MlflowClient, "log_batch", wraps=lambda *args, **kwargs: time.sleep(1)): run_id = mlflow.active_run().info.run_id with batch_metrics_logger(run_id) as metrics_logger: @@ -354,9 +365,7 @@ def test_batch_metrics_logger_records_time_correctly(start_run,): # pylint: dis assert metrics_logger.total_training_time >= 2 -def test_batch_metrics_logger_logs_timestamps_as_int_milliseconds( - start_run, -): # pylint: disable=unused-argument +def test_batch_metrics_logger_logs_timestamps_as_int_milliseconds(start_run,): with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock, mock.patch( "time.time", return_value=123.45678901234567890 ): @@ -371,9 +380,7 @@ def test_batch_metrics_logger_logs_timestamps_as_int_milliseconds( assert logged_metric.timestamp == 123456 -def test_batch_metrics_logger_continues_if_log_batch_fails( - start_run, -): # pylint: disable=unused-argument +def test_batch_metrics_logger_continues_if_log_batch_fails(start_run,): with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock: log_batch_mock.side_effect = [Exception("asdf"), None] @@ -396,3 +403,99 @@ def test_batch_metrics_logger_continues_if_log_batch_fails( assert metric.key == "y" assert metric.value == 2 assert metric.step == 1 + + +def test_autologging_integration_calls_underlying_function_correctly(): + @autologging_integration("test_integration") + def autolog(foo=7, disable=False): + return foo + + assert autolog(foo=10) == 10 + + +def test_autologging_integration_stores_and_updates_config(): + @autologging_integration("test_integration") + def autolog(foo=7, bar=10, disable=False): + return foo + + autolog() + assert AUTOLOGGING_INTEGRATIONS["test_integration"] == {"foo": 7, "bar": 10, "disable": False} + autolog(bar=11) + assert AUTOLOGGING_INTEGRATIONS["test_integration"] == {"foo": 7, "bar": 11, "disable": False} + autolog(6, disable=True) + assert AUTOLOGGING_INTEGRATIONS["test_integration"] == {"foo": 6, "bar": 10, "disable": True} + autolog(1, 2, False) + assert AUTOLOGGING_INTEGRATIONS["test_integration"] == {"foo": 1, "bar": 2, "disable": False} + + +def test_autologging_integration_forwards_positional_and_keyword_arguments_as_expected(): + @autologging_integration("test_integration") + def autolog(foo=7, bar=10, disable=False): + return foo, bar, disable + + assert autolog(1, bar=2, disable=True) == (1, 2, True) + + +def test_autologging_integration_validates_structure_of_autolog_function(): + def fn_missing_disable_conf(): + pass + + def fn_bad_disable_conf_1(disable=True): + pass + + # Try to use a falsy value that isn't "false" + def fn_bad_disable_conf_2(disable=0): + pass + + for fn in [fn_missing_disable_conf, fn_bad_disable_conf_1, fn_bad_disable_conf_2]: + with pytest.raises(Exception, match="must specify a 'disable' argument"): + autologging_integration("test")(fn) + + # Failure to apply the @autologging_integration decorator should not create a + # placeholder for configuration state + assert "test" not in AUTOLOGGING_INTEGRATIONS + + +def test_get_autologging_config_returns_configured_values_or_defaults_as_expected(): + + assert get_autologging_config("nonexistent_integration", "foo") is None + + @autologging_integration("test_integration_for_config") + def autolog(foo="bar", t=7, disable=False): + pass + + # Before `autolog()` has been invoked, config values should not be available + assert get_autologging_config("test_integration_for_config", "foo") is None + assert get_autologging_config("test_integration_for_config", "disable") is None + assert get_autologging_config("test_integration_for_config", "t", 10) == 10 + + autolog() + + assert get_autologging_config("test_integration_for_config", "foo") == "bar" + assert get_autologging_config("test_integration_for_config", "disable") is False + assert get_autologging_config("test_integration_for_config", "t", 10) == 7 + assert get_autologging_config("test_integration_for_config", "nonexistent") is None + + autolog(foo="baz") + + assert get_autologging_config("test_integration_for_config", "foo") == "baz" + + +def test_autologging_is_disabled_returns_expected_values(): + + assert autologging_is_disabled("nonexistent_integration") is True + + @autologging_integration("test_integration_for_disable_check") + def autolog(disable=False): + pass + + # Before `autolog()` has been invoked, `autologging_is_disabled` should return False + assert autologging_is_disabled("test_integration_for_disable_check") is True + + autolog(disable=True) + + assert autologging_is_disabled("test_integration_for_disable_check") is True + + autolog(disable=False) + + assert autologging_is_disabled("test_integration_for_disable_check") is False From 2826878017b83ccb80a62ab2fbfa0655a3fa57e1 Mon Sep 17 00:00:00 2001 From: Halil Coban Date: Mon, 14 Dec 2020 01:10:56 +0100 Subject: [PATCH 09/22] reject bool metric value (#3822) * reject bool metric value Signed-off-by: Halil Coban * add comment on why we check for bool Signed-off-by: Halil Coban --- mlflow/utils/validation.py | 4 +++- tests/utils/test_validation.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mlflow/utils/validation.py b/mlflow/utils/validation.py index 6fdb6daed55f0..90dcfcac702d7 100644 --- a/mlflow/utils/validation.py +++ b/mlflow/utils/validation.py @@ -69,7 +69,9 @@ def _validate_metric(key, value, timestamp, step): it isn't. """ _validate_metric_name(key) - if not isinstance(value, numbers.Number): + # value must be a Number + # since bool is an instance of Number check for bool additionally + if isinstance(value, bool) or not isinstance(value, numbers.Number): raise MlflowException( "Got invalid value %s for metric '%s' (timestamp=%s). Please specify value as a valid " "double (64-bit floating point)" % (value, key, timestamp), diff --git a/tests/utils/test_validation.py b/tests/utils/test_validation.py index 5f71d11ab012e..e79926bc7a441 100644 --- a/tests/utils/test_validation.py +++ b/tests/utils/test_validation.py @@ -122,6 +122,7 @@ def test_validate_batch_log_data(): Metric("super-long-bad-key" * 1000, 4.0, 0, 0), ] metrics_with_bad_val = [Metric("good-metric-key", "not-a-double-val", 0, 0)] + metrics_with_bool_val = [Metric("good-metric-key", True, 0, 0)] metrics_with_bad_ts = [Metric("good-metric-key", 1.0, "not-a-timestamp", 0)] metrics_with_neg_ts = [Metric("good-metric-key", 1.0, -123, 0)] metrics_with_bad_step = [Metric("good-metric-key", 1.0, 0, "not-a-step")] @@ -145,6 +146,7 @@ def test_validate_batch_log_data(): "metrics": [ metrics_with_bad_key, metrics_with_bad_val, + metrics_with_bool_val, metrics_with_bad_ts, metrics_with_neg_ts, metrics_with_bad_step, From 09f4f24e4d34a2a93443d282a8bdfbd9fbc236bc Mon Sep 17 00:00:00 2001 From: tomasatdatabricks <33237569+tomasatdatabricks@users.noreply.github.com> Date: Sun, 13 Dec 2020 19:53:56 -0800 Subject: [PATCH 10/22] Update schema enforcement (#3798) * initial commit Signed-off-by: tomasatdatabricks * Updated tests to refelct new type conversions rules and to make sure we include hin message when necessary. Signed-off-by: tomasatdatabricks * fix tests. Signed-off-by: tomasatdatabricks * lint. Signed-off-by: tomasatdatabricks * fix. Signed-off-by: tomasatdatabricks * fix. Signed-off-by: tomasatdatabricks * fix. Signed-off-by: tomasatdatabricks * minor fix Signed-off-by: tomasatdatabricks * lint Signed-off-by: tomasatdatabricks * revert Signed-off-by: tomasatdatabricks * update Signed-off-by: tomasatdatabricks * Update doc. Signed-off-by: tomasatdatabricks * fix. Signed-off-by: tomasatdatabricks * fix docs. Signed-off-by: tomasatdatabricks * add hint/warning to schema inference Signed-off-by: tomasatdatabricks * Addressed review comments. Signed-off-by: tomasatdatabricks * Addressed review comments. Signed-off-by: tomasatdatabricks --- docs/source/models.rst | 18 ++++- mlflow/pyfunc/__init__.py | 53 +++++++++++-- mlflow/types/utils.py | 43 +++++++---- ...export_with_loader_module_and_data_path.py | 76 +++++++++++++++++-- 4 files changed, 162 insertions(+), 28 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 79bd9f2ea1c41..2021ad7d19f22 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -146,8 +146,22 @@ names, matching is done by position (i.e. MLflow will only check the number of c Column Type Enforcement """"""""""""""""""""""" The input column types are checked against the signature. MLflow will perform safe type conversions -if necessary. Generally, only upcasts (e.g. integer -> long or float -> double) are considered to be -safe. If the types cannot be made compatible, MLflow will raise an error. +if necessary. Generally, only conversions that are guaranteed to be lossless are allowed. For +example, int -> long or int -> double conversions are ok, long -> double is not. If the types cannot +be made compatible, MLflow will raise an error. + +Handling Integers With Missing Values +""""""""""""""""""""""""""""""""""""" +Integer data with missing values is typically represented as floats in Python. Therefore, data +types of integer columns in Python can vary depending on the data sample. This type variance can +cause schema enforcement errors at runtime since integer and float are not compatible types. For +example, if your training data did not have any missing values for integer column c, its type will +be integer. However, when you attempt to score a sample of the data that does include a missing +value in column c, its type will be float. If your model signature specified c to have integer type, +MLflow will raise an error since it can not convert float to int. Note that MLflow uses python to +serve models and to deploy models to Spark, so this can affect most model deployments. The best way +to avoid this problem is to declare integer columns as doubles (float64) whenever there can be +missing values. How To Log Models With Signatures ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/mlflow/pyfunc/__init__.py b/mlflow/pyfunc/__init__.py index d12ffb3ddabc8..03d5528e59da4 100644 --- a/mlflow/pyfunc/__init__.py +++ b/mlflow/pyfunc/__init__.py @@ -290,6 +290,7 @@ def _enforce_type(name, values: pandas.Series, t: DataType): 1. np.object -> string 2. int -> long (upcast) 3. float -> double (upcast) + 4. int -> double (safe conversion) Any other type mismatch will raise error. """ @@ -310,7 +311,10 @@ def _enforce_type(name, values: pandas.Series, t: DataType): "Failed to convert column {0} from type {1} to {2}.".format(name, values.dtype, t) ) - if values.dtype in (t.to_pandas(), t.to_numpy()): + # NB: Comparison of pandas and numpy data type fails when numpy data type is on the left hand + # side of the comparison operator. It works, however, if pandas type is on the left hand side. + # That is because pandas is aware of numpy. + if t.to_pandas() == values.dtype or t.to_numpy() == values.dtype: # The types are already compatible => conversion is not necessary. return values @@ -321,17 +325,46 @@ def _enforce_type(name, values: pandas.Series, t: DataType): return values numpy_type = t.to_numpy() - is_compatible_type = values.dtype.kind == numpy_type.kind - is_upcast = values.dtype.itemsize <= numpy_type.itemsize - if is_compatible_type and is_upcast: + if values.dtype.kind == numpy_type.kind: + is_upcast = values.dtype.itemsize <= numpy_type.itemsize + elif values.dtype.kind == "u" and numpy_type.kind == "i": + is_upcast = values.dtype.itemsize < numpy_type.itemsize + elif values.dtype.kind in ("i", "u") and numpy_type == np.float64: + # allow (u)int => double conversion + is_upcast = values.dtype.itemsize <= 6 + else: + is_upcast = False + + if is_upcast: return values.astype(numpy_type, errors="raise") else: # NB: conversion between incompatible types (e.g. floats -> ints or # double -> float) are not allowed. While supported by pandas and numpy, # these conversions alter the values significantly. + def all_ints(xs): + return all([pandas.isnull(x) or int(x) == x for x in xs]) + + hint = "" + if ( + values.dtype == np.float64 + and numpy_type.kind in ("i", "u") + and values.hasnans + and all_ints(values) + ): + hint = ( + " Hint: the type mismatch is likely caused by missing values. " + "Integer columns in python can not represent missing values and are therefore " + "encoded as floats. The best way to avoid this problem is to infer the model " + "schema based on a realistic data sample (training dataset) that includes missing " + "values. Alternatively, you can declare integer columns as doubles (float64) " + "whenever these columns may have missing values. See `Handling Integers With " + "Missing Values `_ for more details." + ) + raise MlflowException( "Incompatible input types for column {0}. " - "Can not safely convert {1} to {2}.".format(name, values.dtype, numpy_type) + "Can not safely convert {1} to {2}.{3}".format(name, values.dtype, numpy_type, hint) ) @@ -399,7 +432,7 @@ class PyFuncModel(object): ``model_impl`` can be any Python object that implements the `Pyfunc interface `_, and is - by invoking the model's ``loader_module``. + returned by invoking the model's ``loader_module``. ``model_meta`` contains model metadata loaded from the MLmodel file. """ @@ -415,10 +448,16 @@ def __init__(self, model_meta: Model, model_impl: Any): def predict(self, data: pandas.DataFrame) -> PyFuncOutput: """ Generate model predictions. + + If the model contains signature, enforce the input schema first before calling the model + implementation with the sanitized input. If the pyfunc model does not include model schema, + the input is passed to the model implementation as is. See `Model Signature Enforcement + `_ for more details." + :param data: Model input as pandas.DataFrame. :return: Model predictions as one of pandas.DataFrame, pandas.Series, numpy.ndarray or list. """ - input_schema = self._model_meta.get_input_schema() + input_schema = self.metadata.get_input_schema() if input_schema is not None: data = _enforce_schema(data, input_schema) return self._model_impl.predict(data) diff --git a/mlflow/types/utils.py b/mlflow/types/utils.py index e187682dd2f6e..f91fe8e4c287f 100644 --- a/mlflow/types/utils.py +++ b/mlflow/types/utils.py @@ -1,8 +1,10 @@ from typing import Any +import warnings import numpy as np import pandas as pd + from mlflow.exceptions import MlflowException from mlflow.types import DataType from mlflow.types.schema import Schema, ColSpec @@ -54,11 +56,11 @@ def _infer_schema(data: Any) -> Schema: "Data in the dictionary must be 1-dimensional, " "got shape {}".format(ary.shape) ) - return Schema(res) + schema = Schema(res) elif isinstance(data, pd.Series): - return Schema([ColSpec(type=_infer_numpy_array(data.values))]) + schema = Schema([ColSpec(type=_infer_numpy_array(data.values))]) elif isinstance(data, pd.DataFrame): - return Schema( + schema = Schema( [ColSpec(type=_infer_numpy_array(data[col].values), name=col) for col in data.columns] ) elif isinstance(data, np.ndarray): @@ -68,25 +70,40 @@ def _infer_schema(data: Any) -> Schema: ) if data.dtype == np.object: data = pd.DataFrame(data).infer_objects() - return Schema( + schema = Schema( [ColSpec(type=_infer_numpy_array(data[col].values)) for col in data.columns] ) - if len(data.shape) == 1: - return Schema([ColSpec(type=_infer_numpy_dtype(data.dtype))]) + elif len(data.shape) == 1: + schema = Schema([ColSpec(type=_infer_numpy_dtype(data.dtype))]) elif len(data.shape) == 2: - return Schema([ColSpec(type=_infer_numpy_dtype(data.dtype))] * data.shape[1]) + schema = Schema([ColSpec(type=_infer_numpy_dtype(data.dtype))] * data.shape[1]) elif _is_spark_df(data): - return Schema( + schema = Schema( [ ColSpec(type=_infer_spark_type(field.dataType), name=field.name) for field in data.schema.fields ] ) - raise TypeError( - "Expected one of (pandas.DataFrame, numpy array, " - "dictionary of (name -> numpy.ndarray), pyspark.sql.DataFrame) " - "but got '{}'".format(type(data)) - ) + else: + raise TypeError( + "Expected one of (pandas.DataFrame, numpy array, " + "dictionary of (name -> numpy.ndarray), pyspark.sql.DataFrame) " + "but got '{}'".format(type(data)) + ) + if any([t in (DataType.integer, DataType.long) for t in schema.column_types()]): + warnings.warn( + "Hint: Inferred schema contains integer column(s). Integer columns in " + "Python can not represent missing values. If your input data contains" + "missing values at inference time, it will be encoded as floats and will " + "cause a schema enforcement error. The best way to avoid this problem is " + "to infer the model schema based on a realistic data sample (training " + "dataset) that includes missing values. Alternatively, you can declare " + "integer columns as doubles (float64) whenever these columns may have " + "missing values. See `Handling Integers With Missing Values " + "`_ for more details." + ) + return schema def _infer_numpy_dtype(dtype: np.dtype) -> DataType: diff --git a/tests/pyfunc/test_model_export_with_loader_module_and_data_path.py b/tests/pyfunc/test_model_export_with_loader_module_and_data_path.py index d61f49283a185..8d4e3576a6797 100644 --- a/tests/pyfunc/test_model_export_with_loader_module_and_data_path.py +++ b/tests/pyfunc/test_model_export_with_loader_module_and_data_path.py @@ -182,39 +182,75 @@ def predict(pdf): assert res.dtypes.to_dict() == expected_types pdf["b"] = pdf["b"].astype(np.int64) - # 3. double -> float raises + # 3. unsigned int -> long works + pdf["b"] = pdf["b"].astype(np.uint32) + res = pyfunc_model.predict(pdf) + assert all((res == pdf[input_schema.column_names()]).all()) + assert res.dtypes.to_dict() == expected_types + pdf["b"] = pdf["b"].astype(np.int64) + + # 4. unsigned int -> int raises + pdf["a"] = pdf["a"].astype(np.uint32) + with pytest.raises(MlflowException) as ex: + pyfunc_model.predict(pdf) + assert "Incompatible input types" in str(ex) + pdf["a"] = pdf["a"].astype(np.int32) + + # 5. double -> float raises pdf["c"] = pdf["c"].astype(np.float64) with pytest.raises(MlflowException) as ex: pyfunc_model.predict(pdf) assert "Incompatible input types" in str(ex) pdf["c"] = pdf["c"].astype(np.float32) - # 4. float -> double works + # 6. float -> double works, double -> float does not pdf["d"] = pdf["d"].astype(np.float32) res = pyfunc_model.predict(pdf) assert res.dtypes.to_dict() == expected_types assert "Incompatible input types" in str(ex) - pdf["d"] = pdf["d"].astype(np.int64) + pdf["d"] = pdf["d"].astype(np.float64) + pdf["c"] = pdf["c"].astype(np.float64) + with pytest.raises(MlflowException) as ex: + pyfunc_model.predict(pdf) + assert "Incompatible input types" in str(ex) + pdf["c"] = pdf["c"].astype(np.float32) - # 5. floats -> ints raises + # 7. int -> float raises pdf["c"] = pdf["c"].astype(np.int32) with pytest.raises(MlflowException) as ex: pyfunc_model.predict(pdf) assert "Incompatible input types" in str(ex) pdf["c"] = pdf["c"].astype(np.float32) + # 8. int -> double works + pdf["d"] = pdf["d"].astype(np.int32) + pyfunc_model.predict(pdf) + assert all((res == pdf[input_schema.column_names()]).all()) + assert res.dtypes.to_dict() == expected_types + + # 9. long -> double raises pdf["d"] = pdf["d"].astype(np.int64) with pytest.raises(MlflowException) as ex: pyfunc_model.predict(pdf) assert "Incompatible input types" in str(ex) pdf["d"] = pdf["d"].astype(np.float64) - # 6. ints -> floats raises + # 10. any float -> any int raises pdf["a"] = pdf["a"].astype(np.float32) with pytest.raises(MlflowException) as ex: pyfunc_model.predict(pdf) assert "Incompatible input types" in str(ex) + # 10. any float -> any int raises + pdf["a"] = pdf["a"].astype(np.float64) + with pytest.raises(MlflowException) as ex: + pyfunc_model.predict(pdf) + assert "Incompatible input types" in str(ex) pdf["a"] = pdf["a"].astype(np.int32) + pdf["b"] = pdf["b"].astype(np.float64) + with pytest.raises(MlflowException) as ex: + pyfunc_model.predict(pdf) + assert "Incompatible input types" in str(ex) + pdf["b"] = pdf["b"].astype(np.int64) pdf["b"] = pdf["b"].astype(np.float64) with pytest.raises(MlflowException) as ex: @@ -222,7 +258,7 @@ def predict(pdf): pdf["b"] = pdf["b"].astype(np.int64) assert "Incompatible input types" in str(ex) - # 7. objects work + # 11. objects work pdf["b"] = pdf["b"].astype(np.object) pdf["d"] = pdf["d"].astype(np.object) pdf["e"] = pdf["e"].astype(np.object) @@ -232,6 +268,34 @@ def predict(pdf): assert res.dtypes.to_dict() == expected_types +def test_missing_value_hint_is_displayed_when_it_should(): + class TestModel(object): + @staticmethod + def predict(pdf): + return pdf + + m = Model() + input_schema = Schema([ColSpec("integer", "a")]) + m.signature = ModelSignature(inputs=input_schema) + pyfunc_model = PyFuncModel(model_meta=m, model_impl=TestModel()) + pdf = pd.DataFrame(data=[[1], [None]], columns=["a"],) + with pytest.raises(MlflowException) as ex: + pyfunc_model.predict(pdf) + hint = "Hint: the type mismatch is likely caused by missing values." + assert "Incompatible input types" in str(ex.value.message) + assert hint in str(ex.value.message) + pdf = pd.DataFrame(data=[[1.5], [None]], columns=["a"],) + with pytest.raises(MlflowException) as ex: + pyfunc_model.predict(pdf) + assert "Incompatible input types" in str(ex) + assert hint not in str(ex.value.message) + pdf = pd.DataFrame(data=[[1], [2]], columns=["a"], dtype=np.float64) + with pytest.raises(MlflowException) as ex: + pyfunc_model.predict(pdf) + assert "Incompatible input types" in str(ex.value.message) + assert hint not in str(ex.value.message) + + def test_schema_enforcement_no_col_names(): class TestModel(object): @staticmethod From f2c854ec18a157c6bf1668f9432a20810da0ca2d Mon Sep 17 00:00:00 2001 From: Harutaka Kawamura Date: Mon, 14 Dec 2020 13:11:56 +0900 Subject: [PATCH 11/22] Fix `AttributeError: 'Dataset' object has no attribute 'value'` in h5py < 3.0.0 (#3825) * Fix AttributeError: 'Dataset' object has no attribute 'value' Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * fix reimport Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * remove print Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- tests/keras/test_keras_model_export.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/keras/test_keras_model_export.py b/tests/keras/test_keras_model_export.py index 230a9c49b7b1d..a1c37f7c279ff 100644 --- a/tests/keras/test_keras_model_export.py +++ b/tests/keras/test_keras_model_export.py @@ -183,7 +183,12 @@ class FakeKerasModule(object): @staticmethod def load_model(file, **kwargs): # pylint: disable=unused-argument - return MyModel(file.get("x").value) + + # `Dataset.value` was removed in `h5py == 3.0.0` + if LooseVersion(h5py.__version__) >= LooseVersion("3.0.0"): + return MyModel(file.get("x")[()].decode("utf-8")) + else: + return MyModel(file.get("x").value) original_import = importlib.import_module From aec3be2c31081273d5ff4ba697a99018864d2a3e Mon Sep 17 00:00:00 2001 From: Harutaka Kawamura Date: Mon, 14 Dec 2020 16:13:37 +0900 Subject: [PATCH 12/22] Add gluon to the cross version tests (#3826) * Add gluon to cross-version-tests Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * fix version Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Fix metric import Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * newline Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * fix typo & pylint error Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * use load_parameters Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Fix import Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Fix test_gluon_model_export.py Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Add # pylint: disable=import-error Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Fix import position Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * nit Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- ml-package-versions.yml | 18 ++++++++++++++++++ mlflow/gluon.py | 7 ++++++- tests/gluon/test_gluon_model_export.py | 6 +++++- tests/gluon_autolog/test_gluon_autolog.py | 6 +++++- 4 files changed, 34 insertions(+), 3 deletions(-) diff --git a/ml-package-versions.yml b/ml-package-versions.yml index c1910a8696e77..eddd3d9901a2f 100644 --- a/ml-package-versions.yml +++ b/ml-package-versions.yml @@ -152,3 +152,21 @@ lightgbm: requirements: ["scikit-learn", "matplotlib"] run: | pytest tests/lightgbm/test_lightgbm_autolog.py --large + +gluon: + package_info: + pip_release: "mxnet" + install_dev: | + pip install --pre mxnet -f https://dist.mxnet.io/python/cpu + + models: + minimum: "1.5.1" + maximum: "1.7.0.post1" + run: | + pytest tests/gluon/test_gluon_model_export.py --large + + autologging: + minimum: "1.5.1" + maximum: "1.7.0.post1" + run: | + pytest tests/gluon_autolog/test_gluon_autolog.py --large diff --git a/mlflow/gluon.py b/mlflow/gluon.py index 3a2ee4fac83fa..11bcdfbf87d4c 100644 --- a/mlflow/gluon.py +++ b/mlflow/gluon.py @@ -1,3 +1,4 @@ +from distutils.version import LooseVersion import os import pandas as pd @@ -48,6 +49,7 @@ def load_model(model_uri, ctx): model = mlflow.gluon.load_model("runs:/" + gluon_random_data_run.info.run_id + "/model") model(nd.array(np.random.rand(1000, 1, 32))) """ + import mxnet from mxnet import gluon from mxnet import sym @@ -58,7 +60,10 @@ def load_model(model_uri, ctx): symbol = sym.load(model_arch_path) inputs = sym.var("data", dtype="float32") net = gluon.SymbolBlock(symbol, inputs) - net.collect_params().load(model_params_path, ctx) + if LooseVersion(mxnet.__version__) >= LooseVersion("2.0.0"): + net.load_parameters(model_params_path, ctx) + else: + net.collect_params().load(model_params_path, ctx) return net diff --git a/tests/gluon/test_gluon_model_export.py b/tests/gluon/test_gluon_model_export.py index 214662c320a74..b56a381044e44 100644 --- a/tests/gluon/test_gluon_model_export.py +++ b/tests/gluon/test_gluon_model_export.py @@ -14,7 +14,6 @@ from mxnet.gluon.data import DataLoader from mxnet.gluon.loss import SoftmaxCrossEntropyLoss from mxnet.gluon.nn import HybridSequential, Dense -from mxnet.metric import Accuracy import mlflow import mlflow.gluon @@ -29,6 +28,11 @@ from tests.helper_functions import pyfunc_serve_and_score_model +if LooseVersion(mx.__version__) >= LooseVersion("2.0.0"): + from mxnet.gluon.metric import Accuracy # pylint: disable=import-error +else: + from mxnet.metric import Accuracy # pylint: disable=import-error + @pytest.fixture def model_path(tmpdir): diff --git a/tests/gluon_autolog/test_gluon_autolog.py b/tests/gluon_autolog/test_gluon_autolog.py index 70730c28e6cf9..ee6b463cad3cf 100644 --- a/tests/gluon_autolog/test_gluon_autolog.py +++ b/tests/gluon_autolog/test_gluon_autolog.py @@ -11,13 +11,17 @@ from mxnet.gluon.data import Dataset, DataLoader from mxnet.gluon.loss import SoftmaxCrossEntropyLoss from mxnet.gluon.nn import HybridSequential, Dense -from mxnet.metric import Accuracy import mlflow import mlflow.gluon from mlflow.utils.autologging_utils import BatchMetricsLogger from unittest.mock import patch +if LooseVersion(mx.__version__) >= LooseVersion("2.0.0"): + from mxnet.gluon.metric import Accuracy # pylint: disable=import-error +else: + from mxnet.metric import Accuracy # pylint: disable=import-error + class LogsDataset(Dataset): def __init__(self): From 26fed59b20b5991c52d04cb6241e9d060bfabffd Mon Sep 17 00:00:00 2001 From: Harutaka Kawamura Date: Mon, 14 Dec 2020 17:23:28 +0900 Subject: [PATCH 13/22] Fix invalid metric error in statsmodels tests (#3828) * Fix invalid metric issue in statsmodels flavor Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Introduce _is_numeric Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/statsmodels.py | 3 ++- mlflow/utils/validation.py | 11 ++++++++++- tests/utils/test_validation.py | 10 ++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/mlflow/statsmodels.py b/mlflow/statsmodels.py index 44ee62cde4074..04444c7d5aed7 100644 --- a/mlflow/statsmodels.py +++ b/mlflow/statsmodels.py @@ -30,6 +30,7 @@ from mlflow.exceptions import MlflowException from mlflow.utils.annotations import experimental from mlflow.utils.autologging_utils import try_mlflow_log, log_fn_args_as_params +from mlflow.utils.validation import _is_numeric import itertools import inspect @@ -464,7 +465,7 @@ def results_to_dict(results): renamed_keys_dict = prepend_to_keys(d, f) results_dict.update(renamed_keys_dict) - elif isinstance(field, (int, float)): + elif _is_numeric(field): results_dict[f] = field except AttributeError: diff --git a/mlflow/utils/validation.py b/mlflow/utils/validation.py index 90dcfcac702d7..c7d6659bd5766 100644 --- a/mlflow/utils/validation.py +++ b/mlflow/utils/validation.py @@ -63,6 +63,15 @@ def _validate_metric_name(name): ) +def _is_numeric(value): + """ + Returns True if the passed-in value is numeric. + """ + # Note that `isinstance(bool_value, numbers.Number)` returns `True` because `bool` is a + # subclass of `int`. + return not isinstance(value, bool) and isinstance(value, numbers.Number) + + def _validate_metric(key, value, timestamp, step): """ Check that a param with the specified key, value, timestamp is valid and raise an exception if @@ -71,7 +80,7 @@ def _validate_metric(key, value, timestamp, step): _validate_metric_name(key) # value must be a Number # since bool is an instance of Number check for bool additionally - if isinstance(value, bool) or not isinstance(value, numbers.Number): + if not _is_numeric(value): raise MlflowException( "Got invalid value %s for metric '%s' (timestamp=%s). Please specify value as a valid " "double (64-bit floating point)" % (value, key, timestamp), diff --git a/tests/utils/test_validation.py b/tests/utils/test_validation.py index e79926bc7a441..f35d6734f3245 100644 --- a/tests/utils/test_validation.py +++ b/tests/utils/test_validation.py @@ -5,6 +5,7 @@ from mlflow.entities import Metric, Param, RunTag from mlflow.protos.databricks_pb2 import ErrorCode, INVALID_PARAMETER_VALUE from mlflow.utils.validation import ( + _is_numeric, _validate_metric_name, _validate_param_name, _validate_tag_name, @@ -43,6 +44,15 @@ ] +def test_is_numeric(): + assert _is_numeric(0) + assert _is_numeric(0.0) + assert not _is_numeric(True) + assert not _is_numeric(False) + assert not _is_numeric("0") + assert not _is_numeric(None) + + def test_validate_metric_name(): for good_name in GOOD_METRIC_OR_PARAM_NAMES: _validate_metric_name(good_name) From 0be73523cfb9ea5ef47380698307102fc1397752 Mon Sep 17 00:00:00 2001 From: Harutaka Kawamura Date: Mon, 14 Dec 2020 18:25:22 +0900 Subject: [PATCH 14/22] Add fastai to the cross version tests (#3830) * Add fastai to the cross version tests Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * add sklearn Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- ml-package-versions.yml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/ml-package-versions.yml b/ml-package-versions.yml index eddd3d9901a2f..11a262166a5ca 100644 --- a/ml-package-versions.yml +++ b/ml-package-versions.yml @@ -170,3 +170,21 @@ gluon: maximum: "1.7.0.post1" run: | pytest tests/gluon_autolog/test_gluon_autolog.py --large + +fastai-1.x: + package_info: + pip_release: "fastai" + + models: + minimum: "1.0.60" + maximum: "1.0.61" + requirements: ["scikit-learn"] + run: | + pytest tests/fastai/test_fastai_model_export.py --large + + autologging: + minimum: "1.0.60" + maximum: "1.0.61" + requirements: ["scikit-learn"] + run: | + pytest tests/fastai/test_fastai_autolog.py --large From 24211dfaf5e6b5de236e45cd734eea2f6a6d0737 Mon Sep 17 00:00:00 2001 From: dbczumar <39497902+dbczumar@users.noreply.github.com> Date: Mon, 14 Dec 2020 01:35:04 -0800 Subject: [PATCH 15/22] Add autologging safety utils to several autologging integrations (#3815) * Safe Signed-off-by: Corey Zumar * Keras Signed-off-by: Corey Zumar * Lint Signed-off-by: Corey Zumar * TF Signed-off-by: Corey Zumar * Fixes Signed-off-by: Corey Zumar * Lint Signed-off-by: Corey Zumar * Some unit tests Signed-off-by: Corey Zumar * More unit tests Signed-off-by: Corey Zumar * Test coverage for safe_patch Signed-off-by: Corey Zumar * Add public API for autologging integration configs Signed-off-by: Mohamad Arabi Signed-off-by: Corey Zumar * Remove big comment Signed-off-by: Corey Zumar * Conf tests Signed-off-by: Corey Zumar * Tests Signed-off-by: Corey Zumar * Mark large Signed-off-by: Corey Zumar * Whitespace Signed-off-by: Corey Zumar * Blackspace Signed-off-by: Corey Zumar * Rename Signed-off-by: Corey Zumar * Simplify, will raise integrations as separate PR Signed-off-by: Corey Zumar * Remove partial tensorflow Signed-off-by: Corey Zumar * Lint Signed-off-by: Corey Zumar * Black Signed-off-by: Corey Zumar * Updates from utils Signed-off-by: Corey Zumar * Remove test_mode_off for now Signed-off-by: Corey Zumar * Support positional arguments Signed-off-by: Corey Zumar * Docstring fix Signed-off-by: Corey Zumar * use match instead of comparison to str(exc) Signed-off-by: Corey Zumar * Black Signed-off-by: Corey Zumar * Forward args Signed-off-by: Corey Zumar * Fixes from #3682 Signed-off-by: Corey Zumar * integration start Signed-off-by: Corey Zumar * Lint Signed-off-by: Corey Zumar * Try importing mock from unittest? Signed-off-by: Corey Zumar * Fix import mock in statsmodel Signed-off-by: Corey Zumar * Mock fix Signed-off-by: Corey Zumar * Revert "Fix import mock in statsmodel" This reverts commit a81e8101a3a5da1962bfee22dee1006adf1fd728. Signed-off-by: Corey Zumar * Black Signed-off-by: Corey Zumar * Support tuple Signed-off-by: Corey Zumar * Address more comments Signed-off-by: Corey Zumar * Stop patching log_param Signed-off-by: Corey Zumar * Modules Signed-off-by: Corey Zumar * Another test, enable test mode broadly Signed-off-by: Corey Zumar * Black Signed-off-by: Corey Zumar * Fix Signed-off-by: Corey Zumar * Move to fixture Signed-off-by: Corey Zumar * Docstring Signed-off-by: Corey Zumar * Use test mode for try_mlflow_log Signed-off-by: Corey Zumar * Test try_mlflow_log Signed-off-by: Corey Zumar * Docs Signed-off-by: Corey Zumar * Assert Signed-off-by: Corey Zumar * Try log keras Signed-off-by: Corey Zumar * Review comment, add init for tests Signed-off-by: Corey Zumar * Lint Signed-off-by: Corey Zumar * Actually commit the fixtures file... Signed-off-by: Corey Zumar * Test fixes, lint Signed-off-by: Corey Zumar * Fix, format Signed-off-by: Corey Zumar * Fix fast.ai Signed-off-by: Corey Zumar * Lintfix Signed-off-by: Corey Zumar * Lint Signed-off-by: Corey Zumar * Docstrings Signed-off-by: Corey Zumar * Address nit Signed-off-by: Corey Zumar * Lint Signed-off-by: Corey Zumar Co-authored-by: Mohamad Arabi --- mlflow/keras.py | 59 ++++++++--------- mlflow/sklearn/__init__.py | 55 +++++----------- mlflow/tracking/fluent.py | 6 +- mlflow/utils/autologging_utils.py | 8 ++- mlflow/xgboost.py | 38 +++++------ tests/autologging/__init__.py | 0 tests/autologging/fixtures.py | 20 ++++++ .../test_autologging_safety_integration.py | 65 +++++++++++++++++++ .../test_autologging_safety_unit.py | 43 ++++++++++-- tests/autologging/test_autologging_utils.py | 4 ++ tests/conftest.py | 23 +++++++ tests/fastai/test_fastai_autolog.py | 16 ++--- tests/sklearn/test_sklearn_autolog.py | 35 ++-------- 13 files changed, 231 insertions(+), 141 deletions(-) create mode 100644 tests/autologging/__init__.py create mode 100644 tests/autologging/fixtures.py create mode 100644 tests/autologging/test_autologging_safety_integration.py diff --git a/mlflow/keras.py b/mlflow/keras.py index 3b65969aa8452..9be7217f03889 100644 --- a/mlflow/keras.py +++ b/mlflow/keras.py @@ -24,14 +24,15 @@ from mlflow.models.signature import ModelSignature from mlflow.models.utils import ModelInputExample, _save_example from mlflow.tracking.artifact_utils import _download_artifact_from_uri -from mlflow.utils import gorilla from mlflow.utils.environment import _mlflow_conda_env from mlflow.utils.model_utils import _get_flavor_configuration from mlflow.utils.annotations import experimental from mlflow.utils.autologging_utils import ( + autologging_integration, + safe_patch, + ExceptionSafeClass, try_mlflow_log, log_fn_args_as_params, - wrap_patch, batch_metrics_logger, ) from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS @@ -561,10 +562,12 @@ def load_model(model_uri, **kwargs): @experimental -def autolog(log_models=True): +@autologging_integration(FLAVOR_NAME) +def autolog(log_models=True, disable=False): # pylint: disable=unused-argument # pylint: disable=E0611 """ - Enables automatic logging from Keras to MLflow. Autologging captures the following information: + Enables (or disables) and configures autologging from Keras to MLflow. Autologging captures + the following information: **Metrics** and **Parameters** - Training loss; validation loss; user-specified metrics @@ -611,11 +614,13 @@ def autolog(log_models=True): :param log_models: If ``True``, trained models are logged as MLflow model artifacts. If ``False``, trained models are not logged. + :param disable: If ``True``, disables all supported autologging integrations. If ``False``, + enables all supported autologging integrations. """ import keras def getKerasCallback(metrics_logger): - class __MLflowKerasCallback(keras.callbacks.Callback): + class __MLflowKerasCallback(keras.callbacks.Callback, metaclass=ExceptionSafeClass): """ Callback for auto-logging metrics and parameters. Records available logs after each epoch. @@ -691,17 +696,14 @@ def _early_stop_check(callbacks): def _log_early_stop_callback_params(callback): if callback: - try: - earlystopping_params = { - "monitor": callback.monitor, - "min_delta": callback.min_delta, - "patience": callback.patience, - "baseline": callback.baseline, - "restore_best_weights": callback.restore_best_weights, - } - try_mlflow_log(mlflow.log_params, earlystopping_params) - except Exception: # pylint: disable=W0703 - return + earlystopping_params = { + "monitor": callback.monitor, + "min_delta": callback.min_delta, + "patience": callback.patience, + "baseline": callback.baseline, + "restore_best_weights": callback.restore_best_weights, + } + try_mlflow_log(mlflow.log_params, earlystopping_params) def _get_early_stop_callback_attrs(callback): try: @@ -731,12 +733,6 @@ def _log_early_stop_callback_metrics(callback, history, metrics_logger): metrics_logger.record_metrics(restored_metrics, last_epoch) def _run_and_log_function(self, original, args, kwargs, unlogged_params, callback_arg_index): - if not mlflow.active_run(): - try_mlflow_log(mlflow.start_run) - auto_end_run = True - else: - auto_end_run = False - log_fn_args_as_params(original, args, kwargs, unlogged_params) early_stop_callback = None @@ -755,37 +751,34 @@ def _run_and_log_function(self, original, args, kwargs, unlogged_params, callbac else: kwargs["callbacks"] = [mlflowKerasCallback] - _log_early_stop_callback_params(early_stop_callback) + try_mlflow_log(_log_early_stop_callback_params, early_stop_callback) history = original(self, *args, **kwargs) - _log_early_stop_callback_metrics(early_stop_callback, history, metrics_logger) - - if auto_end_run: - try_mlflow_log(mlflow.end_run) + try_mlflow_log( + _log_early_stop_callback_metrics, early_stop_callback, history, metrics_logger + ) return history - def fit(self, *args, **kwargs): - original = gorilla.get_original_attribute(keras.Model, "fit") + def fit(original, self, *args, **kwargs): unlogged_params = ["self", "x", "y", "callbacks", "validation_data", "verbose"] return _run_and_log_function(self, original, args, kwargs, unlogged_params, 5) - def fit_generator(self, *args, **kwargs): + def fit_generator(original, self, *args, **kwargs): """ NOTE: `fit_generator()` is deprecated in Keras >= 2.4.0 and simply wraps `fit()`. To avoid unintentional creation of nested MLflow runs caused by a patched `fit_generator()` method calling a patched `fit()` method, we only patch `fit_generator()` in Keras < 2.4.0. """ - original = gorilla.get_original_attribute(keras.Model, "fit_generator") unlogged_params = ["self", "generator", "callbacks", "validation_data", "verbose"] return _run_and_log_function(self, original, args, kwargs, unlogged_params, 4) - wrap_patch(keras.Model, "fit", fit) + safe_patch(FLAVOR_NAME, keras.Model, "fit", fit, manage_run=True) # `fit_generator()` is deprecated in Keras >= 2.4.0 and simply wraps `fit()`. # To avoid unintentional creation of nested MLflow runs caused by a patched # `fit_generator()` method calling a patched `fit()` method, we only patch # `fit_generator()` in Keras < 2.4.0. if LooseVersion(keras.__version__) < LooseVersion("2.4.0"): - wrap_patch(keras.Model, "fit_generator", fit_generator) + safe_patch(FLAVOR_NAME, keras.Model, "fit_generator", fit_generator, manage_run=True) diff --git a/mlflow/sklearn/__init__.py b/mlflow/sklearn/__init__.py index 5a8e093c4bd65..15d66f3f3f45d 100644 --- a/mlflow/sklearn/__init__.py +++ b/mlflow/sklearn/__init__.py @@ -18,7 +18,6 @@ import mlflow from mlflow import pyfunc -from mlflow.entities.run_status import RunStatus from mlflow.exceptions import MlflowException from mlflow.models import Model from mlflow.models.model import MLMODEL_FILE_NAME @@ -27,13 +26,13 @@ from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, INTERNAL_ERROR from mlflow.protos.databricks_pb2 import RESOURCE_ALREADY_EXISTS from mlflow.tracking.artifact_utils import _download_artifact_from_uri -from mlflow.utils import gorilla from mlflow.utils.annotations import experimental from mlflow.utils.environment import _mlflow_conda_env from mlflow.utils.model_utils import _get_flavor_configuration from mlflow.utils.autologging_utils import ( + autologging_integration, + safe_patch, try_mlflow_log, - wrap_patch, INPUT_EXAMPLE_SAMPLE_ROWS, resolve_input_example_and_signature, ) @@ -532,9 +531,12 @@ def should_log(self): @experimental -def autolog(log_input_examples=False, log_model_signatures=True, log_models=True): +@autologging_integration(FLAVOR_NAME) +def autolog( + log_input_examples=False, log_model_signatures=True, log_models=True, disable=False +): # pylint: disable=unused-argument """ - Enables autologging for scikit-learn estimators. + Enables (or disables) and configures autologging for scikit-learn estimators. **When is autologging performed?** Autologging is performed when you call: @@ -712,6 +714,8 @@ def fetch_logged_data(run_id): If ``False``, trained models are not logged. Input examples and model signatures, which are attributes of MLflow models, are also omitted when ``log_models`` is ``False``. + :param disable: If ``True``, disables all supported autologging integrations. If ``False``, + enables all supported autologging integrations. """ import pandas as pd import sklearn @@ -749,32 +753,15 @@ def fetch_logged_data(run_id): stacklevel=2, ) - def fit_mlflow(self, clazz, func_name, *args, **kwargs): + def fit_mlflow(original, self, *args, **kwargs): """ Autologging function that performs model training by executing the training method referred to be `func_name` on the instance of `clazz` referred to by `self` & records MLflow parameters, metrics, tags, and artifacts to a corresponding MLflow Run. """ - should_start_run = mlflow.active_run() is None - if should_start_run: - try_mlflow_log(mlflow.start_run) - _log_pretraining_metadata(self, *args, **kwargs) - - original_fit = gorilla.get_original_attribute(clazz, func_name) - try: - fit_output = original_fit(self, *args, **kwargs) - except Exception as e: - if should_start_run: - try_mlflow_log(mlflow.end_run, RunStatus.to_string(RunStatus.FAILED)) - - raise e - + fit_output = original(self, *args, **kwargs) _log_posttraining_metadata(self, *args, **kwargs) - - if should_start_run: - try_mlflow_log(mlflow.end_run) - return fit_output def _log_pretraining_metadata(estimator, *args, **kwargs): # pylint: disable=unused-argument @@ -923,7 +910,7 @@ def infer_model_signature(input_example): ) _logger.warning(msg) - def patched_fit(self, clazz, func_name, *args, **kwargs): + def patched_fit(original, self, *args, **kwargs): """ Autologging patch function to be applied to a sklearn model class that defines a `fit` method and inherits from `BaseEstimator` (thereby defining the `get_params()` method) @@ -934,18 +921,11 @@ def patched_fit(self, clazz, func_name, *args, **kwargs): for autologging (e.g., specify "fit" in order to indicate that `sklearn.linear_model.LogisticRegression.fit()` is being patched) """ - with _SklearnTrainingSession(clazz=clazz, allow_children=False) as t: + with _SklearnTrainingSession(clazz=self.__class__, allow_children=False) as t: if t.should_log(): - return fit_mlflow(self, clazz, func_name, *args, **kwargs) + return fit_mlflow(original, self, *args, **kwargs) else: - original_fit = gorilla.get_original_attribute(clazz, func_name) - return original_fit(self, *args, **kwargs) - - def create_patch_func(clazz, func_name): - def f(self, *args, **kwargs): - return patched_fit(self, clazz, func_name, *args, **kwargs) - - return f + return original(self, *args, **kwargs) _, estimators_to_patch = zip(*_all_estimators()) # Ensure that relevant meta estimators (e.g. GridSearchCV, Pipeline) are selected @@ -998,5 +978,6 @@ def f(self, *args, **kwargs): if isinstance(original, property): continue - patch_func = create_patch_func(class_def, func_name) - wrap_patch(class_def, func_name, patch_func) + safe_patch( + FLAVOR_NAME, class_def, func_name, patched_fit, manage_run=True, + ) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index e652ba0868056..e5a729d76f337 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -1212,10 +1212,10 @@ def _get_experiment_id(): def autolog( - log_input_examples=False, log_model_signatures=True, log_models=True + log_input_examples=False, log_model_signatures=True, log_models=True, disable=False ): # pylint: disable=unused-argument """ - Enable autologging for all supported integrations. + Enables (or disables) and configures autologging for all supported integrations. The parameters are passed to any autologging integrations that support them. @@ -1237,6 +1237,8 @@ def autolog( If ``False``, trained models are not logged. Input examples and model signatures, which are attributes of MLflow models, are also omitted when ``log_models`` is ``False``. + :param disable: If ``True``, disables all supported autologging integrations. If ``False``, + enables all supported autologging integrations. .. code-block:: python :caption: Example diff --git a/mlflow/utils/autologging_utils.py b/mlflow/utils/autologging_utils.py index c775dbaa61ff6..2328d369401cb 100644 --- a/mlflow/utils/autologging_utils.py +++ b/mlflow/utils/autologging_utils.py @@ -19,6 +19,7 @@ ENSURE_AUTOLOGGING_ENABLED_TEXT = ( "please ensure that autologging is enabled before constructing the dataset." ) +_AUTOLOGGING_TEST_MODE_ENV_VAR = "MLFLOW_AUTOLOGGING_TESTING" # Dict mapping integration name to its config. AUTOLOGGING_INTEGRATIONS = {} @@ -33,7 +34,10 @@ def try_mlflow_log(fn, *args, **kwargs): try: return fn(*args, **kwargs) except Exception as e: # pylint: disable=broad-except - warnings.warn("Logging to MLflow failed: " + str(e), stacklevel=2) + if _is_testing(): + raise + else: + warnings.warn("Logging to MLflow failed: " + str(e), stacklevel=2) def log_fn_args_as_params(fn, args, kwargs, unlogged=[]): # pylint: disable=W0102 @@ -373,7 +377,7 @@ def _is_testing(): """ import os - return os.environ.get("MLFLOW_AUTOLOGGING_TESTING", "false") == "true" + return os.environ.get(_AUTOLOGGING_TEST_MODE_ENV_VAR, "false") == "true" # Function attribute used for testing purposes to verify that a given function diff --git a/mlflow/xgboost.py b/mlflow/xgboost.py index c688e37c02bb8..22552992f9b87 100644 --- a/mlflow/xgboost.py +++ b/mlflow/xgboost.py @@ -32,15 +32,16 @@ from mlflow.models.signature import ModelSignature from mlflow.models.utils import _save_example from mlflow.tracking.artifact_utils import _download_artifact_from_uri -from mlflow.utils import gorilla from mlflow.utils.environment import _mlflow_conda_env from mlflow.utils.model_utils import _get_flavor_configuration from mlflow.exceptions import MlflowException from mlflow.utils.annotations import experimental from mlflow.utils.autologging_utils import ( + autologging_integration, + safe_patch, + exception_safe_function, try_mlflow_log, log_fn_args_as_params, - wrap_patch, INPUT_EXAMPLE_SAMPLE_ROWS, resolve_input_example_and_signature, _InputExampleInfo, @@ -283,11 +284,16 @@ def predict(self, dataframe): @experimental +@autologging_integration(FLAVOR_NAME) def autolog( - importance_types=None, log_input_examples=False, log_model_signatures=True, log_models=True, -): + importance_types=None, + log_input_examples=False, + log_model_signatures=True, + log_models=True, + disable=False, +): # pylint: disable=W0102,unused-argument """ - Enables automatic logging from XGBoost to MLflow. Logs the following. + Enables (or disables) and configures autologging from XGBoost to MLflow. Logs the following: - parameters specified in `xgboost.train`_. - metrics on each iteration (if ``evals`` specified). @@ -316,6 +322,8 @@ def autolog( If ``False``, trained models are not logged. Input examples and model signatures, which are attributes of MLflow models, are also omitted when ``log_models`` is ``False``. + :param disable: If ``True``, disables all supported autologging integrations. If ``False``, + enables all supported autologging integrations. """ import xgboost import numpy as np @@ -327,9 +335,8 @@ def autolog( # to use as an input example and for inferring the model signature. # (there is no way to get the data back from a DMatrix object) # We store it on the DMatrix object so the train function is able to read it. - def __init__(self, *args, **kwargs): + def __init__(original, self, *args, **kwargs): data = args[0] if len(args) > 0 else kwargs.get("data") - original = gorilla.get_original_attribute(xgboost.DMatrix, "__init__") if data is not None: try: @@ -348,24 +355,19 @@ def __init__(self, *args, **kwargs): original(self, *args, **kwargs) - def train(*args, **kwargs): + def train(original, *args, **kwargs): def record_eval_results(eval_results, metrics_logger): """ Create a callback function that records evaluation results. """ + @exception_safe_function def callback(env): metrics_logger.record_metrics(dict(env.evaluation_result_list), env.iteration) eval_results.append(dict(env.evaluation_result_list)) return callback - if not mlflow.active_run(): - try_mlflow_log(mlflow.start_run) - auto_end_run = True - else: - auto_end_run = False - def log_feature_importance_plot(features, importance, importance_type): """ Log feature importance plot. @@ -403,8 +405,6 @@ def log_feature_importance_plot(features, importance, importance_type): plt.close(fig) shutil.rmtree(tmpdir) - original = gorilla.get_original_attribute(xgboost, "train") - # logging booster params separately via mlflow.log_params to extract key/value pairs # and make it easier to compare them across runs. params = args[0] if len(args) > 0 else kwargs["params"] @@ -518,9 +518,7 @@ def infer_model_signature(input_example): input_example=input_example, ) - if auto_end_run: - try_mlflow_log(mlflow.end_run) return model - wrap_patch(xgboost, "train", train) - wrap_patch(xgboost.DMatrix, "__init__", __init__) + safe_patch(FLAVOR_NAME, xgboost, "train", train, manage_run=True) + safe_patch(FLAVOR_NAME, xgboost.DMatrix, "__init__", __init__) diff --git a/tests/autologging/__init__.py b/tests/autologging/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/autologging/fixtures.py b/tests/autologging/fixtures.py new file mode 100644 index 0000000000000..e42c7e98f2de7 --- /dev/null +++ b/tests/autologging/fixtures.py @@ -0,0 +1,20 @@ +import pytest +from unittest import mock + +import mlflow.utils.autologging_utils as autologging_utils + + +@pytest.fixture +def test_mode_off(): + with mock.patch("mlflow.utils.autologging_utils._is_testing") as testing_mock: + testing_mock.return_value = False + assert not autologging_utils._is_testing() + yield + + +@pytest.fixture +def test_mode_on(): + with mock.patch("mlflow.utils.autologging_utils._is_testing") as testing_mock: + testing_mock.return_value = True + assert autologging_utils._is_testing() + yield diff --git a/tests/autologging/test_autologging_safety_integration.py b/tests/autologging/test_autologging_safety_integration.py new file mode 100644 index 0000000000000..2a763e615b523 --- /dev/null +++ b/tests/autologging/test_autologging_safety_integration.py @@ -0,0 +1,65 @@ +# pylint: disable=unused-argument + +import importlib +import pytest +from unittest import mock + +import mlflow +from mlflow.utils import gorilla +from mlflow.utils.autologging_utils import ( + safe_patch, + get_autologging_config, + autologging_is_disabled, +) + + +pytestmark = pytest.mark.large + + +AUTOLOGGING_INTEGRATIONS_TO_TEST = { + mlflow.sklearn: "sklearn", + mlflow.keras: "keras", + mlflow.xgboost: "xgboost", +} + + +@pytest.fixture(autouse=True, scope="module") +def import_integration_libraries(): + for library_module in AUTOLOGGING_INTEGRATIONS_TO_TEST.values(): + importlib.import_module(library_module) + + +@pytest.fixture(autouse=True) +def disable_autologging_at_test_end(): + for integration in AUTOLOGGING_INTEGRATIONS_TO_TEST: + integration.autolog(disable=True) + + +def test_autologging_integrations_expose_configs_and_support_disablement(): + for integration in AUTOLOGGING_INTEGRATIONS_TO_TEST: + integration.autolog(disable=False) + + assert not autologging_is_disabled(integration.FLAVOR_NAME) + assert not get_autologging_config(integration.FLAVOR_NAME, "disable", True) + + integration.autolog(disable=True) + + assert autologging_is_disabled(integration.FLAVOR_NAME) + assert get_autologging_config(integration.FLAVOR_NAME, "disable", False) + + +def test_autologging_integrations_use_safe_patch_for_monkey_patching(): + for integration in AUTOLOGGING_INTEGRATIONS_TO_TEST: + with mock.patch( + "mlflow.utils.gorilla.apply", wraps=gorilla.apply + ) as gorilla_mock, mock.patch( + integration.__name__ + ".safe_patch", wraps=safe_patch + ) as safe_patch_mock: + integration.autolog(disable=False) + assert safe_patch_mock.call_count > 0 + # `safe_patch` leverages `gorilla.apply` in its implementation. Accordingly, we expect + # that the total number of `gorilla.apply` calls to be equivalent to the number of + # `safe_patch` calls. This verifies that autologging integrations are leveraging + # `safe_patch`, rather than calling `gorilla.apply` directly (which does not provide + # exception safety properties) + assert safe_patch_mock.call_count == gorilla_mock.call_count diff --git a/tests/autologging/test_autologging_safety_unit.py b/tests/autologging/test_autologging_safety_unit.py index fbb0e36edb1be..e51ac577836fb 100644 --- a/tests/autologging/test_autologging_safety_unit.py +++ b/tests/autologging/test_autologging_safety_unit.py @@ -19,21 +19,29 @@ with_managed_run, _validate_args, _is_testing, + try_mlflow_log, ) +from tests.autologging.fixtures import test_mode_off, test_mode_on # pylint: disable=unused-import + pytestmark = pytest.mark.large -@pytest.fixture -def test_mode_on(): - with mock.patch("mlflow.utils.autologging_utils._is_testing") as testing_mock: - testing_mock.return_value = True - assert autologging_utils._is_testing() - yield +PATCH_DESTINATION_FN_DEFAULT_RESULT = "original_result" -PATCH_DESTINATION_FN_DEFAULT_RESULT = "original_result" +@pytest.fixture(autouse=True) +def turn_test_mode_off_by_default(test_mode_off): + """ + Most of the unit test cases in this module assume that autologging APIs are operating in a + standard execution mode (i.e. where test mode is disabled). Accordingly, we turn off autologging + test mode for this test module by default. Test cases that verify behaviors specific to test + mode enable test mode explicitly by specifying the `test_mode_on` fixture. + + For more information about autologging test mode, see the docstring for + :py:func:`mlflow.utils.autologging_utils._is_testing()`. + """ @pytest.fixture @@ -833,3 +841,24 @@ def test_validate_args_throws_when_arg_types_or_values_are_changed(): _validate_args( user_call_args, user_call_kwargs, user_call_args, invalid_autologging_call_kwargs_2 ) + + +def test_try_mlflow_log_emits_exceptions_as_warnings_in_standard_mode(): + assert not autologging_utils._is_testing() + + def throwing_function(): + raise Exception("bad implementation") + + with pytest.warns(UserWarning, match="bad implementation"): + try_mlflow_log(throwing_function) + + +@pytest.mark.usefixtures(test_mode_on.__name__) +def test_try_mlflow_log_propagates_exceptions_in_test_mode(): + assert autologging_utils._is_testing() + + def throwing_function(): + raise Exception("bad implementation") + + with pytest.raises(Exception, match="bad implementation"): + try_mlflow_log(throwing_function) diff --git a/tests/autologging/test_autologging_utils.py b/tests/autologging/test_autologging_utils.py index a0182079e7a84..e80cf08228e8a 100644 --- a/tests/autologging/test_autologging_utils.py +++ b/tests/autologging/test_autologging_utils.py @@ -23,6 +23,9 @@ from mlflow.utils.autologging_utils import AUTOLOGGING_INTEGRATIONS +from tests.autologging.fixtures import test_mode_off + + pytestmark = pytest.mark.large @@ -380,6 +383,7 @@ def test_batch_metrics_logger_logs_timestamps_as_int_milliseconds(start_run,): assert logged_metric.timestamp == 123456 +@pytest.mark.usefixtures(test_mode_off.__name__) def test_batch_metrics_logger_continues_if_log_batch_fails(start_run,): with mock.patch.object(MlflowClient, "log_batch") as log_batch_mock: log_batch_mock.side_effect = [Exception("asdf"), None] diff --git a/tests/conftest.py b/tests/conftest.py index fa4407762958d..5bdfd04b32de8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,10 @@ import pytest import mlflow +from mlflow.utils.autologging_utils import ( + _is_testing, + _AUTOLOGGING_TEST_MODE_ENV_VAR, +) from mlflow.utils.file_utils import path_to_local_sqlite_uri @@ -33,3 +37,22 @@ def tracking_uri_mock(tmpdir, request): mlflow.set_tracking_uri(None) if "notrackingurimock" not in request.keywords: del os.environ["MLFLOW_TRACKING_URI"] + + +@pytest.fixture(autouse=True, scope="session") +def enable_test_mode_by_default_for_autologging_integrations(): + """ + Run all MLflow tests in autologging test mode, ensuring that errors in autologging patch code + are raised and detected. For more information about autologging test mode, see the docstring + for :py:func:`mlflow.utils.autologging_utils._is_testing()`. + """ + try: + prev_env_var_value = os.environ.pop(_AUTOLOGGING_TEST_MODE_ENV_VAR, None) + os.environ[_AUTOLOGGING_TEST_MODE_ENV_VAR] = "true" + assert _is_testing() + yield + finally: + if prev_env_var_value: + os.environ[_AUTOLOGGING_TEST_MODE_ENV_VAR] = prev_env_var_value + else: + del os.environ[_AUTOLOGGING_TEST_MODE_ENV_VAR] diff --git a/tests/fastai/test_fastai_autolog.py b/tests/fastai/test_fastai_autolog.py index 34b0d2c4bdcec..7e2dd182d0237 100644 --- a/tests/fastai/test_fastai_autolog.py +++ b/tests/fastai/test_fastai_autolog.py @@ -177,20 +177,18 @@ def test_fastai_autolog_model_can_load_from_artifact(fastai_random_data_run): def fastai_random_data_run_with_callback(iris_data, fit_variant, manual_run, callback, patience): # pylint: disable=unused-argument mlflow.fastai.autolog() - callbacks = [] - if callback == "early": - # min_delta is set as such to guarantee early stopping - callbacks.append( - lambda learn: EarlyStoppingCallback(learn, patience=patience, min_delta=MIN_DELTA) - ) + model = fastai_model(iris_data) - model = fastai_model(iris_data, callback_fns=callbacks) + callbacks = [] + if callback == "early": + callback = EarlyStoppingCallback(learn=model, patience=patience, min_delta=MIN_DELTA) + callbacks.append(callback) if fit_variant == "fit_one_cycle": - model.fit_one_cycle(NUM_EPOCHS) + model.fit_one_cycle(NUM_EPOCHS, callbacks=callbacks) else: - model.fit(NUM_EPOCHS) + model.fit(NUM_EPOCHS, callbacks=callbacks) client = mlflow.tracking.MlflowClient() return model, client.get_run(client.list_run_infos(experiment_id="0")[0].run_id) diff --git a/tests/sklearn/test_sklearn_autolog.py b/tests/sklearn/test_sklearn_autolog.py index d5d9d4a577ffa..cd54f06494bcb 100644 --- a/tests/sklearn/test_sklearn_autolog.py +++ b/tests/sklearn/test_sklearn_autolog.py @@ -2,7 +2,6 @@ import inspect from unittest import mock import os -import warnings import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -26,7 +25,6 @@ _truncate_dict, ) from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID -from mlflow.utils.autologging_utils import try_mlflow_log from mlflow.utils.validation import ( MAX_PARAMS_TAGS_PER_BATCH, MAX_METRICS_PER_BATCH, @@ -34,6 +32,8 @@ MAX_ENTITY_KEY_LENGTH, ) +from tests.autologging.fixtures import test_mode_off + FIT_FUNC_NAMES = ["fit", "fit_transform", "fit_predict"] TRAINING_SCORE = "training_score" ESTIMATOR_CLASS = "estimator_class" @@ -107,33 +107,6 @@ def fit_func_name(request): return request.param -@pytest.fixture(autouse=True, scope="function") -def force_try_mlflow_log_to_fail(request): - # autolog contains multiple `try_mlflow_log`. They unexpectedly allow tests that - # should fail to pass (without us noticing). To prevent that, temporarily turns - # warnings emitted by `try_mlflow_log` into errors. - if "disable_force_try_mlflow_log_to_fail" in request.keywords: - yield - else: - with warnings.catch_warnings(): - warnings.filterwarnings( - "error", message=r"^Logging to MLflow failed", category=UserWarning, - ) - yield - - -@pytest.mark.xfail(strict=True, raises=UserWarning) -def test_force_try_mlflow_log_to_fail(): - with mlflow.start_run(): - try_mlflow_log(lambda: 1 / 0) - - -@pytest.mark.disable_force_try_mlflow_log_to_fail -def test_no_force_try_mlflow_log_to_fail(): - with mlflow.start_run(): - try_mlflow_log(lambda: 1 / 0) - - def test_autolog_preserves_original_function_attributes(): def get_func_attrs(f): attrs = {} @@ -852,7 +825,7 @@ def test_parameter_search_handles_large_volume_of_metric_outputs(): assert len(child_run.data.metrics) >= metrics_size -@pytest.mark.disable_force_try_mlflow_log_to_fail +@pytest.mark.usefixtures(test_mode_off.__name__) @pytest.mark.parametrize( "failing_specialization", [ @@ -871,7 +844,7 @@ def test_autolog_does_not_throw_when_parameter_search_logging_fails(failing_spec mock_func.assert_called_once() -@pytest.mark.disable_force_try_mlflow_log_to_fail +@pytest.mark.usefixtures(test_mode_off.__name__) @pytest.mark.parametrize( "func_to_fail", ["mlflow.log_params", "mlflow.log_metric", "mlflow.set_tags", "mlflow.sklearn.log_model"], From 94bc305c3da7f3c8b102f14ca4f6dbdf5523a212 Mon Sep 17 00:00:00 2001 From: Mohamad Arabi Date: Mon, 14 Dec 2020 10:06:47 -0800 Subject: [PATCH 16/22] add test case for before spark session Signed-off-by: Mohamad Arabi --- .../test_spark_datasource_autologging_order.py | 15 +++++++++++---- .../test_spark_disable_autologging.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/spark_autologging/test_spark_datasource_autologging_order.py b/tests/spark_autologging/test_spark_datasource_autologging_order.py index 6f640bfb14004..5b9b23a445823 100644 --- a/tests/spark_autologging/test_spark_datasource_autologging_order.py +++ b/tests/spark_autologging/test_spark_datasource_autologging_order.py @@ -11,12 +11,16 @@ from pyspark.sql.types import StructType, IntegerType, StructField from tests.spark_autologging.utils import _get_or_create_spark_session -from tests.spark_autologging.utils import _assert_spark_data_logged +from tests.spark_autologging.utils import ( + _assert_spark_data_logged, + _assert_spark_data_not_logged, +) @pytest.mark.large -def test_enabling_autologging_before_spark_session_works(): - mlflow.spark.autolog() +@pytest.mark.parametrize("disable", [False, True]) +def test_enabling_autologging_before_spark_session_works(disable): + mlflow.spark.autolog(disable=disable) # creating spark session AFTER autolog was enabled spark_session = _get_or_create_spark_session() @@ -42,7 +46,10 @@ def test_enabling_autologging_before_spark_session_works(): time.sleep(1) run = mlflow.get_run(run_id) - _assert_spark_data_logged(run=run, path=filepath, data_format="csv") + if disable: + _assert_spark_data_not_logged(run=run) + else: + _assert_spark_data_logged(run=run, path=filepath, data_format="csv") shutil.rmtree(tempdir) spark_session.stop() diff --git a/tests/spark_autologging/test_spark_disable_autologging.py b/tests/spark_autologging/test_spark_disable_autologging.py index 97b22eb5f31e3..ff4292e3dfa75 100644 --- a/tests/spark_autologging/test_spark_disable_autologging.py +++ b/tests/spark_autologging/test_spark_disable_autologging.py @@ -19,7 +19,7 @@ # Note that the following tests run one-after-the-other and operate on the SAME spark_session -# (it is not reset between tests) +# (it is not reset between tests, except for the last test case) @pytest.mark.large From 9eac9760207114d0cbbdd028e92ccb02d56ca303 Mon Sep 17 00:00:00 2001 From: Mohamad Arabi Date: Mon, 14 Dec 2020 10:19:36 -0800 Subject: [PATCH 17/22] unnecessary change Signed-off-by: Mohamad Arabi --- mlflow/utils/autologging_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mlflow/utils/autologging_utils.py b/mlflow/utils/autologging_utils.py index 5f8b33b148866..2328d369401cb 100644 --- a/mlflow/utils/autologging_utils.py +++ b/mlflow/utils/autologging_utils.py @@ -26,11 +26,6 @@ _logger = logging.getLogger(__name__) -# Dict mapping integration name to its config. -AUTOLOGGING_INTEGRATIONS = {} - -_logger = logging.getLogger(__name__) - def try_mlflow_log(fn, *args, **kwargs): """ From 0ec7400a928bf37db07b0a2959ff9ce98a652492 Mon Sep 17 00:00:00 2001 From: Mohamad Arabi Date: Mon, 14 Dec 2020 10:21:21 -0800 Subject: [PATCH 18/22] modify comment Signed-off-by: Mohamad Arabi --- tests/spark_autologging/test_spark_disable_autologging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spark_autologging/test_spark_disable_autologging.py b/tests/spark_autologging/test_spark_disable_autologging.py index ff4292e3dfa75..97b22eb5f31e3 100644 --- a/tests/spark_autologging/test_spark_disable_autologging.py +++ b/tests/spark_autologging/test_spark_disable_autologging.py @@ -19,7 +19,7 @@ # Note that the following tests run one-after-the-other and operate on the SAME spark_session -# (it is not reset between tests, except for the last test case) +# (it is not reset between tests) @pytest.mark.large From 4feeb118c877e318d6613493ab0266a6b850330b Mon Sep 17 00:00:00 2001 From: Mohamad Arabi Date: Mon, 14 Dec 2020 11:15:37 -0800 Subject: [PATCH 19/22] cannot assign FLAVOR_NAME in _spark_autolgging.py Signed-off-by: Mohamad Arabi --- mlflow/_spark_autologging.py | 5 ++--- tests/spark_autologging/test_spark_disable_autologging.py | 2 -- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/mlflow/_spark_autologging.py b/mlflow/_spark_autologging.py index 477a61e3e782a..0c575d82f643c 100644 --- a/mlflow/_spark_autologging.py +++ b/mlflow/_spark_autologging.py @@ -22,7 +22,6 @@ _JAVA_PACKAGE = "org.mlflow.spark.autologging" _SPARK_TABLE_INFO_TAG_NAME = "sparkDatasourceInfo" -FLAVOR_NAME = "spark" _logger = logging.getLogger(__name__) _lock = threading.Lock() @@ -221,7 +220,7 @@ def _notify(self, path, version, data_format): Method called by Scala SparkListener to propagate datasource read events to the current Python process """ - if autologging_is_disabled(FLAVOR_NAME): + if autologging_is_disabled("spark"): return # If there's an active run, simply set the tag on it # Note that there's a TOCTOU race condition here - active_run() here can actually throw @@ -255,7 +254,7 @@ def in_context(self): def tags(self): # if autologging is disabled, then short circuit `tags()` and return empty dict. - if autologging_is_disabled(FLAVOR_NAME): + if autologging_is_disabled("spark"): return {} with _lock: global _table_infos diff --git a/tests/spark_autologging/test_spark_disable_autologging.py b/tests/spark_autologging/test_spark_disable_autologging.py index 97b22eb5f31e3..781d03a394b94 100644 --- a/tests/spark_autologging/test_spark_disable_autologging.py +++ b/tests/spark_autologging/test_spark_disable_autologging.py @@ -5,8 +5,6 @@ import mlflow import mlflow.spark -from tests.tracking.test_rest_tracking import BACKEND_URIS -from tests.tracking.test_rest_tracking import tracking_server_uri # pylint: disable=unused-import from tests.tracking.test_rest_tracking import mlflow_client # pylint: disable=unused-import from tests.spark_autologging.utils import ( _assert_spark_data_logged, From cf7a1a0bca85f201d1337eaed0c70ebf782116a7 Mon Sep 17 00:00:00 2001 From: Mohamad Arabi Date: Mon, 14 Dec 2020 13:14:49 -0800 Subject: [PATCH 20/22] address final comments Signed-off-by: Mohamad Arabi --- mlflow/_spark_autologging.py | 5 +++-- mlflow/spark.py | 5 +++-- tests/test_flavors.py | 3 ++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/mlflow/_spark_autologging.py b/mlflow/_spark_autologging.py index 0c575d82f643c..4b672b44bce5e 100644 --- a/mlflow/_spark_autologging.py +++ b/mlflow/_spark_autologging.py @@ -19,6 +19,7 @@ wrap_patch, autologging_is_disabled, ) +from mlflow.spark import FLAVOR_NAME _JAVA_PACKAGE = "org.mlflow.spark.autologging" _SPARK_TABLE_INFO_TAG_NAME = "sparkDatasourceInfo" @@ -220,7 +221,7 @@ def _notify(self, path, version, data_format): Method called by Scala SparkListener to propagate datasource read events to the current Python process """ - if autologging_is_disabled("spark"): + if autologging_is_disabled(FLAVOR_NAME): return # If there's an active run, simply set the tag on it # Note that there's a TOCTOU race condition here - active_run() here can actually throw @@ -254,7 +255,7 @@ def in_context(self): def tags(self): # if autologging is disabled, then short circuit `tags()` and return empty dict. - if autologging_is_disabled("spark"): + if autologging_is_disabled(FLAVOR_NAME): return {} with _lock: global _table_infos diff --git a/mlflow/spark.py b/mlflow/spark.py index bfc24d7b9f656..2d7422baea177 100644 --- a/mlflow/spark.py +++ b/mlflow/spark.py @@ -631,7 +631,7 @@ def predict(self, pandas_df): @autologging_integration(FLAVOR_NAME) def autolog(disable=False): # pylint: disable=unused-argument """ - Enables automatic logging of Spark datasource paths, versions (if applicable), and formats + Enables (or disables) and configures logging of Spark datasource paths, versions (if applicable), and formats when they are read. This method is not threadsafe and assumes a `SparkSession `_ @@ -685,7 +685,8 @@ def autolog(disable=False): # pylint: disable=unused-argument with mlflow.start_run() as active_run: pandas_df = loaded_df.toPandas() - :param disable: Whether to enable or disable autologging. + :param disable: If ``True``, disables all supported autologging integrations. If ``False``, + enables all supported autologging integrations. """ from mlflow import _spark_autologging diff --git a/tests/test_flavors.py b/tests/test_flavors.py index 363f771836710..dfb274d7d7554 100644 --- a/tests/test_flavors.py +++ b/tests/test_flavors.py @@ -23,7 +23,8 @@ def is_model_flavor(src): def iter_flavor_names(): for root, _, files in os.walk("mlflow"): for f in files: - if not f.endswith(".py"): + is_private_module = f.startswith("_") and f != "__init__.py" + if not f.endswith(".py") or is_private_module: continue path = os.path.join(root, f) src = read_file(path) From a708ada5051718eb8d194fcdda4312b4323702e6 Mon Sep 17 00:00:00 2001 From: Mohamad Arabi Date: Mon, 14 Dec 2020 15:09:25 -0800 Subject: [PATCH 21/22] fix api documentation Signed-off-by: Mohamad Arabi --- mlflow/spark.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlflow/spark.py b/mlflow/spark.py index 2d7422baea177..7533b40448653 100644 --- a/mlflow/spark.py +++ b/mlflow/spark.py @@ -685,8 +685,8 @@ def autolog(disable=False): # pylint: disable=unused-argument with mlflow.start_run() as active_run: pandas_df = loaded_df.toPandas() - :param disable: If ``True``, disables all supported autologging integrations. If ``False``, - enables all supported autologging integrations. + :param disable: If ``True``, disables all supported autologging integrations. + If ``False``, enables all supported autologging integrations. """ from mlflow import _spark_autologging From bc59ab0c8a736fa8a39edf22d2144f35eeece763 Mon Sep 17 00:00:00 2001 From: Mohamad Arabi Date: Mon, 14 Dec 2020 16:05:38 -0800 Subject: [PATCH 22/22] fix api documentation II Signed-off-by: Mohamad Arabi --- mlflow/spark.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlflow/spark.py b/mlflow/spark.py index 7533b40448653..d7331c014a363 100644 --- a/mlflow/spark.py +++ b/mlflow/spark.py @@ -685,8 +685,8 @@ def autolog(disable=False): # pylint: disable=unused-argument with mlflow.start_run() as active_run: pandas_df = loaded_df.toPandas() - :param disable: If ``True``, disables all supported autologging integrations. - If ``False``, enables all supported autologging integrations. + :param disable: If ``True``, disables all supported autologging integrations. + If ``False``, enables all supported autologging integrations. """ from mlflow import _spark_autologging