diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index 511efa291ab62..451336a774754 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -157,7 +157,7 @@ def rename_experiment(self, experiment_id, new_name): """ self.store.rename_experiment(experiment_id, new_name) - def log_metric(self, run_id, key, value, timestamp=None): + def log_metric(self, run_id, key, value, timestamp=None, step=0): """ Log a metric against the run ID. If timestamp is not provided, uses the current timestamp. diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index ced34c3fbe87b..ec03c5ff009ba 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -60,6 +60,7 @@ class ActiveRun(Run): # pylint: disable=W0223 def __init__(self, run): Run.__init__(self, run.info, run.data) + self.step = 0 def __enter__(self): return self @@ -186,15 +187,22 @@ def set_tag(key, value): MlflowClient().set_tag(run_id, key, value) -def log_metric(key, value): +def log_metric(key, value, step=None): """ Log a metric under the current run, creating a run if necessary. :param key: Metric name (string). :param value: Metric value (float). """ - run_id = _get_or_start_run().info.run_uuid - MlflowClient().log_metric(run_id, key, value, int(time.time())) + run = _get_or_start_run() + if step is None: + step = run.step + MlflowClient().log_metric(run_id=run.info.run_uuid, + key=key, + value=value, + timestamp=int(time.time()), + step=step) + run.step = step + 1 # Or just `step`? def log_metrics(metrics):