From e63182b4effc6ab2b10e44153387246e0adc3e61 Mon Sep 17 00:00:00 2001 From: Corey Zumar Date: Mon, 11 May 2020 17:01:10 -0700 Subject: [PATCH 01/28] Add protos and compilation --- generate-protos.sh | 6 +- mlflow/protos/databricks_artifacts.proto | 113 +++++++++++++++++++++++ 2 files changed, 117 insertions(+), 2 deletions(-) create mode 100644 mlflow/protos/databricks_artifacts.proto diff --git a/generate-protos.sh b/generate-protos.sh index 36cf78f890a86..bcd286ecbf760 100755 --- a/generate-protos.sh +++ b/generate-protos.sh @@ -14,16 +14,18 @@ protoc -I="$PROTOS" \ "$PROTOS"/databricks.proto \ "$PROTOS"/service.proto \ "$PROTOS"/model_registry.proto \ + "$PROTOS"/databricks_artifacts.proto \ "$PROTOS"/scalapb/scalapb.proto OLD_SCALAPB="from scalapb import scalapb_pb2 as scalapb_dot_scalapb__pb2" NEW_SCALAPB="from .scalapb import scalapb_pb2 as scalapb_dot_scalapb__pb2" -sed -i'.old' -e "s/$OLD_SCALAPB/$NEW_SCALAPB/g" "$PROTOS/databricks_pb2.py" "$PROTOS/service_pb2.py" "$PROTOS/model_registry_pb2.py" +sed -i'.old' -e "s/$OLD_SCALAPB/$NEW_SCALAPB/g" "$PROTOS/databricks_pb2.py" "$PROTOS/service_pb2.py" "$PROTOS/model_registry_pb2.py" "$PROTOS/databricks_artifacts_pb2.py" OLD_DATABRICKS="import databricks_pb2 as databricks__pb2" NEW_DATABRICKS="from . import databricks_pb2 as databricks__pb2" -sed -i'.old' -e "s/$OLD_DATABRICKS/$NEW_DATABRICKS/g" "$PROTOS/service_pb2.py" "$PROTOS/model_registry_pb2.py" +sed -i'.old' -e "s/$OLD_DATABRICKS/$NEW_DATABRICKS/g" "$PROTOS/service_pb2.py" "$PROTOS/model_registry_pb2.py" "$PROTOS/databricks_artifacts_pb2.py" rm "$PROTOS/databricks_pb2.py.old" rm "$PROTOS/service_pb2.py.old" rm "$PROTOS/model_registry_pb2.py.old" +rm "$PROTOS/databricks_artifacts__pb2.py.old" diff --git a/mlflow/protos/databricks_artifacts.proto b/mlflow/protos/databricks_artifacts.proto new file mode 100644 index 0000000000000..6e516743b2fd6 --- /dev/null +++ b/mlflow/protos/databricks_artifacts.proto @@ -0,0 +1,113 @@ +syntax = "proto2"; + +package mlflow; + +import "scalapb/scalapb.proto"; +import "databricks.proto"; + +option java_package = "com.databricks.api.proto.mlflow"; +option java_generate_equals_and_hash = true; +option py_generic_services = true; +option (scalapb.options) = { + flat_package: true, +}; + +service DatabricksMlflowArtifactsService { + + // Fetch credentials to read from the specified MLflow artifact location + // + // Note: Even if no artifacts exist at the specified artifact location, this API will + // still provide read credentials as long as the format of the location is valid. + // Callers must subsequently check for the existence of the artifacts using the appropriate + // cloud storage APIs (as determined by the `ArtifactCredentialType` property of the response) + rpc getCredentialsForRead (GetCredentialsForRead) returns (GetCredentialsForRead.Response) { + option (rpc) = { + endpoints: [{ + method: "GET", + path: "/mlflow/artifacts/credentials-for-read" + since { major: 2, minor: 0 }, + }], + visibility: PUBLIC_UNDOCUMENTED, + }; + } + + // Fetch credentials to write to the specified MLflow artifact location + rpc getCredentialsForWrite (GetCredentialsForWrite) returns (GetCredentialsForWrite.Response) { + option (rpc) = { + endpoints: [{ + method: "GET", + path: "/mlflow/artifacts/credentials-for-write" + since { major: 2, minor: 0 }, + }], + visibility: PUBLIC_UNDOCUMENTED, + }; + } +} + +// The type of a given artifact access credential +enum ArtifactCredentialType { + + // The credential is an Azure Shared Access Signature URI. For more information, see + // https://docs.microsoft.com/en-us/azure/storage/common/storage-sas-overview + AZURE_SAS_URI = 1; + + // The credential is an AWS Presigned URL. For more information, see + // https://docs.aws.amazon.com/AmazonS3/latest/dev/ShareObjectPreSignedURL.html + AWS_PRESIGNED_URL = 2; + +} + +message ArtifactCredentialInfo { + + // The ID of the MLflow Run containing the artifact that can be accessed + // with the credential + optional string run_id = 1; + + // The path, relative to the Run's artifact root location, of the artifact + // that can be accessed with the credential + optional string path = 2; + + // The signed URI credential that provides access to the artifact + optional string signed_uri = 3; + + // The type of the signed credential URI (e.g., an AWS presigned URL + // or an Azure Shared Access Signature URI) + optional ArtifactCredentialType type = 4; + +} + +message GetCredentialsForRead { + option (scalapb.message).extends = "com.databricks.rpc.RPC[$this.Response]"; + + // The ID of the MLflow Run for which to fetch artifact read credentials + optional string run_id = 1; + + // The artifact path, relative to the Run's artifact root location, for which to + // fetch artifact read credentials + optional string path = 2; + + message Response { + + // Credentials for reading from the specified artifact location + optional ArtifactCredentialInfo credentials = 1; + + } +} + +message GetCredentialsForWrite { + option (scalapb.message).extends = "com.databricks.rpc.RPC[$this.Response]"; + + // The ID of the MLflow Run for which to fetch artifact write credentials + optional string run_id = 1; + + // The artifact path, relative to the Run's artifact root location, for which to + // fetch artifact write credentials + optional string path = 2; + + message Response { + + // Credentials for writing to the specified artifacts location + optional ArtifactCredentialInfo credentials = 1; + + } +} From cb69e267cc5c897cbef5dc93be67a35b8f70afc9 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Mon, 18 May 2020 23:09:47 -0700 Subject: [PATCH 02/28] Adding databricks_artifact_repo to store/artifact --- .../artifact/databricks_artifact_repo.py | 94 +++++++++++++++++++ mlflow/store/artifact/dbfs_artifact_repo.py | 6 +- mlflow/utils/databricks_utils.py | 1 - mlflow/utils/uri.py | 6 +- 4 files changed, 104 insertions(+), 3 deletions(-) create mode 100644 mlflow/store/artifact/databricks_artifact_repo.py diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py new file mode 100644 index 0000000000000..f94f9967caafb --- /dev/null +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -0,0 +1,94 @@ +from azure.storage.blob import BlobClient + +import os +from mlflow.exceptions import MlflowException +from mlflow.store.artifact.artifact_repo import ArtifactRepository +from mlflow.utils.string_utils import strip_suffix +from mlflow.utils.file_utils import relative_path_to_artifact_path +from mlflow.utils.rest_utils import call_endpoint, extract_api_info_for_service +from mlflow.protos.databricks_artifacts_pb2 import DatabricksMlflowArtifactsService +from mlflow.protos.databricks_artifacts_pb2 import GetCredentialsForWrite, GetCredentialsForRead +from mlflow.utils.databricks_utils import get_databricks_host_creds +from mlflow.protos.service_pb2 import MlflowService, ListArtifacts + +_PATH_PREFIX = "/api/2.0" + + +class DatabricksArtifactRepository(ArtifactRepository): + """ + SOMETHING : TYPING TILL IT WORKS LOL + """ + + def __init__(self, artifact_uri): + super(DatabricksArtifactRepository, self).__init__(artifact_uri) + + def _extract_run_id(self, artifact_uri): + return artifact_uri.lstrip('/').split('/')[4] + + def _call_endpoint(self, service, api, json_body): + _METHOD_TO_INFO = extract_api_info_for_service(service, _PATH_PREFIX) + endpoint, method = _METHOD_TO_INFO[api] + response_proto = api.Response() + return call_endpoint(get_databricks_host_creds(), endpoint, method, json_body, response_proto) + + def _create_json_body(self, run_id, path=None): + path = path or '.' + return { + "run_id": run_id, + "path": path + } + + def _get_azure_write_credentials(self, run_id, path=None): + return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForWrite, + self._create_json_body(run_id, path)) + + def _get_azure_read_credentials(self, run_id, path=None): + return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForRead, + self._create_json_body(run_id, path)) + + def _upload_file(self, local_file, artifact_path): + run_id = self._extract_run_id(self.artifact_uri) + write_credentials = self._get_azure_write_credentials(run_id, artifact_path) + signed_write_uri = write_credentials.credentials.signed_uri + service = BlobClient.from_blob_url(blob_url=signed_write_uri, credential=None) + try: + with open(local_file, "rb") as data: + service.upload_blob(data, overwrite=True) + except Exception as err: + raise MlflowException(err) + + def log_artifact(self, local_file, artifact_path=None): + self._upload_file(local_file, artifact_path) + + def log_artifacts(self, local_dir, artifact_path=None): + artifact_path = artifact_path or '' + basename = os.path.basename(strip_suffix(local_dir, '/')) + for (dirpath, _, filenames) in os.walk(local_dir): + artifact_subdir = basename + if dirpath != local_dir: + rel_path = os.path.relpath(dirpath, local_dir) + rel_path = relative_path_to_artifact_path(rel_path) + artifact_subdir = os.path.join(artifact_subdir, rel_path) + for name in filenames: + local_file = os.path.join(dirpath, name) + artifact_location = os.path.join(artifact_path, artifact_subdir) + self._upload_file(local_file, artifact_location) + + def list_artifacts(self, path=None): + run_id = self._extract_run_id(self.artifact_uri) + return self._call_endpoint(MlflowService, ListArtifacts, self._create_json_body(run_id, path)) + + def _download_file(self, remote_file_path, local_path): + run_id = self._extract_run_id(self.artifact_uri) + read_credentials = self._get_azure_read_credentials(run_id, remote_file_path) + signed_read_uri = read_credentials.credentials.signed_uri + service = BlobClient.from_blob_url(blob_url=signed_read_uri, credential=None) + try: + with open(local_path, "wb") as output_file: + blob = service.download_blob() + output_file.write(blob.readall()) + except Exception as err: + raise MlflowException(err) + + def delete_artifacts(self, artifact_path=None): + raise MlflowException('Not implemented yet') diff --git a/mlflow/store/artifact/dbfs_artifact_repo.py b/mlflow/store/artifact/dbfs_artifact_repo.py index 4b213a24f9a2e..4600a0f9741f9 100644 --- a/mlflow/store/artifact/dbfs_artifact_repo.py +++ b/mlflow/store/artifact/dbfs_artifact_repo.py @@ -6,12 +6,14 @@ from mlflow.exceptions import MlflowException from mlflow.store.tracking.rest_store import RestStore from mlflow.store.artifact.artifact_repo import ArtifactRepository +from mlflow.store.artifact.databricks_artifact_repo import DatabricksArtifactRepository from mlflow.store.artifact.local_artifact_repo import LocalArtifactRepository from mlflow.tracking._tracking_service import utils from mlflow.utils.file_utils import relative_path_to_artifact_path from mlflow.utils.rest_utils import http_request, http_request_safe, RESOURCE_DOES_NOT_EXIST from mlflow.utils.string_utils import strip_prefix import mlflow.utils.databricks_utils +from mlflow.utils.uri import is_artifact_acled_uri LIST_API_ENDPOINT = '/api/2.0/dbfs/list' GET_STATUS_ENDPOINT = '/api/2.0/dbfs/get-status' @@ -163,7 +165,9 @@ def dbfs_artifact_repo_factory(artifact_uri): :return: Subclass of ArtifactRepository capable of storing artifacts on DBFS. """ cleaned_artifact_uri = artifact_uri.rstrip('/') - if mlflow.utils.databricks_utils.is_dbfs_fuse_available() \ + if is_artifact_acled_uri(artifact_uri): + return DatabricksArtifactRepository(artifact_uri) + elif mlflow.utils.databricks_utils.is_dbfs_fuse_available() \ and os.environ.get(USE_FUSE_ENV_VAR, "").lower() != "false" \ and not artifact_uri.startswith("dbfs:/databricks/mlflow-registry"): # If the DBFS FUSE mount is available, write artifacts directly to /dbfs/... using diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index 926eaee4c1bf1..6e51e812a6102 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -6,7 +6,6 @@ from mlflow.utils.rest_utils import MlflowHostCreds from databricks_cli.configure import provider - _logger = logging.getLogger(__name__) diff --git a/mlflow/utils/uri.py b/mlflow/utils/uri.py index 5d9cfe09474be..ec7a4328c24c5 100644 --- a/mlflow/utils/uri.py +++ b/mlflow/utils/uri.py @@ -8,7 +8,7 @@ _INVALID_DB_URI_MSG = "Please refer to https://mlflow.org/docs/latest/tracking.html#storage for " \ "format specifications." - +_ACLED_ARTIFACT_URI = "dbfs:/databricks/mlflow-tracking/" def is_local_uri(uri): """Returns true if this is a local file path (/foo or file:/foo).""" @@ -129,3 +129,7 @@ def _join_posixpaths_and_append_absolute_suffixes(prefix_path, suffix_path): # joined path suffix_path = suffix_path.lstrip(posixpath.sep) return posixpath.join(prefix_path, suffix_path) + + +def is_artifact_acled_uri(artifact_uri): + return artifact_uri.startswith(_ACLED_ARTIFACT_URI.lstrip('/')) From 3af5d3b25606e6cb05cd8d5acc70e787f712c848 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Tue, 19 May 2020 14:28:18 -0700 Subject: [PATCH 03/28] Addressing comments --- .../api/proto/mlflow/DatabricksArtifacts.java | 4535 +++++++++++++++++ mlflow/protos/databricks_artifacts_pb2.py | 342 ++ .../artifact/databricks_artifact_repo.py | 69 +- mlflow/store/artifact/dbfs_artifact_repo.py | 4 +- mlflow/utils/uri.py | 14 +- 5 files changed, 4937 insertions(+), 27 deletions(-) create mode 100644 mlflow/java/client/src/main/java/com/databricks/api/proto/mlflow/DatabricksArtifacts.java create mode 100644 mlflow/protos/databricks_artifacts_pb2.py diff --git a/mlflow/java/client/src/main/java/com/databricks/api/proto/mlflow/DatabricksArtifacts.java b/mlflow/java/client/src/main/java/com/databricks/api/proto/mlflow/DatabricksArtifacts.java new file mode 100644 index 0000000000000..e94afa768bd0b --- /dev/null +++ b/mlflow/java/client/src/main/java/com/databricks/api/proto/mlflow/DatabricksArtifacts.java @@ -0,0 +1,4535 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: databricks_artifacts.proto + +package com.databricks.api.proto.mlflow; + +public final class DatabricksArtifacts { + private DatabricksArtifacts() {} + public static void registerAllExtensions( + com.google.protobuf.ExtensionRegistryLite registry) { + } + + public static void registerAllExtensions( + com.google.protobuf.ExtensionRegistry registry) { + registerAllExtensions( + (com.google.protobuf.ExtensionRegistryLite) registry); + } + /** + *
+   * The type of a given artifact access credential
+   * 
+ * + * Protobuf enum {@code mlflow.ArtifactCredentialType} + */ + public enum ArtifactCredentialType + implements com.google.protobuf.ProtocolMessageEnum { + /** + *
+     * The credential is an Azure Shared Access Signature URI. For more information, see
+     * https://docs.microsoft.com/en-us/azure/storage/common/storage-sas-overview
+     * 
+ * + * AZURE_SAS_URI = 1; + */ + AZURE_SAS_URI(1), + /** + *
+     * The credential is an AWS Presigned URL. For more information, see
+     * https://docs.aws.amazon.com/AmazonS3/latest/dev/ShareObjectPreSignedURL.html
+     * 
+ * + * AWS_PRESIGNED_URL = 2; + */ + AWS_PRESIGNED_URL(2), + ; + + /** + *
+     * The credential is an Azure Shared Access Signature URI. For more information, see
+     * https://docs.microsoft.com/en-us/azure/storage/common/storage-sas-overview
+     * 
+ * + * AZURE_SAS_URI = 1; + */ + public static final int AZURE_SAS_URI_VALUE = 1; + /** + *
+     * The credential is an AWS Presigned URL. For more information, see
+     * https://docs.aws.amazon.com/AmazonS3/latest/dev/ShareObjectPreSignedURL.html
+     * 
+ * + * AWS_PRESIGNED_URL = 2; + */ + public static final int AWS_PRESIGNED_URL_VALUE = 2; + + + public final int getNumber() { + return value; + } + + /** + * @deprecated Use {@link #forNumber(int)} instead. + */ + @java.lang.Deprecated + public static ArtifactCredentialType valueOf(int value) { + return forNumber(value); + } + + public static ArtifactCredentialType forNumber(int value) { + switch (value) { + case 1: return AZURE_SAS_URI; + case 2: return AWS_PRESIGNED_URL; + default: return null; + } + } + + public static com.google.protobuf.Internal.EnumLiteMap + internalGetValueMap() { + return internalValueMap; + } + private static final com.google.protobuf.Internal.EnumLiteMap< + ArtifactCredentialType> internalValueMap = + new com.google.protobuf.Internal.EnumLiteMap() { + public ArtifactCredentialType findValueByNumber(int number) { + return ArtifactCredentialType.forNumber(number); + } + }; + + public final com.google.protobuf.Descriptors.EnumValueDescriptor + getValueDescriptor() { + return getDescriptor().getValues().get(ordinal()); + } + public final com.google.protobuf.Descriptors.EnumDescriptor + getDescriptorForType() { + return getDescriptor(); + } + public static final com.google.protobuf.Descriptors.EnumDescriptor + getDescriptor() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.getDescriptor().getEnumTypes().get(0); + } + + private static final ArtifactCredentialType[] VALUES = values(); + + public static ArtifactCredentialType valueOf( + com.google.protobuf.Descriptors.EnumValueDescriptor desc) { + if (desc.getType() != getDescriptor()) { + throw new java.lang.IllegalArgumentException( + "EnumValueDescriptor is not for this type."); + } + return VALUES[desc.getIndex()]; + } + + private final int value; + + private ArtifactCredentialType(int value) { + this.value = value; + } + + // @@protoc_insertion_point(enum_scope:mlflow.ArtifactCredentialType) + } + + public interface ArtifactCredentialInfoOrBuilder extends + // @@protoc_insertion_point(interface_extends:mlflow.ArtifactCredentialInfo) + com.google.protobuf.MessageOrBuilder { + + /** + *
+     * The ID of the MLflow Run containing the artifact that can be accessed
+     * with the credential
+     * 
+ * + * optional string run_id = 1; + */ + boolean hasRunId(); + /** + *
+     * The ID of the MLflow Run containing the artifact that can be accessed
+     * with the credential
+     * 
+ * + * optional string run_id = 1; + */ + java.lang.String getRunId(); + /** + *
+     * The ID of the MLflow Run containing the artifact that can be accessed
+     * with the credential
+     * 
+ * + * optional string run_id = 1; + */ + com.google.protobuf.ByteString + getRunIdBytes(); + + /** + *
+     * The path, relative to the Run's artifact root location, of the artifact
+     * that can be accessed with the credential
+     * 
+ * + * optional string path = 2; + */ + boolean hasPath(); + /** + *
+     * The path, relative to the Run's artifact root location, of the artifact
+     * that can be accessed with the credential
+     * 
+ * + * optional string path = 2; + */ + java.lang.String getPath(); + /** + *
+     * The path, relative to the Run's artifact root location, of the artifact
+     * that can be accessed with the credential
+     * 
+ * + * optional string path = 2; + */ + com.google.protobuf.ByteString + getPathBytes(); + + /** + *
+     * The signed URI credential that provides access to the artifact
+     * 
+ * + * optional string signed_uri = 3; + */ + boolean hasSignedUri(); + /** + *
+     * The signed URI credential that provides access to the artifact
+     * 
+ * + * optional string signed_uri = 3; + */ + java.lang.String getSignedUri(); + /** + *
+     * The signed URI credential that provides access to the artifact
+     * 
+ * + * optional string signed_uri = 3; + */ + com.google.protobuf.ByteString + getSignedUriBytes(); + + /** + *
+     * The type of the signed credential URI (e.g., an AWS presigned URL
+     * or an Azure Shared Access Signature URI)
+     * 
+ * + * optional .mlflow.ArtifactCredentialType type = 4; + */ + boolean hasType(); + /** + *
+     * The type of the signed credential URI (e.g., an AWS presigned URL
+     * or an Azure Shared Access Signature URI)
+     * 
+ * + * optional .mlflow.ArtifactCredentialType type = 4; + */ + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType getType(); + } + /** + * Protobuf type {@code mlflow.ArtifactCredentialInfo} + */ + public static final class ArtifactCredentialInfo extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:mlflow.ArtifactCredentialInfo) + ArtifactCredentialInfoOrBuilder { + private static final long serialVersionUID = 0L; + // Use ArtifactCredentialInfo.newBuilder() to construct. + private ArtifactCredentialInfo(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private ArtifactCredentialInfo() { + runId_ = ""; + path_ = ""; + signedUri_ = ""; + type_ = 1; + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + private ArtifactCredentialInfo( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + com.google.protobuf.ByteString bs = input.readBytes(); + bitField0_ |= 0x00000001; + runId_ = bs; + break; + } + case 18: { + com.google.protobuf.ByteString bs = input.readBytes(); + bitField0_ |= 0x00000002; + path_ = bs; + break; + } + case 26: { + com.google.protobuf.ByteString bs = input.readBytes(); + bitField0_ |= 0x00000004; + signedUri_ = bs; + break; + } + case 32: { + int rawValue = input.readEnum(); + @SuppressWarnings("deprecation") + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType value = com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType.valueOf(rawValue); + if (value == null) { + unknownFields.mergeVarintField(4, rawValue); + } else { + bitField0_ |= 0x00000008; + type_ = rawValue; + } + break; + } + default: { + if (!parseUnknownField( + input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException( + e).setUnfinishedMessage(this); + } finally { + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_ArtifactCredentialInfo_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_ArtifactCredentialInfo_fieldAccessorTable + .ensureFieldAccessorsInitialized( + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.class, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder.class); + } + + private int bitField0_; + public static final int RUN_ID_FIELD_NUMBER = 1; + private volatile java.lang.Object runId_; + /** + *
+     * The ID of the MLflow Run containing the artifact that can be accessed
+     * with the credential
+     * 
+ * + * optional string run_id = 1; + */ + public boolean hasRunId() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + *
+     * The ID of the MLflow Run containing the artifact that can be accessed
+     * with the credential
+     * 
+ * + * optional string run_id = 1; + */ + public java.lang.String getRunId() { + java.lang.Object ref = runId_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + runId_ = s; + } + return s; + } + } + /** + *
+     * The ID of the MLflow Run containing the artifact that can be accessed
+     * with the credential
+     * 
+ * + * optional string run_id = 1; + */ + public com.google.protobuf.ByteString + getRunIdBytes() { + java.lang.Object ref = runId_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + runId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + public static final int PATH_FIELD_NUMBER = 2; + private volatile java.lang.Object path_; + /** + *
+     * The path, relative to the Run's artifact root location, of the artifact
+     * that can be accessed with the credential
+     * 
+ * + * optional string path = 2; + */ + public boolean hasPath() { + return ((bitField0_ & 0x00000002) == 0x00000002); + } + /** + *
+     * The path, relative to the Run's artifact root location, of the artifact
+     * that can be accessed with the credential
+     * 
+ * + * optional string path = 2; + */ + public java.lang.String getPath() { + java.lang.Object ref = path_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + path_ = s; + } + return s; + } + } + /** + *
+     * The path, relative to the Run's artifact root location, of the artifact
+     * that can be accessed with the credential
+     * 
+ * + * optional string path = 2; + */ + public com.google.protobuf.ByteString + getPathBytes() { + java.lang.Object ref = path_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + path_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + public static final int SIGNED_URI_FIELD_NUMBER = 3; + private volatile java.lang.Object signedUri_; + /** + *
+     * The signed URI credential that provides access to the artifact
+     * 
+ * + * optional string signed_uri = 3; + */ + public boolean hasSignedUri() { + return ((bitField0_ & 0x00000004) == 0x00000004); + } + /** + *
+     * The signed URI credential that provides access to the artifact
+     * 
+ * + * optional string signed_uri = 3; + */ + public java.lang.String getSignedUri() { + java.lang.Object ref = signedUri_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + signedUri_ = s; + } + return s; + } + } + /** + *
+     * The signed URI credential that provides access to the artifact
+     * 
+ * + * optional string signed_uri = 3; + */ + public com.google.protobuf.ByteString + getSignedUriBytes() { + java.lang.Object ref = signedUri_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + signedUri_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + public static final int TYPE_FIELD_NUMBER = 4; + private int type_; + /** + *
+     * The type of the signed credential URI (e.g., an AWS presigned URL
+     * or an Azure Shared Access Signature URI)
+     * 
+ * + * optional .mlflow.ArtifactCredentialType type = 4; + */ + public boolean hasType() { + return ((bitField0_ & 0x00000008) == 0x00000008); + } + /** + *
+     * The type of the signed credential URI (e.g., an AWS presigned URL
+     * or an Azure Shared Access Signature URI)
+     * 
+ * + * optional .mlflow.ArtifactCredentialType type = 4; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType getType() { + @SuppressWarnings("deprecation") + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType result = com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType.valueOf(type_); + return result == null ? com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType.AZURE_SAS_URI : result; + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (((bitField0_ & 0x00000001) == 0x00000001)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, runId_); + } + if (((bitField0_ & 0x00000002) == 0x00000002)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 2, path_); + } + if (((bitField0_ & 0x00000004) == 0x00000004)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 3, signedUri_); + } + if (((bitField0_ & 0x00000008) == 0x00000008)) { + output.writeEnum(4, type_); + } + unknownFields.writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (((bitField0_ & 0x00000001) == 0x00000001)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, runId_); + } + if (((bitField0_ & 0x00000002) == 0x00000002)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(2, path_); + } + if (((bitField0_ & 0x00000004) == 0x00000004)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(3, signedUri_); + } + if (((bitField0_ & 0x00000008) == 0x00000008)) { + size += com.google.protobuf.CodedOutputStream + .computeEnumSize(4, type_); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo)) { + return super.equals(obj); + } + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo other = (com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo) obj; + + boolean result = true; + result = result && (hasRunId() == other.hasRunId()); + if (hasRunId()) { + result = result && getRunId() + .equals(other.getRunId()); + } + result = result && (hasPath() == other.hasPath()); + if (hasPath()) { + result = result && getPath() + .equals(other.getPath()); + } + result = result && (hasSignedUri() == other.hasSignedUri()); + if (hasSignedUri()) { + result = result && getSignedUri() + .equals(other.getSignedUri()); + } + result = result && (hasType() == other.hasType()); + if (hasType()) { + result = result && type_ == other.type_; + } + result = result && unknownFields.equals(other.unknownFields); + return result; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (hasRunId()) { + hash = (37 * hash) + RUN_ID_FIELD_NUMBER; + hash = (53 * hash) + getRunId().hashCode(); + } + if (hasPath()) { + hash = (37 * hash) + PATH_FIELD_NUMBER; + hash = (53 * hash) + getPath().hashCode(); + } + if (hasSignedUri()) { + hash = (37 * hash) + SIGNED_URI_FIELD_NUMBER; + hash = (53 * hash) + getSignedUri().hashCode(); + } + if (hasType()) { + hash = (37 * hash) + TYPE_FIELD_NUMBER; + hash = (53 * hash) + type_; + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code mlflow.ArtifactCredentialInfo} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:mlflow.ArtifactCredentialInfo) + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfoOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_ArtifactCredentialInfo_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_ArtifactCredentialInfo_fieldAccessorTable + .ensureFieldAccessorsInitialized( + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.class, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder.class); + } + + // Construct using com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3 + .alwaysUseFieldBuilders) { + } + } + @java.lang.Override + public Builder clear() { + super.clear(); + runId_ = ""; + bitField0_ = (bitField0_ & ~0x00000001); + path_ = ""; + bitField0_ = (bitField0_ & ~0x00000002); + signedUri_ = ""; + bitField0_ = (bitField0_ & ~0x00000004); + type_ = 1; + bitField0_ = (bitField0_ & ~0x00000008); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_ArtifactCredentialInfo_descriptor; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo getDefaultInstanceForType() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.getDefaultInstance(); + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo build() { + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo buildPartial() { + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo result = new com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo(this); + int from_bitField0_ = bitField0_; + int to_bitField0_ = 0; + if (((from_bitField0_ & 0x00000001) == 0x00000001)) { + to_bitField0_ |= 0x00000001; + } + result.runId_ = runId_; + if (((from_bitField0_ & 0x00000002) == 0x00000002)) { + to_bitField0_ |= 0x00000002; + } + result.path_ = path_; + if (((from_bitField0_ & 0x00000004) == 0x00000004)) { + to_bitField0_ |= 0x00000004; + } + result.signedUri_ = signedUri_; + if (((from_bitField0_ & 0x00000008) == 0x00000008)) { + to_bitField0_ |= 0x00000008; + } + result.type_ = type_; + result.bitField0_ = to_bitField0_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return (Builder) super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return (Builder) super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return (Builder) super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return (Builder) super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return (Builder) super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return (Builder) super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo) { + return mergeFrom((com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo other) { + if (other == com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.getDefaultInstance()) return this; + if (other.hasRunId()) { + bitField0_ |= 0x00000001; + runId_ = other.runId_; + onChanged(); + } + if (other.hasPath()) { + bitField0_ |= 0x00000002; + path_ = other.path_; + onChanged(); + } + if (other.hasSignedUri()) { + bitField0_ |= 0x00000004; + signedUri_ = other.signedUri_; + onChanged(); + } + if (other.hasType()) { + setType(other.getType()); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + private int bitField0_; + + private java.lang.Object runId_ = ""; + /** + *
+       * The ID of the MLflow Run containing the artifact that can be accessed
+       * with the credential
+       * 
+ * + * optional string run_id = 1; + */ + public boolean hasRunId() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + *
+       * The ID of the MLflow Run containing the artifact that can be accessed
+       * with the credential
+       * 
+ * + * optional string run_id = 1; + */ + public java.lang.String getRunId() { + java.lang.Object ref = runId_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + runId_ = s; + } + return s; + } else { + return (java.lang.String) ref; + } + } + /** + *
+       * The ID of the MLflow Run containing the artifact that can be accessed
+       * with the credential
+       * 
+ * + * optional string run_id = 1; + */ + public com.google.protobuf.ByteString + getRunIdBytes() { + java.lang.Object ref = runId_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + runId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + *
+       * The ID of the MLflow Run containing the artifact that can be accessed
+       * with the credential
+       * 
+ * + * optional string run_id = 1; + */ + public Builder setRunId( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000001; + runId_ = value; + onChanged(); + return this; + } + /** + *
+       * The ID of the MLflow Run containing the artifact that can be accessed
+       * with the credential
+       * 
+ * + * optional string run_id = 1; + */ + public Builder clearRunId() { + bitField0_ = (bitField0_ & ~0x00000001); + runId_ = getDefaultInstance().getRunId(); + onChanged(); + return this; + } + /** + *
+       * The ID of the MLflow Run containing the artifact that can be accessed
+       * with the credential
+       * 
+ * + * optional string run_id = 1; + */ + public Builder setRunIdBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000001; + runId_ = value; + onChanged(); + return this; + } + + private java.lang.Object path_ = ""; + /** + *
+       * The path, relative to the Run's artifact root location, of the artifact
+       * that can be accessed with the credential
+       * 
+ * + * optional string path = 2; + */ + public boolean hasPath() { + return ((bitField0_ & 0x00000002) == 0x00000002); + } + /** + *
+       * The path, relative to the Run's artifact root location, of the artifact
+       * that can be accessed with the credential
+       * 
+ * + * optional string path = 2; + */ + public java.lang.String getPath() { + java.lang.Object ref = path_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + path_ = s; + } + return s; + } else { + return (java.lang.String) ref; + } + } + /** + *
+       * The path, relative to the Run's artifact root location, of the artifact
+       * that can be accessed with the credential
+       * 
+ * + * optional string path = 2; + */ + public com.google.protobuf.ByteString + getPathBytes() { + java.lang.Object ref = path_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + path_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + *
+       * The path, relative to the Run's artifact root location, of the artifact
+       * that can be accessed with the credential
+       * 
+ * + * optional string path = 2; + */ + public Builder setPath( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000002; + path_ = value; + onChanged(); + return this; + } + /** + *
+       * The path, relative to the Run's artifact root location, of the artifact
+       * that can be accessed with the credential
+       * 
+ * + * optional string path = 2; + */ + public Builder clearPath() { + bitField0_ = (bitField0_ & ~0x00000002); + path_ = getDefaultInstance().getPath(); + onChanged(); + return this; + } + /** + *
+       * The path, relative to the Run's artifact root location, of the artifact
+       * that can be accessed with the credential
+       * 
+ * + * optional string path = 2; + */ + public Builder setPathBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000002; + path_ = value; + onChanged(); + return this; + } + + private java.lang.Object signedUri_ = ""; + /** + *
+       * The signed URI credential that provides access to the artifact
+       * 
+ * + * optional string signed_uri = 3; + */ + public boolean hasSignedUri() { + return ((bitField0_ & 0x00000004) == 0x00000004); + } + /** + *
+       * The signed URI credential that provides access to the artifact
+       * 
+ * + * optional string signed_uri = 3; + */ + public java.lang.String getSignedUri() { + java.lang.Object ref = signedUri_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + signedUri_ = s; + } + return s; + } else { + return (java.lang.String) ref; + } + } + /** + *
+       * The signed URI credential that provides access to the artifact
+       * 
+ * + * optional string signed_uri = 3; + */ + public com.google.protobuf.ByteString + getSignedUriBytes() { + java.lang.Object ref = signedUri_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + signedUri_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + *
+       * The signed URI credential that provides access to the artifact
+       * 
+ * + * optional string signed_uri = 3; + */ + public Builder setSignedUri( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000004; + signedUri_ = value; + onChanged(); + return this; + } + /** + *
+       * The signed URI credential that provides access to the artifact
+       * 
+ * + * optional string signed_uri = 3; + */ + public Builder clearSignedUri() { + bitField0_ = (bitField0_ & ~0x00000004); + signedUri_ = getDefaultInstance().getSignedUri(); + onChanged(); + return this; + } + /** + *
+       * The signed URI credential that provides access to the artifact
+       * 
+ * + * optional string signed_uri = 3; + */ + public Builder setSignedUriBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000004; + signedUri_ = value; + onChanged(); + return this; + } + + private int type_ = 1; + /** + *
+       * The type of the signed credential URI (e.g., an AWS presigned URL
+       * or an Azure Shared Access Signature URI)
+       * 
+ * + * optional .mlflow.ArtifactCredentialType type = 4; + */ + public boolean hasType() { + return ((bitField0_ & 0x00000008) == 0x00000008); + } + /** + *
+       * The type of the signed credential URI (e.g., an AWS presigned URL
+       * or an Azure Shared Access Signature URI)
+       * 
+ * + * optional .mlflow.ArtifactCredentialType type = 4; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType getType() { + @SuppressWarnings("deprecation") + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType result = com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType.valueOf(type_); + return result == null ? com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType.AZURE_SAS_URI : result; + } + /** + *
+       * The type of the signed credential URI (e.g., an AWS presigned URL
+       * or an Azure Shared Access Signature URI)
+       * 
+ * + * optional .mlflow.ArtifactCredentialType type = 4; + */ + public Builder setType(com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000008; + type_ = value.getNumber(); + onChanged(); + return this; + } + /** + *
+       * The type of the signed credential URI (e.g., an AWS presigned URL
+       * or an Azure Shared Access Signature URI)
+       * 
+ * + * optional .mlflow.ArtifactCredentialType type = 4; + */ + public Builder clearType() { + bitField0_ = (bitField0_ & ~0x00000008); + type_ = 1; + onChanged(); + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:mlflow.ArtifactCredentialInfo) + } + + // @@protoc_insertion_point(class_scope:mlflow.ArtifactCredentialInfo) + private static final com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo(); + } + + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + @java.lang.Deprecated public static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public ArtifactCredentialInfo parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new ArtifactCredentialInfo(input, extensionRegistry); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface GetCredentialsForReadOrBuilder extends + // @@protoc_insertion_point(interface_extends:mlflow.GetCredentialsForRead) + com.google.protobuf.MessageOrBuilder { + + /** + *
+     * The ID of the MLflow Run for which to fetch artifact read credentials
+     * 
+ * + * optional string run_id = 1; + */ + boolean hasRunId(); + /** + *
+     * The ID of the MLflow Run for which to fetch artifact read credentials
+     * 
+ * + * optional string run_id = 1; + */ + java.lang.String getRunId(); + /** + *
+     * The ID of the MLflow Run for which to fetch artifact read credentials
+     * 
+ * + * optional string run_id = 1; + */ + com.google.protobuf.ByteString + getRunIdBytes(); + + /** + *
+     * The artifact path, relative to the Run's artifact root location, for which to
+     * fetch artifact read credentials
+     * 
+ * + * optional string path = 2; + */ + boolean hasPath(); + /** + *
+     * The artifact path, relative to the Run's artifact root location, for which to
+     * fetch artifact read credentials
+     * 
+ * + * optional string path = 2; + */ + java.lang.String getPath(); + /** + *
+     * The artifact path, relative to the Run's artifact root location, for which to
+     * fetch artifact read credentials
+     * 
+ * + * optional string path = 2; + */ + com.google.protobuf.ByteString + getPathBytes(); + } + /** + * Protobuf type {@code mlflow.GetCredentialsForRead} + */ + public static final class GetCredentialsForRead extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:mlflow.GetCredentialsForRead) + GetCredentialsForReadOrBuilder { + private static final long serialVersionUID = 0L; + // Use GetCredentialsForRead.newBuilder() to construct. + private GetCredentialsForRead(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private GetCredentialsForRead() { + runId_ = ""; + path_ = ""; + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + private GetCredentialsForRead( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + com.google.protobuf.ByteString bs = input.readBytes(); + bitField0_ |= 0x00000001; + runId_ = bs; + break; + } + case 18: { + com.google.protobuf.ByteString bs = input.readBytes(); + bitField0_ |= 0x00000002; + path_ = bs; + break; + } + default: { + if (!parseUnknownField( + input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException( + e).setUnfinishedMessage(this); + } finally { + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForRead_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForRead_fieldAccessorTable + .ensureFieldAccessorsInitialized( + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.class, com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Builder.class); + } + + public interface ResponseOrBuilder extends + // @@protoc_insertion_point(interface_extends:mlflow.GetCredentialsForRead.Response) + com.google.protobuf.MessageOrBuilder { + + /** + *
+       * Credentials for reading from the specified artifact location
+       * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + boolean hasCredentials(); + /** + *
+       * Credentials for reading from the specified artifact location
+       * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo getCredentials(); + /** + *
+       * Credentials for reading from the specified artifact location
+       * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfoOrBuilder getCredentialsOrBuilder(); + } + /** + * Protobuf type {@code mlflow.GetCredentialsForRead.Response} + */ + public static final class Response extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:mlflow.GetCredentialsForRead.Response) + ResponseOrBuilder { + private static final long serialVersionUID = 0L; + // Use Response.newBuilder() to construct. + private Response(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private Response() { + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + private Response( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder subBuilder = null; + if (((bitField0_ & 0x00000001) == 0x00000001)) { + subBuilder = credentials_.toBuilder(); + } + credentials_ = input.readMessage(com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.PARSER, extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom(credentials_); + credentials_ = subBuilder.buildPartial(); + } + bitField0_ |= 0x00000001; + break; + } + default: { + if (!parseUnknownField( + input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException( + e).setUnfinishedMessage(this); + } finally { + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForRead_Response_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForRead_Response_fieldAccessorTable + .ensureFieldAccessorsInitialized( + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response.class, com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response.Builder.class); + } + + private int bitField0_; + public static final int CREDENTIALS_FIELD_NUMBER = 1; + private com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo credentials_; + /** + *
+       * Credentials for reading from the specified artifact location
+       * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public boolean hasCredentials() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + *
+       * Credentials for reading from the specified artifact location
+       * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo getCredentials() { + return credentials_ == null ? com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.getDefaultInstance() : credentials_; + } + /** + *
+       * Credentials for reading from the specified artifact location
+       * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfoOrBuilder getCredentialsOrBuilder() { + return credentials_ == null ? com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.getDefaultInstance() : credentials_; + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (((bitField0_ & 0x00000001) == 0x00000001)) { + output.writeMessage(1, getCredentials()); + } + unknownFields.writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (((bitField0_ & 0x00000001) == 0x00000001)) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(1, getCredentials()); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response)) { + return super.equals(obj); + } + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response other = (com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response) obj; + + boolean result = true; + result = result && (hasCredentials() == other.hasCredentials()); + if (hasCredentials()) { + result = result && getCredentials() + .equals(other.getCredentials()); + } + result = result && unknownFields.equals(other.unknownFields); + return result; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (hasCredentials()) { + hash = (37 * hash) + CREDENTIALS_FIELD_NUMBER; + hash = (53 * hash) + getCredentials().hashCode(); + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code mlflow.GetCredentialsForRead.Response} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:mlflow.GetCredentialsForRead.Response) + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.ResponseOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForRead_Response_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForRead_Response_fieldAccessorTable + .ensureFieldAccessorsInitialized( + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response.class, com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response.Builder.class); + } + + // Construct using com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3 + .alwaysUseFieldBuilders) { + getCredentialsFieldBuilder(); + } + } + @java.lang.Override + public Builder clear() { + super.clear(); + if (credentialsBuilder_ == null) { + credentials_ = null; + } else { + credentialsBuilder_.clear(); + } + bitField0_ = (bitField0_ & ~0x00000001); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForRead_Response_descriptor; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response getDefaultInstanceForType() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response.getDefaultInstance(); + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response build() { + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response buildPartial() { + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response result = new com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response(this); + int from_bitField0_ = bitField0_; + int to_bitField0_ = 0; + if (((from_bitField0_ & 0x00000001) == 0x00000001)) { + to_bitField0_ |= 0x00000001; + } + if (credentialsBuilder_ == null) { + result.credentials_ = credentials_; + } else { + result.credentials_ = credentialsBuilder_.build(); + } + result.bitField0_ = to_bitField0_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return (Builder) super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return (Builder) super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return (Builder) super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return (Builder) super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return (Builder) super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return (Builder) super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response) { + return mergeFrom((com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response other) { + if (other == com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response.getDefaultInstance()) return this; + if (other.hasCredentials()) { + mergeCredentials(other.getCredentials()); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + private int bitField0_; + + private com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo credentials_ = null; + private com.google.protobuf.SingleFieldBuilderV3< + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfoOrBuilder> credentialsBuilder_; + /** + *
+         * Credentials for reading from the specified artifact location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public boolean hasCredentials() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + *
+         * Credentials for reading from the specified artifact location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo getCredentials() { + if (credentialsBuilder_ == null) { + return credentials_ == null ? com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.getDefaultInstance() : credentials_; + } else { + return credentialsBuilder_.getMessage(); + } + } + /** + *
+         * Credentials for reading from the specified artifact location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public Builder setCredentials(com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo value) { + if (credentialsBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + credentials_ = value; + onChanged(); + } else { + credentialsBuilder_.setMessage(value); + } + bitField0_ |= 0x00000001; + return this; + } + /** + *
+         * Credentials for reading from the specified artifact location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public Builder setCredentials( + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder builderForValue) { + if (credentialsBuilder_ == null) { + credentials_ = builderForValue.build(); + onChanged(); + } else { + credentialsBuilder_.setMessage(builderForValue.build()); + } + bitField0_ |= 0x00000001; + return this; + } + /** + *
+         * Credentials for reading from the specified artifact location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public Builder mergeCredentials(com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo value) { + if (credentialsBuilder_ == null) { + if (((bitField0_ & 0x00000001) == 0x00000001) && + credentials_ != null && + credentials_ != com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.getDefaultInstance()) { + credentials_ = + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.newBuilder(credentials_).mergeFrom(value).buildPartial(); + } else { + credentials_ = value; + } + onChanged(); + } else { + credentialsBuilder_.mergeFrom(value); + } + bitField0_ |= 0x00000001; + return this; + } + /** + *
+         * Credentials for reading from the specified artifact location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public Builder clearCredentials() { + if (credentialsBuilder_ == null) { + credentials_ = null; + onChanged(); + } else { + credentialsBuilder_.clear(); + } + bitField0_ = (bitField0_ & ~0x00000001); + return this; + } + /** + *
+         * Credentials for reading from the specified artifact location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder getCredentialsBuilder() { + bitField0_ |= 0x00000001; + onChanged(); + return getCredentialsFieldBuilder().getBuilder(); + } + /** + *
+         * Credentials for reading from the specified artifact location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfoOrBuilder getCredentialsOrBuilder() { + if (credentialsBuilder_ != null) { + return credentialsBuilder_.getMessageOrBuilder(); + } else { + return credentials_ == null ? + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.getDefaultInstance() : credentials_; + } + } + /** + *
+         * Credentials for reading from the specified artifact location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + private com.google.protobuf.SingleFieldBuilderV3< + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfoOrBuilder> + getCredentialsFieldBuilder() { + if (credentialsBuilder_ == null) { + credentialsBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfoOrBuilder>( + getCredentials(), + getParentForChildren(), + isClean()); + credentials_ = null; + } + return credentialsBuilder_; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:mlflow.GetCredentialsForRead.Response) + } + + // @@protoc_insertion_point(class_scope:mlflow.GetCredentialsForRead.Response) + private static final com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response(); + } + + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + @java.lang.Deprecated public static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public Response parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new Response(input, extensionRegistry); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Response getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + private int bitField0_; + public static final int RUN_ID_FIELD_NUMBER = 1; + private volatile java.lang.Object runId_; + /** + *
+     * The ID of the MLflow Run for which to fetch artifact read credentials
+     * 
+ * + * optional string run_id = 1; + */ + public boolean hasRunId() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + *
+     * The ID of the MLflow Run for which to fetch artifact read credentials
+     * 
+ * + * optional string run_id = 1; + */ + public java.lang.String getRunId() { + java.lang.Object ref = runId_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + runId_ = s; + } + return s; + } + } + /** + *
+     * The ID of the MLflow Run for which to fetch artifact read credentials
+     * 
+ * + * optional string run_id = 1; + */ + public com.google.protobuf.ByteString + getRunIdBytes() { + java.lang.Object ref = runId_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + runId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + public static final int PATH_FIELD_NUMBER = 2; + private volatile java.lang.Object path_; + /** + *
+     * The artifact path, relative to the Run's artifact root location, for which to
+     * fetch artifact read credentials
+     * 
+ * + * optional string path = 2; + */ + public boolean hasPath() { + return ((bitField0_ & 0x00000002) == 0x00000002); + } + /** + *
+     * The artifact path, relative to the Run's artifact root location, for which to
+     * fetch artifact read credentials
+     * 
+ * + * optional string path = 2; + */ + public java.lang.String getPath() { + java.lang.Object ref = path_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + path_ = s; + } + return s; + } + } + /** + *
+     * The artifact path, relative to the Run's artifact root location, for which to
+     * fetch artifact read credentials
+     * 
+ * + * optional string path = 2; + */ + public com.google.protobuf.ByteString + getPathBytes() { + java.lang.Object ref = path_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + path_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (((bitField0_ & 0x00000001) == 0x00000001)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, runId_); + } + if (((bitField0_ & 0x00000002) == 0x00000002)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 2, path_); + } + unknownFields.writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (((bitField0_ & 0x00000001) == 0x00000001)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, runId_); + } + if (((bitField0_ & 0x00000002) == 0x00000002)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(2, path_); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead)) { + return super.equals(obj); + } + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead other = (com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead) obj; + + boolean result = true; + result = result && (hasRunId() == other.hasRunId()); + if (hasRunId()) { + result = result && getRunId() + .equals(other.getRunId()); + } + result = result && (hasPath() == other.hasPath()); + if (hasPath()) { + result = result && getPath() + .equals(other.getPath()); + } + result = result && unknownFields.equals(other.unknownFields); + return result; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (hasRunId()) { + hash = (37 * hash) + RUN_ID_FIELD_NUMBER; + hash = (53 * hash) + getRunId().hashCode(); + } + if (hasPath()) { + hash = (37 * hash) + PATH_FIELD_NUMBER; + hash = (53 * hash) + getPath().hashCode(); + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code mlflow.GetCredentialsForRead} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:mlflow.GetCredentialsForRead) + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForReadOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForRead_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForRead_fieldAccessorTable + .ensureFieldAccessorsInitialized( + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.class, com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.Builder.class); + } + + // Construct using com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3 + .alwaysUseFieldBuilders) { + } + } + @java.lang.Override + public Builder clear() { + super.clear(); + runId_ = ""; + bitField0_ = (bitField0_ & ~0x00000001); + path_ = ""; + bitField0_ = (bitField0_ & ~0x00000002); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForRead_descriptor; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead getDefaultInstanceForType() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.getDefaultInstance(); + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead build() { + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead buildPartial() { + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead result = new com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead(this); + int from_bitField0_ = bitField0_; + int to_bitField0_ = 0; + if (((from_bitField0_ & 0x00000001) == 0x00000001)) { + to_bitField0_ |= 0x00000001; + } + result.runId_ = runId_; + if (((from_bitField0_ & 0x00000002) == 0x00000002)) { + to_bitField0_ |= 0x00000002; + } + result.path_ = path_; + result.bitField0_ = to_bitField0_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return (Builder) super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return (Builder) super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return (Builder) super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return (Builder) super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return (Builder) super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return (Builder) super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead) { + return mergeFrom((com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead other) { + if (other == com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead.getDefaultInstance()) return this; + if (other.hasRunId()) { + bitField0_ |= 0x00000001; + runId_ = other.runId_; + onChanged(); + } + if (other.hasPath()) { + bitField0_ |= 0x00000002; + path_ = other.path_; + onChanged(); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + private int bitField0_; + + private java.lang.Object runId_ = ""; + /** + *
+       * The ID of the MLflow Run for which to fetch artifact read credentials
+       * 
+ * + * optional string run_id = 1; + */ + public boolean hasRunId() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + *
+       * The ID of the MLflow Run for which to fetch artifact read credentials
+       * 
+ * + * optional string run_id = 1; + */ + public java.lang.String getRunId() { + java.lang.Object ref = runId_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + runId_ = s; + } + return s; + } else { + return (java.lang.String) ref; + } + } + /** + *
+       * The ID of the MLflow Run for which to fetch artifact read credentials
+       * 
+ * + * optional string run_id = 1; + */ + public com.google.protobuf.ByteString + getRunIdBytes() { + java.lang.Object ref = runId_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + runId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + *
+       * The ID of the MLflow Run for which to fetch artifact read credentials
+       * 
+ * + * optional string run_id = 1; + */ + public Builder setRunId( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000001; + runId_ = value; + onChanged(); + return this; + } + /** + *
+       * The ID of the MLflow Run for which to fetch artifact read credentials
+       * 
+ * + * optional string run_id = 1; + */ + public Builder clearRunId() { + bitField0_ = (bitField0_ & ~0x00000001); + runId_ = getDefaultInstance().getRunId(); + onChanged(); + return this; + } + /** + *
+       * The ID of the MLflow Run for which to fetch artifact read credentials
+       * 
+ * + * optional string run_id = 1; + */ + public Builder setRunIdBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000001; + runId_ = value; + onChanged(); + return this; + } + + private java.lang.Object path_ = ""; + /** + *
+       * The artifact path, relative to the Run's artifact root location, for which to
+       * fetch artifact read credentials
+       * 
+ * + * optional string path = 2; + */ + public boolean hasPath() { + return ((bitField0_ & 0x00000002) == 0x00000002); + } + /** + *
+       * The artifact path, relative to the Run's artifact root location, for which to
+       * fetch artifact read credentials
+       * 
+ * + * optional string path = 2; + */ + public java.lang.String getPath() { + java.lang.Object ref = path_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + path_ = s; + } + return s; + } else { + return (java.lang.String) ref; + } + } + /** + *
+       * The artifact path, relative to the Run's artifact root location, for which to
+       * fetch artifact read credentials
+       * 
+ * + * optional string path = 2; + */ + public com.google.protobuf.ByteString + getPathBytes() { + java.lang.Object ref = path_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + path_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + *
+       * The artifact path, relative to the Run's artifact root location, for which to
+       * fetch artifact read credentials
+       * 
+ * + * optional string path = 2; + */ + public Builder setPath( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000002; + path_ = value; + onChanged(); + return this; + } + /** + *
+       * The artifact path, relative to the Run's artifact root location, for which to
+       * fetch artifact read credentials
+       * 
+ * + * optional string path = 2; + */ + public Builder clearPath() { + bitField0_ = (bitField0_ & ~0x00000002); + path_ = getDefaultInstance().getPath(); + onChanged(); + return this; + } + /** + *
+       * The artifact path, relative to the Run's artifact root location, for which to
+       * fetch artifact read credentials
+       * 
+ * + * optional string path = 2; + */ + public Builder setPathBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000002; + path_ = value; + onChanged(); + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:mlflow.GetCredentialsForRead) + } + + // @@protoc_insertion_point(class_scope:mlflow.GetCredentialsForRead) + private static final com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead(); + } + + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + @java.lang.Deprecated public static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public GetCredentialsForRead parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new GetCredentialsForRead(input, extensionRegistry); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface GetCredentialsForWriteOrBuilder extends + // @@protoc_insertion_point(interface_extends:mlflow.GetCredentialsForWrite) + com.google.protobuf.MessageOrBuilder { + + /** + *
+     * The ID of the MLflow Run for which to fetch artifact write credentials
+     * 
+ * + * optional string run_id = 1; + */ + boolean hasRunId(); + /** + *
+     * The ID of the MLflow Run for which to fetch artifact write credentials
+     * 
+ * + * optional string run_id = 1; + */ + java.lang.String getRunId(); + /** + *
+     * The ID of the MLflow Run for which to fetch artifact write credentials
+     * 
+ * + * optional string run_id = 1; + */ + com.google.protobuf.ByteString + getRunIdBytes(); + + /** + *
+     * The artifact path, relative to the Run's artifact root location, for which to
+     * fetch artifact write credentials
+     * 
+ * + * optional string path = 2; + */ + boolean hasPath(); + /** + *
+     * The artifact path, relative to the Run's artifact root location, for which to
+     * fetch artifact write credentials
+     * 
+ * + * optional string path = 2; + */ + java.lang.String getPath(); + /** + *
+     * The artifact path, relative to the Run's artifact root location, for which to
+     * fetch artifact write credentials
+     * 
+ * + * optional string path = 2; + */ + com.google.protobuf.ByteString + getPathBytes(); + } + /** + * Protobuf type {@code mlflow.GetCredentialsForWrite} + */ + public static final class GetCredentialsForWrite extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:mlflow.GetCredentialsForWrite) + GetCredentialsForWriteOrBuilder { + private static final long serialVersionUID = 0L; + // Use GetCredentialsForWrite.newBuilder() to construct. + private GetCredentialsForWrite(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private GetCredentialsForWrite() { + runId_ = ""; + path_ = ""; + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + private GetCredentialsForWrite( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + com.google.protobuf.ByteString bs = input.readBytes(); + bitField0_ |= 0x00000001; + runId_ = bs; + break; + } + case 18: { + com.google.protobuf.ByteString bs = input.readBytes(); + bitField0_ |= 0x00000002; + path_ = bs; + break; + } + default: { + if (!parseUnknownField( + input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException( + e).setUnfinishedMessage(this); + } finally { + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForWrite_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForWrite_fieldAccessorTable + .ensureFieldAccessorsInitialized( + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.class, com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Builder.class); + } + + public interface ResponseOrBuilder extends + // @@protoc_insertion_point(interface_extends:mlflow.GetCredentialsForWrite.Response) + com.google.protobuf.MessageOrBuilder { + + /** + *
+       * Credentials for writing to the specified artifacts location
+       * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + boolean hasCredentials(); + /** + *
+       * Credentials for writing to the specified artifacts location
+       * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo getCredentials(); + /** + *
+       * Credentials for writing to the specified artifacts location
+       * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfoOrBuilder getCredentialsOrBuilder(); + } + /** + * Protobuf type {@code mlflow.GetCredentialsForWrite.Response} + */ + public static final class Response extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:mlflow.GetCredentialsForWrite.Response) + ResponseOrBuilder { + private static final long serialVersionUID = 0L; + // Use Response.newBuilder() to construct. + private Response(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private Response() { + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + private Response( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder subBuilder = null; + if (((bitField0_ & 0x00000001) == 0x00000001)) { + subBuilder = credentials_.toBuilder(); + } + credentials_ = input.readMessage(com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.PARSER, extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom(credentials_); + credentials_ = subBuilder.buildPartial(); + } + bitField0_ |= 0x00000001; + break; + } + default: { + if (!parseUnknownField( + input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException( + e).setUnfinishedMessage(this); + } finally { + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForWrite_Response_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForWrite_Response_fieldAccessorTable + .ensureFieldAccessorsInitialized( + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response.class, com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response.Builder.class); + } + + private int bitField0_; + public static final int CREDENTIALS_FIELD_NUMBER = 1; + private com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo credentials_; + /** + *
+       * Credentials for writing to the specified artifacts location
+       * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public boolean hasCredentials() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + *
+       * Credentials for writing to the specified artifacts location
+       * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo getCredentials() { + return credentials_ == null ? com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.getDefaultInstance() : credentials_; + } + /** + *
+       * Credentials for writing to the specified artifacts location
+       * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfoOrBuilder getCredentialsOrBuilder() { + return credentials_ == null ? com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.getDefaultInstance() : credentials_; + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (((bitField0_ & 0x00000001) == 0x00000001)) { + output.writeMessage(1, getCredentials()); + } + unknownFields.writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (((bitField0_ & 0x00000001) == 0x00000001)) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(1, getCredentials()); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response)) { + return super.equals(obj); + } + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response other = (com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response) obj; + + boolean result = true; + result = result && (hasCredentials() == other.hasCredentials()); + if (hasCredentials()) { + result = result && getCredentials() + .equals(other.getCredentials()); + } + result = result && unknownFields.equals(other.unknownFields); + return result; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (hasCredentials()) { + hash = (37 * hash) + CREDENTIALS_FIELD_NUMBER; + hash = (53 * hash) + getCredentials().hashCode(); + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code mlflow.GetCredentialsForWrite.Response} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:mlflow.GetCredentialsForWrite.Response) + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.ResponseOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForWrite_Response_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForWrite_Response_fieldAccessorTable + .ensureFieldAccessorsInitialized( + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response.class, com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response.Builder.class); + } + + // Construct using com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3 + .alwaysUseFieldBuilders) { + getCredentialsFieldBuilder(); + } + } + @java.lang.Override + public Builder clear() { + super.clear(); + if (credentialsBuilder_ == null) { + credentials_ = null; + } else { + credentialsBuilder_.clear(); + } + bitField0_ = (bitField0_ & ~0x00000001); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForWrite_Response_descriptor; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response getDefaultInstanceForType() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response.getDefaultInstance(); + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response build() { + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response buildPartial() { + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response result = new com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response(this); + int from_bitField0_ = bitField0_; + int to_bitField0_ = 0; + if (((from_bitField0_ & 0x00000001) == 0x00000001)) { + to_bitField0_ |= 0x00000001; + } + if (credentialsBuilder_ == null) { + result.credentials_ = credentials_; + } else { + result.credentials_ = credentialsBuilder_.build(); + } + result.bitField0_ = to_bitField0_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return (Builder) super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return (Builder) super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return (Builder) super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return (Builder) super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return (Builder) super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return (Builder) super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response) { + return mergeFrom((com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response other) { + if (other == com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response.getDefaultInstance()) return this; + if (other.hasCredentials()) { + mergeCredentials(other.getCredentials()); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + private int bitField0_; + + private com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo credentials_ = null; + private com.google.protobuf.SingleFieldBuilderV3< + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfoOrBuilder> credentialsBuilder_; + /** + *
+         * Credentials for writing to the specified artifacts location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public boolean hasCredentials() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + *
+         * Credentials for writing to the specified artifacts location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo getCredentials() { + if (credentialsBuilder_ == null) { + return credentials_ == null ? com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.getDefaultInstance() : credentials_; + } else { + return credentialsBuilder_.getMessage(); + } + } + /** + *
+         * Credentials for writing to the specified artifacts location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public Builder setCredentials(com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo value) { + if (credentialsBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + credentials_ = value; + onChanged(); + } else { + credentialsBuilder_.setMessage(value); + } + bitField0_ |= 0x00000001; + return this; + } + /** + *
+         * Credentials for writing to the specified artifacts location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public Builder setCredentials( + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder builderForValue) { + if (credentialsBuilder_ == null) { + credentials_ = builderForValue.build(); + onChanged(); + } else { + credentialsBuilder_.setMessage(builderForValue.build()); + } + bitField0_ |= 0x00000001; + return this; + } + /** + *
+         * Credentials for writing to the specified artifacts location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public Builder mergeCredentials(com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo value) { + if (credentialsBuilder_ == null) { + if (((bitField0_ & 0x00000001) == 0x00000001) && + credentials_ != null && + credentials_ != com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.getDefaultInstance()) { + credentials_ = + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.newBuilder(credentials_).mergeFrom(value).buildPartial(); + } else { + credentials_ = value; + } + onChanged(); + } else { + credentialsBuilder_.mergeFrom(value); + } + bitField0_ |= 0x00000001; + return this; + } + /** + *
+         * Credentials for writing to the specified artifacts location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public Builder clearCredentials() { + if (credentialsBuilder_ == null) { + credentials_ = null; + onChanged(); + } else { + credentialsBuilder_.clear(); + } + bitField0_ = (bitField0_ & ~0x00000001); + return this; + } + /** + *
+         * Credentials for writing to the specified artifacts location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder getCredentialsBuilder() { + bitField0_ |= 0x00000001; + onChanged(); + return getCredentialsFieldBuilder().getBuilder(); + } + /** + *
+         * Credentials for writing to the specified artifacts location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfoOrBuilder getCredentialsOrBuilder() { + if (credentialsBuilder_ != null) { + return credentialsBuilder_.getMessageOrBuilder(); + } else { + return credentials_ == null ? + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.getDefaultInstance() : credentials_; + } + } + /** + *
+         * Credentials for writing to the specified artifacts location
+         * 
+ * + * optional .mlflow.ArtifactCredentialInfo credentials = 1; + */ + private com.google.protobuf.SingleFieldBuilderV3< + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfoOrBuilder> + getCredentialsFieldBuilder() { + if (credentialsBuilder_ == null) { + credentialsBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfoOrBuilder>( + getCredentials(), + getParentForChildren(), + isClean()); + credentials_ = null; + } + return credentialsBuilder_; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:mlflow.GetCredentialsForWrite.Response) + } + + // @@protoc_insertion_point(class_scope:mlflow.GetCredentialsForWrite.Response) + private static final com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response(); + } + + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + @java.lang.Deprecated public static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public Response parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new Response(input, extensionRegistry); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Response getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + private int bitField0_; + public static final int RUN_ID_FIELD_NUMBER = 1; + private volatile java.lang.Object runId_; + /** + *
+     * The ID of the MLflow Run for which to fetch artifact write credentials
+     * 
+ * + * optional string run_id = 1; + */ + public boolean hasRunId() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + *
+     * The ID of the MLflow Run for which to fetch artifact write credentials
+     * 
+ * + * optional string run_id = 1; + */ + public java.lang.String getRunId() { + java.lang.Object ref = runId_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + runId_ = s; + } + return s; + } + } + /** + *
+     * The ID of the MLflow Run for which to fetch artifact write credentials
+     * 
+ * + * optional string run_id = 1; + */ + public com.google.protobuf.ByteString + getRunIdBytes() { + java.lang.Object ref = runId_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + runId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + public static final int PATH_FIELD_NUMBER = 2; + private volatile java.lang.Object path_; + /** + *
+     * The artifact path, relative to the Run's artifact root location, for which to
+     * fetch artifact write credentials
+     * 
+ * + * optional string path = 2; + */ + public boolean hasPath() { + return ((bitField0_ & 0x00000002) == 0x00000002); + } + /** + *
+     * The artifact path, relative to the Run's artifact root location, for which to
+     * fetch artifact write credentials
+     * 
+ * + * optional string path = 2; + */ + public java.lang.String getPath() { + java.lang.Object ref = path_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + path_ = s; + } + return s; + } + } + /** + *
+     * The artifact path, relative to the Run's artifact root location, for which to
+     * fetch artifact write credentials
+     * 
+ * + * optional string path = 2; + */ + public com.google.protobuf.ByteString + getPathBytes() { + java.lang.Object ref = path_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + path_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (((bitField0_ & 0x00000001) == 0x00000001)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, runId_); + } + if (((bitField0_ & 0x00000002) == 0x00000002)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 2, path_); + } + unknownFields.writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (((bitField0_ & 0x00000001) == 0x00000001)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, runId_); + } + if (((bitField0_ & 0x00000002) == 0x00000002)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(2, path_); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite)) { + return super.equals(obj); + } + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite other = (com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite) obj; + + boolean result = true; + result = result && (hasRunId() == other.hasRunId()); + if (hasRunId()) { + result = result && getRunId() + .equals(other.getRunId()); + } + result = result && (hasPath() == other.hasPath()); + if (hasPath()) { + result = result && getPath() + .equals(other.getPath()); + } + result = result && unknownFields.equals(other.unknownFields); + return result; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (hasRunId()) { + hash = (37 * hash) + RUN_ID_FIELD_NUMBER; + hash = (53 * hash) + getRunId().hashCode(); + } + if (hasPath()) { + hash = (37 * hash) + PATH_FIELD_NUMBER; + hash = (53 * hash) + getPath().hashCode(); + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code mlflow.GetCredentialsForWrite} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:mlflow.GetCredentialsForWrite) + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWriteOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForWrite_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForWrite_fieldAccessorTable + .ensureFieldAccessorsInitialized( + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.class, com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.Builder.class); + } + + // Construct using com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3 + .alwaysUseFieldBuilders) { + } + } + @java.lang.Override + public Builder clear() { + super.clear(); + runId_ = ""; + bitField0_ = (bitField0_ & ~0x00000001); + path_ = ""; + bitField0_ = (bitField0_ & ~0x00000002); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_GetCredentialsForWrite_descriptor; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite getDefaultInstanceForType() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.getDefaultInstance(); + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite build() { + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite buildPartial() { + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite result = new com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite(this); + int from_bitField0_ = bitField0_; + int to_bitField0_ = 0; + if (((from_bitField0_ & 0x00000001) == 0x00000001)) { + to_bitField0_ |= 0x00000001; + } + result.runId_ = runId_; + if (((from_bitField0_ & 0x00000002) == 0x00000002)) { + to_bitField0_ |= 0x00000002; + } + result.path_ = path_; + result.bitField0_ = to_bitField0_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return (Builder) super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return (Builder) super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return (Builder) super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return (Builder) super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return (Builder) super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return (Builder) super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite) { + return mergeFrom((com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite other) { + if (other == com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite.getDefaultInstance()) return this; + if (other.hasRunId()) { + bitField0_ |= 0x00000001; + runId_ = other.runId_; + onChanged(); + } + if (other.hasPath()) { + bitField0_ |= 0x00000002; + path_ = other.path_; + onChanged(); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + private int bitField0_; + + private java.lang.Object runId_ = ""; + /** + *
+       * The ID of the MLflow Run for which to fetch artifact write credentials
+       * 
+ * + * optional string run_id = 1; + */ + public boolean hasRunId() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + *
+       * The ID of the MLflow Run for which to fetch artifact write credentials
+       * 
+ * + * optional string run_id = 1; + */ + public java.lang.String getRunId() { + java.lang.Object ref = runId_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + runId_ = s; + } + return s; + } else { + return (java.lang.String) ref; + } + } + /** + *
+       * The ID of the MLflow Run for which to fetch artifact write credentials
+       * 
+ * + * optional string run_id = 1; + */ + public com.google.protobuf.ByteString + getRunIdBytes() { + java.lang.Object ref = runId_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + runId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + *
+       * The ID of the MLflow Run for which to fetch artifact write credentials
+       * 
+ * + * optional string run_id = 1; + */ + public Builder setRunId( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000001; + runId_ = value; + onChanged(); + return this; + } + /** + *
+       * The ID of the MLflow Run for which to fetch artifact write credentials
+       * 
+ * + * optional string run_id = 1; + */ + public Builder clearRunId() { + bitField0_ = (bitField0_ & ~0x00000001); + runId_ = getDefaultInstance().getRunId(); + onChanged(); + return this; + } + /** + *
+       * The ID of the MLflow Run for which to fetch artifact write credentials
+       * 
+ * + * optional string run_id = 1; + */ + public Builder setRunIdBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000001; + runId_ = value; + onChanged(); + return this; + } + + private java.lang.Object path_ = ""; + /** + *
+       * The artifact path, relative to the Run's artifact root location, for which to
+       * fetch artifact write credentials
+       * 
+ * + * optional string path = 2; + */ + public boolean hasPath() { + return ((bitField0_ & 0x00000002) == 0x00000002); + } + /** + *
+       * The artifact path, relative to the Run's artifact root location, for which to
+       * fetch artifact write credentials
+       * 
+ * + * optional string path = 2; + */ + public java.lang.String getPath() { + java.lang.Object ref = path_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + path_ = s; + } + return s; + } else { + return (java.lang.String) ref; + } + } + /** + *
+       * The artifact path, relative to the Run's artifact root location, for which to
+       * fetch artifact write credentials
+       * 
+ * + * optional string path = 2; + */ + public com.google.protobuf.ByteString + getPathBytes() { + java.lang.Object ref = path_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + path_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + *
+       * The artifact path, relative to the Run's artifact root location, for which to
+       * fetch artifact write credentials
+       * 
+ * + * optional string path = 2; + */ + public Builder setPath( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000002; + path_ = value; + onChanged(); + return this; + } + /** + *
+       * The artifact path, relative to the Run's artifact root location, for which to
+       * fetch artifact write credentials
+       * 
+ * + * optional string path = 2; + */ + public Builder clearPath() { + bitField0_ = (bitField0_ & ~0x00000002); + path_ = getDefaultInstance().getPath(); + onChanged(); + return this; + } + /** + *
+       * The artifact path, relative to the Run's artifact root location, for which to
+       * fetch artifact write credentials
+       * 
+ * + * optional string path = 2; + */ + public Builder setPathBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000002; + path_ = value; + onChanged(); + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:mlflow.GetCredentialsForWrite) + } + + // @@protoc_insertion_point(class_scope:mlflow.GetCredentialsForWrite) + private static final com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite(); + } + + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + @java.lang.Deprecated public static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public GetCredentialsForWrite parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new GetCredentialsForWrite(input, extensionRegistry); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrite getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_mlflow_ArtifactCredentialInfo_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_mlflow_ArtifactCredentialInfo_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_mlflow_GetCredentialsForRead_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_mlflow_GetCredentialsForRead_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_mlflow_GetCredentialsForRead_Response_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_mlflow_GetCredentialsForRead_Response_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_mlflow_GetCredentialsForWrite_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_mlflow_GetCredentialsForWrite_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_mlflow_GetCredentialsForWrite_Response_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_mlflow_GetCredentialsForWrite_Response_fieldAccessorTable; + + public static com.google.protobuf.Descriptors.FileDescriptor + getDescriptor() { + return descriptor; + } + private static com.google.protobuf.Descriptors.FileDescriptor + descriptor; + static { + java.lang.String[] descriptorData = { + "\n\032databricks_artifacts.proto\022\006mlflow\032\025sc" + + "alapb/scalapb.proto\032\020databricks.proto\"x\n" + + "\026ArtifactCredentialInfo\022\016\n\006run_id\030\001 \001(\t\022" + + "\014\n\004path\030\002 \001(\t\022\022\n\nsigned_uri\030\003 \001(\t\022,\n\004typ" + + "e\030\004 \001(\0162\036.mlflow.ArtifactCredentialType\"" + + "\243\001\n\025GetCredentialsForRead\022\016\n\006run_id\030\001 \001(" + + "\t\022\014\n\004path\030\002 \001(\t\032?\n\010Response\0223\n\013credentia" + + "ls\030\001 \001(\0132\036.mlflow.ArtifactCredentialInfo" + + ":+\342?(\n&com.databricks.rpc.RPC[$this.Resp" + + "onse]\"\244\001\n\026GetCredentialsForWrite\022\016\n\006run_" + + "id\030\001 \001(\t\022\014\n\004path\030\002 \001(\t\032?\n\010Response\0223\n\013cr" + + "edentials\030\001 \001(\0132\036.mlflow.ArtifactCredent" + + "ialInfo:+\342?(\n&com.databricks.rpc.RPC[$th" + + "is.Response]*B\n\026ArtifactCredentialType\022\021" + + "\n\rAZURE_SAS_URI\020\001\022\025\n\021AWS_PRESIGNED_URL\020\002" + + "2\342\002\n DatabricksMlflowArtifactsService\022\233\001" + + "\n\025getCredentialsForRead\022\035.mlflow.GetCred" + + "entialsForRead\032&.mlflow.GetCredentialsFo" + + "rRead.Response\";\362\206\0317\n3\n\003GET\022&/mlflow/art" + + "ifacts/credentials-for-read\032\004\010\002\020\000\020\003\022\237\001\n\026" + + "getCredentialsForWrite\022\036.mlflow.GetCrede" + + "ntialsForWrite\032\'.mlflow.GetCredentialsFo" + + "rWrite.Response\"<\362\206\0318\n4\n\003GET\022\'/mlflow/ar" + + "tifacts/credentials-for-write\032\004\010\002\020\000\020\003B,\n" + + "\037com.databricks.api.proto.mlflow\220\001\001\240\001\001\342?" + + "\002\020\001" + }; + com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = + new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { + public com.google.protobuf.ExtensionRegistry assignDescriptors( + com.google.protobuf.Descriptors.FileDescriptor root) { + descriptor = root; + return null; + } + }; + com.google.protobuf.Descriptors.FileDescriptor + .internalBuildGeneratedFileFrom(descriptorData, + new com.google.protobuf.Descriptors.FileDescriptor[] { + org.mlflow.scalapb_interface.Scalapb.getDescriptor(), + com.databricks.api.proto.databricks.Databricks.getDescriptor(), + }, assigner); + internal_static_mlflow_ArtifactCredentialInfo_descriptor = + getDescriptor().getMessageTypes().get(0); + internal_static_mlflow_ArtifactCredentialInfo_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_mlflow_ArtifactCredentialInfo_descriptor, + new java.lang.String[] { "RunId", "Path", "SignedUri", "Type", }); + internal_static_mlflow_GetCredentialsForRead_descriptor = + getDescriptor().getMessageTypes().get(1); + internal_static_mlflow_GetCredentialsForRead_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_mlflow_GetCredentialsForRead_descriptor, + new java.lang.String[] { "RunId", "Path", }); + internal_static_mlflow_GetCredentialsForRead_Response_descriptor = + internal_static_mlflow_GetCredentialsForRead_descriptor.getNestedTypes().get(0); + internal_static_mlflow_GetCredentialsForRead_Response_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_mlflow_GetCredentialsForRead_Response_descriptor, + new java.lang.String[] { "Credentials", }); + internal_static_mlflow_GetCredentialsForWrite_descriptor = + getDescriptor().getMessageTypes().get(2); + internal_static_mlflow_GetCredentialsForWrite_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_mlflow_GetCredentialsForWrite_descriptor, + new java.lang.String[] { "RunId", "Path", }); + internal_static_mlflow_GetCredentialsForWrite_Response_descriptor = + internal_static_mlflow_GetCredentialsForWrite_descriptor.getNestedTypes().get(0); + internal_static_mlflow_GetCredentialsForWrite_Response_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_mlflow_GetCredentialsForWrite_Response_descriptor, + new java.lang.String[] { "Credentials", }); + com.google.protobuf.ExtensionRegistry registry = + com.google.protobuf.ExtensionRegistry.newInstance(); + registry.add(com.databricks.api.proto.databricks.Databricks.rpc); + registry.add(org.mlflow.scalapb_interface.Scalapb.message); + registry.add(org.mlflow.scalapb_interface.Scalapb.options); + com.google.protobuf.Descriptors.FileDescriptor + .internalUpdateFileDescriptor(descriptor, registry); + org.mlflow.scalapb_interface.Scalapb.getDescriptor(); + com.databricks.api.proto.databricks.Databricks.getDescriptor(); + } + + // @@protoc_insertion_point(outer_class_scope) +} diff --git a/mlflow/protos/databricks_artifacts_pb2.py b/mlflow/protos/databricks_artifacts_pb2.py new file mode 100644 index 0000000000000..8d67294e1cf93 --- /dev/null +++ b/mlflow/protos/databricks_artifacts_pb2.py @@ -0,0 +1,342 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: databricks_artifacts.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import service as _service +from google.protobuf import service_reflection +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from .scalapb import scalapb_pb2 as scalapb_dot_scalapb__pb2 +from . import databricks_pb2 as databricks__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='databricks_artifacts.proto', + package='mlflow', + syntax='proto2', + serialized_options=_b('\n\037com.databricks.api.proto.mlflow\220\001\001\240\001\001\342?\002\020\001'), + serialized_pb=_b('\n\x1a\x64\x61tabricks_artifacts.proto\x12\x06mlflow\x1a\x15scalapb/scalapb.proto\x1a\x10\x64\x61tabricks.proto\"x\n\x16\x41rtifactCredentialInfo\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0c\n\x04path\x18\x02 \x01(\t\x12\x12\n\nsigned_uri\x18\x03 \x01(\t\x12,\n\x04type\x18\x04 \x01(\x0e\x32\x1e.mlflow.ArtifactCredentialType\"\xa3\x01\n\x15GetCredentialsForRead\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0c\n\x04path\x18\x02 \x01(\t\x1a?\n\x08Response\x12\x33\n\x0b\x63redentials\x18\x01 \x01(\x0b\x32\x1e.mlflow.ArtifactCredentialInfo:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\xa4\x01\n\x16GetCredentialsForWrite\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0c\n\x04path\x18\x02 \x01(\t\x1a?\n\x08Response\x12\x33\n\x0b\x63redentials\x18\x01 \x01(\x0b\x32\x1e.mlflow.ArtifactCredentialInfo:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]*B\n\x16\x41rtifactCredentialType\x12\x11\n\rAZURE_SAS_URI\x10\x01\x12\x15\n\x11\x41WS_PRESIGNED_URL\x10\x02\x32\xe2\x02\n DatabricksMlflowArtifactsService\x12\x9b\x01\n\x15getCredentialsForRead\x12\x1d.mlflow.GetCredentialsForRead\x1a&.mlflow.GetCredentialsForRead.Response\";\xf2\x86\x19\x37\n3\n\x03GET\x12&/mlflow/artifacts/credentials-for-read\x1a\x04\x08\x02\x10\x00\x10\x03\x12\x9f\x01\n\x16getCredentialsForWrite\x12\x1e.mlflow.GetCredentialsForWrite\x1a\'.mlflow.GetCredentialsForWrite.Response\"<\xf2\x86\x19\x38\n4\n\x03GET\x12\'/mlflow/artifacts/credentials-for-write\x1a\x04\x08\x02\x10\x00\x10\x03\x42,\n\x1f\x63om.databricks.api.proto.mlflow\x90\x01\x01\xa0\x01\x01\xe2?\x02\x10\x01') + , + dependencies=[scalapb_dot_scalapb__pb2.DESCRIPTOR,databricks__pb2.DESCRIPTOR,]) + +_ARTIFACTCREDENTIALTYPE = _descriptor.EnumDescriptor( + name='ArtifactCredentialType', + full_name='mlflow.ArtifactCredentialType', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='AZURE_SAS_URI', index=0, number=1, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='AWS_PRESIGNED_URL', index=1, number=2, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=534, + serialized_end=600, +) +_sym_db.RegisterEnumDescriptor(_ARTIFACTCREDENTIALTYPE) + +ArtifactCredentialType = enum_type_wrapper.EnumTypeWrapper(_ARTIFACTCREDENTIALTYPE) +AZURE_SAS_URI = 1 +AWS_PRESIGNED_URL = 2 + + + +_ARTIFACTCREDENTIALINFO = _descriptor.Descriptor( + name='ArtifactCredentialInfo', + full_name='mlflow.ArtifactCredentialInfo', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='run_id', full_name='mlflow.ArtifactCredentialInfo.run_id', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='path', full_name='mlflow.ArtifactCredentialInfo.path', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='signed_uri', full_name='mlflow.ArtifactCredentialInfo.signed_uri', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='type', full_name='mlflow.ArtifactCredentialInfo.type', index=3, + number=4, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=1, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=79, + serialized_end=199, +) + + +_GETCREDENTIALSFORREAD_RESPONSE = _descriptor.Descriptor( + name='Response', + full_name='mlflow.GetCredentialsForRead.Response', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='credentials', full_name='mlflow.GetCredentialsForRead.Response.credentials', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=257, + serialized_end=320, +) + +_GETCREDENTIALSFORREAD = _descriptor.Descriptor( + name='GetCredentialsForRead', + full_name='mlflow.GetCredentialsForRead', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='run_id', full_name='mlflow.GetCredentialsForRead.run_id', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='path', full_name='mlflow.GetCredentialsForRead.path', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_GETCREDENTIALSFORREAD_RESPONSE, ], + enum_types=[ + ], + serialized_options=_b('\342?(\n&com.databricks.rpc.RPC[$this.Response]'), + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=202, + serialized_end=365, +) + + +_GETCREDENTIALSFORWRITE_RESPONSE = _descriptor.Descriptor( + name='Response', + full_name='mlflow.GetCredentialsForWrite.Response', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='credentials', full_name='mlflow.GetCredentialsForWrite.Response.credentials', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=257, + serialized_end=320, +) + +_GETCREDENTIALSFORWRITE = _descriptor.Descriptor( + name='GetCredentialsForWrite', + full_name='mlflow.GetCredentialsForWrite', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='run_id', full_name='mlflow.GetCredentialsForWrite.run_id', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='path', full_name='mlflow.GetCredentialsForWrite.path', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_GETCREDENTIALSFORWRITE_RESPONSE, ], + enum_types=[ + ], + serialized_options=_b('\342?(\n&com.databricks.rpc.RPC[$this.Response]'), + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=368, + serialized_end=532, +) + +_ARTIFACTCREDENTIALINFO.fields_by_name['type'].enum_type = _ARTIFACTCREDENTIALTYPE +_GETCREDENTIALSFORREAD_RESPONSE.fields_by_name['credentials'].message_type = _ARTIFACTCREDENTIALINFO +_GETCREDENTIALSFORREAD_RESPONSE.containing_type = _GETCREDENTIALSFORREAD +_GETCREDENTIALSFORWRITE_RESPONSE.fields_by_name['credentials'].message_type = _ARTIFACTCREDENTIALINFO +_GETCREDENTIALSFORWRITE_RESPONSE.containing_type = _GETCREDENTIALSFORWRITE +DESCRIPTOR.message_types_by_name['ArtifactCredentialInfo'] = _ARTIFACTCREDENTIALINFO +DESCRIPTOR.message_types_by_name['GetCredentialsForRead'] = _GETCREDENTIALSFORREAD +DESCRIPTOR.message_types_by_name['GetCredentialsForWrite'] = _GETCREDENTIALSFORWRITE +DESCRIPTOR.enum_types_by_name['ArtifactCredentialType'] = _ARTIFACTCREDENTIALTYPE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +ArtifactCredentialInfo = _reflection.GeneratedProtocolMessageType('ArtifactCredentialInfo', (_message.Message,), dict( + DESCRIPTOR = _ARTIFACTCREDENTIALINFO, + __module__ = 'databricks_artifacts_pb2' + # @@protoc_insertion_point(class_scope:mlflow.ArtifactCredentialInfo) + )) +_sym_db.RegisterMessage(ArtifactCredentialInfo) + +GetCredentialsForRead = _reflection.GeneratedProtocolMessageType('GetCredentialsForRead', (_message.Message,), dict( + + Response = _reflection.GeneratedProtocolMessageType('Response', (_message.Message,), dict( + DESCRIPTOR = _GETCREDENTIALSFORREAD_RESPONSE, + __module__ = 'databricks_artifacts_pb2' + # @@protoc_insertion_point(class_scope:mlflow.GetCredentialsForRead.Response) + )) + , + DESCRIPTOR = _GETCREDENTIALSFORREAD, + __module__ = 'databricks_artifacts_pb2' + # @@protoc_insertion_point(class_scope:mlflow.GetCredentialsForRead) + )) +_sym_db.RegisterMessage(GetCredentialsForRead) +_sym_db.RegisterMessage(GetCredentialsForRead.Response) + +GetCredentialsForWrite = _reflection.GeneratedProtocolMessageType('GetCredentialsForWrite', (_message.Message,), dict( + + Response = _reflection.GeneratedProtocolMessageType('Response', (_message.Message,), dict( + DESCRIPTOR = _GETCREDENTIALSFORWRITE_RESPONSE, + __module__ = 'databricks_artifacts_pb2' + # @@protoc_insertion_point(class_scope:mlflow.GetCredentialsForWrite.Response) + )) + , + DESCRIPTOR = _GETCREDENTIALSFORWRITE, + __module__ = 'databricks_artifacts_pb2' + # @@protoc_insertion_point(class_scope:mlflow.GetCredentialsForWrite) + )) +_sym_db.RegisterMessage(GetCredentialsForWrite) +_sym_db.RegisterMessage(GetCredentialsForWrite.Response) + + +DESCRIPTOR._options = None +_GETCREDENTIALSFORREAD._options = None +_GETCREDENTIALSFORWRITE._options = None + +_DATABRICKSMLFLOWARTIFACTSSERVICE = _descriptor.ServiceDescriptor( + name='DatabricksMlflowArtifactsService', + full_name='mlflow.DatabricksMlflowArtifactsService', + file=DESCRIPTOR, + index=0, + serialized_options=None, + serialized_start=603, + serialized_end=957, + methods=[ + _descriptor.MethodDescriptor( + name='getCredentialsForRead', + full_name='mlflow.DatabricksMlflowArtifactsService.getCredentialsForRead', + index=0, + containing_service=None, + input_type=_GETCREDENTIALSFORREAD, + output_type=_GETCREDENTIALSFORREAD_RESPONSE, + serialized_options=_b('\362\206\0317\n3\n\003GET\022&/mlflow/artifacts/credentials-for-read\032\004\010\002\020\000\020\003'), + ), + _descriptor.MethodDescriptor( + name='getCredentialsForWrite', + full_name='mlflow.DatabricksMlflowArtifactsService.getCredentialsForWrite', + index=1, + containing_service=None, + input_type=_GETCREDENTIALSFORWRITE, + output_type=_GETCREDENTIALSFORWRITE_RESPONSE, + serialized_options=_b('\362\206\0318\n4\n\003GET\022\'/mlflow/artifacts/credentials-for-write\032\004\010\002\020\000\020\003'), + ), +]) +_sym_db.RegisterServiceDescriptor(_DATABRICKSMLFLOWARTIFACTSSERVICE) + +DESCRIPTOR.services_by_name['DatabricksMlflowArtifactsService'] = _DATABRICKSMLFLOWARTIFACTSSERVICE + +DatabricksMlflowArtifactsService = service_reflection.GeneratedServiceType('DatabricksMlflowArtifactsService', (_service.Service,), dict( + DESCRIPTOR = _DATABRICKSMLFLOWARTIFACTSSERVICE, + __module__ = 'databricks_artifacts_pb2' + )) + +DatabricksMlflowArtifactsService_Stub = service_reflection.GeneratedServiceStubType('DatabricksMlflowArtifactsService_Stub', (DatabricksMlflowArtifactsService,), dict( + DESCRIPTOR = _DATABRICKSMLFLOWARTIFACTSSERVICE, + __module__ = 'databricks_artifacts_pb2' + )) + + +# @@protoc_insertion_point(module_scope) diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index f94f9967caafb..afc4b9e10f4fe 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -10,46 +10,54 @@ from mlflow.protos.databricks_artifacts_pb2 import GetCredentialsForWrite, GetCredentialsForRead from mlflow.utils.databricks_utils import get_databricks_host_creds from mlflow.protos.service_pb2 import MlflowService, ListArtifacts +from mlflow.utils.uri import extract_and_normalize_path _PATH_PREFIX = "/api/2.0" class DatabricksArtifactRepository(ArtifactRepository): """ - SOMETHING : TYPING TILL IT WORKS LOL + Stores artifacts on Azure/AWS with access control. + + The artifact_uri is expected to be of the form dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts/.. """ def __init__(self, artifact_uri): super(DatabricksArtifactRepository, self).__init__(artifact_uri) + self._SERVICE_AND_METHOD_TO_INFO = { + service: extract_api_info_for_service(service, _PATH_PREFIX) + for service in [MlflowService, DatabricksMlflowArtifactsService] + } + self.credential_type_to_cloud_service = { + + } def _extract_run_id(self, artifact_uri): - return artifact_uri.lstrip('/').split('/')[4] + artifact_path = extract_and_normalize_path(artifact_uri) + return artifact_path.split('/')[3] def _call_endpoint(self, service, api, json_body): - _METHOD_TO_INFO = extract_api_info_for_service(service, _PATH_PREFIX) - endpoint, method = _METHOD_TO_INFO[api] + endpoint, method = self._SERVICE_AND_METHOD_TO_INFO[service][api] response_proto = api.Response() return call_endpoint(get_databricks_host_creds(), endpoint, method, json_body, response_proto) def _create_json_body(self, run_id, path=None): - path = path or '.' + path = path or "" return { "run_id": run_id, "path": path } - def _get_azure_write_credentials(self, run_id, path=None): + def _get_write_credentials(self, run_id, path=None): return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForWrite, self._create_json_body(run_id, path)) - def _get_azure_read_credentials(self, run_id, path=None): + def _get_read_credentials(self, run_id, path=None): return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForRead, self._create_json_body(run_id, path)) - def _upload_file(self, local_file, artifact_path): - run_id = self._extract_run_id(self.artifact_uri) - write_credentials = self._get_azure_write_credentials(run_id, artifact_path) - signed_write_uri = write_credentials.credentials.signed_uri + def _azure_upload_file(self, credentials, local_file): + signed_write_uri = credentials.signed_uri service = BlobClient.from_blob_url(blob_url=signed_write_uri, credential=None) try: with open(local_file, "rb") as data: @@ -57,8 +65,29 @@ def _upload_file(self, local_file, artifact_path): except Exception as err: raise MlflowException(err) + def _azure_download_file(self, credentials, local_path): + signed_read_uri = credentials.signed_uri + service = BlobClient.from_blob_url(blob_url=signed_read_uri, credential=None) + try: + with open(local_path, "wb") as output_file: + blob = service.download_blob() + output_file.write(blob.readall()) + except Exception as err: + raise MlflowException(err) + + def _aws_upload_file(self, credentials, local_file): + pass + + def _aws_download_file(self, credentials, local_path): + pass + def log_artifact(self, local_file, artifact_path=None): - self._upload_file(local_file, artifact_path) + run_id = self._extract_run_id(self.artifact_uri) + write_credentials = self._get_write_credentials(run_id, artifact_path) + if write_credentials.credentials.type == 1: + self._azure_upload_file(write_credentials.credentials, local_file) + else: + raise MlflowException('Not implemented yet') def log_artifacts(self, local_dir, artifact_path=None): artifact_path = artifact_path or '' @@ -72,7 +101,7 @@ def log_artifacts(self, local_dir, artifact_path=None): for name in filenames: local_file = os.path.join(dirpath, name) artifact_location = os.path.join(artifact_path, artifact_subdir) - self._upload_file(local_file, artifact_location) + self.log_artifact(local_file, artifact_location) def list_artifacts(self, path=None): run_id = self._extract_run_id(self.artifact_uri) @@ -80,15 +109,11 @@ def list_artifacts(self, path=None): def _download_file(self, remote_file_path, local_path): run_id = self._extract_run_id(self.artifact_uri) - read_credentials = self._get_azure_read_credentials(run_id, remote_file_path) - signed_read_uri = read_credentials.credentials.signed_uri - service = BlobClient.from_blob_url(blob_url=signed_read_uri, credential=None) - try: - with open(local_path, "wb") as output_file: - blob = service.download_blob() - output_file.write(blob.readall()) - except Exception as err: - raise MlflowException(err) + read_credentials = self._get_read_credentials(run_id, remote_file_path) + if read_credentials.credentials.type == 1: + self._azure_upload_file(read_credentials.credentials, local_path) + else: + raise MlflowException('Not implemented yet') def delete_artifacts(self, artifact_path=None): raise MlflowException('Not implemented yet') diff --git a/mlflow/store/artifact/dbfs_artifact_repo.py b/mlflow/store/artifact/dbfs_artifact_repo.py index 4600a0f9741f9..9c70c358ffd51 100644 --- a/mlflow/store/artifact/dbfs_artifact_repo.py +++ b/mlflow/store/artifact/dbfs_artifact_repo.py @@ -13,7 +13,7 @@ from mlflow.utils.rest_utils import http_request, http_request_safe, RESOURCE_DOES_NOT_EXIST from mlflow.utils.string_utils import strip_prefix import mlflow.utils.databricks_utils -from mlflow.utils.uri import is_artifact_acled_uri +from mlflow.utils.uri import is_databricks_acled_artifacts_uri LIST_API_ENDPOINT = '/api/2.0/dbfs/list' GET_STATUS_ENDPOINT = '/api/2.0/dbfs/get-status' @@ -165,7 +165,7 @@ def dbfs_artifact_repo_factory(artifact_uri): :return: Subclass of ArtifactRepository capable of storing artifacts on DBFS. """ cleaned_artifact_uri = artifact_uri.rstrip('/') - if is_artifact_acled_uri(artifact_uri): + if is_databricks_acled_artifacts_uri(artifact_uri): return DatabricksArtifactRepository(artifact_uri) elif mlflow.utils.databricks_utils.is_dbfs_fuse_available() \ and os.environ.get(USE_FUSE_ENV_VAR, "").lower() != "false" \ diff --git a/mlflow/utils/uri.py b/mlflow/utils/uri.py index ec7a4328c24c5..f137ddd58f178 100644 --- a/mlflow/utils/uri.py +++ b/mlflow/utils/uri.py @@ -8,7 +8,7 @@ _INVALID_DB_URI_MSG = "Please refer to https://mlflow.org/docs/latest/tracking.html#storage for " \ "format specifications." -_ACLED_ARTIFACT_URI = "dbfs:/databricks/mlflow-tracking/" + def is_local_uri(uri): """Returns true if this is a local file path (/foo or file:/foo).""" @@ -67,6 +67,12 @@ def get_uri_scheme(uri_or_path): return scheme +def extract_and_normalize_path(uri): + parsed_uri_path = urllib.parse.urlparse(uri).path + normalized_path = posixpath.normpath(parsed_uri_path) + return normalized_path.lstrip("/") + + def append_to_uri_path(uri, *paths): """ Appends the specified POSIX `paths` to the path component of the specified `uri`. @@ -131,5 +137,7 @@ def _join_posixpaths_and_append_absolute_suffixes(prefix_path, suffix_path): return posixpath.join(prefix_path, suffix_path) -def is_artifact_acled_uri(artifact_uri): - return artifact_uri.startswith(_ACLED_ARTIFACT_URI.lstrip('/')) +def is_databricks_acled_artifacts_uri(artifact_uri): + _ACLED_ARTIFACT_URI = "databricks/mlflow-tracking/" + artifact_uri_path = extract_and_normalize_path(artifact_uri) + return artifact_uri_path.startswith(_ACLED_ARTIFACT_URI) From 38630d37c12298bf134149a919d35c93ffd97454 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Wed, 20 May 2020 00:24:06 -0700 Subject: [PATCH 04/28] Addressing comments and fixing _download_file --- .../artifact/databricks_artifact_repo.py | 20 ++++++++++++------- mlflow/store/artifact/dbfs_artifact_repo.py | 11 +++++++++- setup.py | 2 +- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index afc4b9e10f4fe..810aafa102e97 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -11,6 +11,7 @@ from mlflow.utils.databricks_utils import get_databricks_host_creds from mlflow.protos.service_pb2 import MlflowService, ListArtifacts from mlflow.utils.uri import extract_and_normalize_path +from mlflow.utils.proto_json_utils import message_to_json _PATH_PREFIX = "/api/2.0" @@ -19,7 +20,8 @@ class DatabricksArtifactRepository(ArtifactRepository): """ Stores artifacts on Azure/AWS with access control. - The artifact_uri is expected to be of the form dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts/.. + The artifact_uri is expected to be of the form + dbfs:/databricks/mlflow-tracking///artifacts/ """ def __init__(self, artifact_uri): @@ -49,12 +51,12 @@ def _create_json_body(self, run_id, path=None): } def _get_write_credentials(self, run_id, path=None): - return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForWrite, - self._create_json_body(run_id, path)) + json_body = message_to_json(GetCredentialsForWrite(run_id=run_id, path=path)) + return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForWrite, json_body) def _get_read_credentials(self, run_id, path=None): - return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForRead, - self._create_json_body(run_id, path)) + json_body = message_to_json(GetCredentialsForRead(run_id=run_id, path=path)) + return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForRead, json_body) def _azure_upload_file(self, credentials, local_file): signed_write_uri = credentials.signed_uri @@ -82,6 +84,9 @@ def _aws_download_file(self, credentials, local_path): pass def log_artifact(self, local_file, artifact_path=None): + basename = os.path.basename(local_file) + artifact_path = artifact_path or "" + artifact_path = os.path.join(artifact_path, basename) run_id = self._extract_run_id(self.artifact_uri) write_credentials = self._get_write_credentials(run_id, artifact_path) if write_credentials.credentials.type == 1: @@ -105,13 +110,14 @@ def log_artifacts(self, local_dir, artifact_path=None): def list_artifacts(self, path=None): run_id = self._extract_run_id(self.artifact_uri) - return self._call_endpoint(MlflowService, ListArtifacts, self._create_json_body(run_id, path)) + json_body = message_to_json(ListArtifacts(run_id=run_id, path=path)) + return self._call_endpoint(MlflowService, ListArtifacts, json_body).files def _download_file(self, remote_file_path, local_path): run_id = self._extract_run_id(self.artifact_uri) read_credentials = self._get_read_credentials(run_id, remote_file_path) if read_credentials.credentials.type == 1: - self._azure_upload_file(read_credentials.credentials, local_path) + self._azure_download_file(read_credentials.credentials, local_path) else: raise MlflowException('Not implemented yet') diff --git a/mlflow/store/artifact/dbfs_artifact_repo.py b/mlflow/store/artifact/dbfs_artifact_repo.py index 9c70c358ffd51..67b8b8069121e 100644 --- a/mlflow/store/artifact/dbfs_artifact_repo.py +++ b/mlflow/store/artifact/dbfs_artifact_repo.py @@ -13,7 +13,7 @@ from mlflow.utils.rest_utils import http_request, http_request_safe, RESOURCE_DOES_NOT_EXIST from mlflow.utils.string_utils import strip_prefix import mlflow.utils.databricks_utils -from mlflow.utils.uri import is_databricks_acled_artifacts_uri +from mlflow.utils.uri import is_databricks_acled_artifacts_uri, get_uri_scheme LIST_API_ENDPOINT = '/api/2.0/dbfs/list' GET_STATUS_ENDPOINT = '/api/2.0/dbfs/get-status' @@ -161,10 +161,19 @@ def dbfs_artifact_repo_factory(artifact_uri): This factory method is used with URIs of the form ``dbfs:/``. DBFS-backed artifact storage can only be used together with the RestStore. + + In the special case where the URI is of the form ``dbfs:/databricks/mlflow-tracking///, + a DatabricksArtifactRepository is returned. This is capable of storing access controlled artifacts. + + :param artifact_uri: DBFS root artifact URI (string). :return: Subclass of ArtifactRepository capable of storing artifacts on DBFS. """ cleaned_artifact_uri = artifact_uri.rstrip('/') + uri_scheme = get_uri_scheme(artifact_uri) + if uri_scheme != 'dbfs': + raise Exception("DBFS URI must be of the form " + "dbfs:/, but received {uri}".format(uri=artifact_uri)) if is_databricks_acled_artifacts_uri(artifact_uri): return DatabricksArtifactRepository(artifact_uri) elif mlflow.utils.databricks_utils.is_dbfs_fuse_available() \ diff --git a/setup.py b/setup.py index 53d50d1d87f35..a8aeaecfa7662 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ def package_files(directory): package_data={"mlflow": js_files + models_container_server_files + alembic_files}, install_requires=[ 'alembic', + 'azure-storage-blob>=12.0', 'click>=7.0', 'cloudpickle', 'databricks-cli>=0.8.7', @@ -59,7 +60,6 @@ def package_files(directory): "scikit-learn==0.20; python_version < '3.5'", 'boto3>=1.7.12', 'mleap>=0.8.1', - 'azure-storage-blob>=12.0', 'google-cloud-storage', 'azureml-core>=1.2.0' ], From 1cb98e9c74a2ce2579c4bbe6d8d4ecefb84187f1 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Wed, 20 May 2020 15:08:59 -0700 Subject: [PATCH 05/28] Code Clean-up and lint --- generate-protos.sh | 2 +- .../artifact/databricks_artifact_repo.py | 34 +++++++++++-------- mlflow/store/artifact/dbfs_artifact_repo.py | 7 ++-- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/generate-protos.sh b/generate-protos.sh index bcd286ecbf760..16232b7aa64de 100755 --- a/generate-protos.sh +++ b/generate-protos.sh @@ -28,4 +28,4 @@ sed -i'.old' -e "s/$OLD_DATABRICKS/$NEW_DATABRICKS/g" "$PROTOS/service_pb2.py" " rm "$PROTOS/databricks_pb2.py.old" rm "$PROTOS/service_pb2.py.old" rm "$PROTOS/model_registry_pb2.py.old" -rm "$PROTOS/databricks_artifacts__pb2.py.old" +rm "$PROTOS/databricks_artifacts_pb2.py.old" diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index 810aafa102e97..db4f51d2d8e36 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -30,9 +30,6 @@ def __init__(self, artifact_uri): service: extract_api_info_for_service(service, _PATH_PREFIX) for service in [MlflowService, DatabricksMlflowArtifactsService] } - self.credential_type_to_cloud_service = { - - } def _extract_run_id(self, artifact_uri): artifact_path = extract_and_normalize_path(artifact_uri) @@ -41,7 +38,8 @@ def _extract_run_id(self, artifact_uri): def _call_endpoint(self, service, api, json_body): endpoint, method = self._SERVICE_AND_METHOD_TO_INFO[service][api] response_proto = api.Response() - return call_endpoint(get_databricks_host_creds(), endpoint, method, json_body, response_proto) + return call_endpoint(get_databricks_host_creds(), + endpoint, method, json_body, response_proto) def _create_json_body(self, run_id, path=None): path = path or "" @@ -52,11 +50,13 @@ def _create_json_body(self, run_id, path=None): def _get_write_credentials(self, run_id, path=None): json_body = message_to_json(GetCredentialsForWrite(run_id=run_id, path=path)) - return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForWrite, json_body) + return self._call_endpoint(DatabricksMlflowArtifactsService, + GetCredentialsForWrite, json_body) def _get_read_credentials(self, run_id, path=None): json_body = message_to_json(GetCredentialsForRead(run_id=run_id, path=path)) - return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForRead, json_body) + return self._call_endpoint(DatabricksMlflowArtifactsService, + GetCredentialsForRead, json_body) def _azure_upload_file(self, credentials, local_file): signed_write_uri = credentials.signed_uri @@ -83,16 +83,25 @@ def _aws_upload_file(self, credentials, local_file): def _aws_download_file(self, credentials, local_path): pass + def _upload_to_cloud(self, cloud_credentials, local_file): + if cloud_credentials.credentials.type == 1: + self._azure_upload_file(cloud_credentials.credentials, local_file) + else: + raise MlflowException('Not implemented yet') + + def _download_from_cloud(self, cloud_credentials, local_path): + if cloud_credentials.credentials.type == 1: + self._azure_download_file(cloud_credentials.credentials, local_path) + else: + raise MlflowException('Not implemented yet') + def log_artifact(self, local_file, artifact_path=None): basename = os.path.basename(local_file) artifact_path = artifact_path or "" artifact_path = os.path.join(artifact_path, basename) run_id = self._extract_run_id(self.artifact_uri) write_credentials = self._get_write_credentials(run_id, artifact_path) - if write_credentials.credentials.type == 1: - self._azure_upload_file(write_credentials.credentials, local_file) - else: - raise MlflowException('Not implemented yet') + self._upload_to_cloud(write_credentials, local_file) def log_artifacts(self, local_dir, artifact_path=None): artifact_path = artifact_path or '' @@ -116,10 +125,7 @@ def list_artifacts(self, path=None): def _download_file(self, remote_file_path, local_path): run_id = self._extract_run_id(self.artifact_uri) read_credentials = self._get_read_credentials(run_id, remote_file_path) - if read_credentials.credentials.type == 1: - self._azure_download_file(read_credentials.credentials, local_path) - else: - raise MlflowException('Not implemented yet') + self._download_from_cloud(read_credentials, local_path) def delete_artifacts(self, artifact_path=None): raise MlflowException('Not implemented yet') diff --git a/mlflow/store/artifact/dbfs_artifact_repo.py b/mlflow/store/artifact/dbfs_artifact_repo.py index 67b8b8069121e..9b3041705d1e8 100644 --- a/mlflow/store/artifact/dbfs_artifact_repo.py +++ b/mlflow/store/artifact/dbfs_artifact_repo.py @@ -162,9 +162,10 @@ def dbfs_artifact_repo_factory(artifact_uri): This factory method is used with URIs of the form ``dbfs:/``. DBFS-backed artifact storage can only be used together with the RestStore. - In the special case where the URI is of the form ``dbfs:/databricks/mlflow-tracking///, - a DatabricksArtifactRepository is returned. This is capable of storing access controlled artifacts. - + In the special case where the URI is of the form + ``dbfs:/databricks/mlflow-tracking///, + a DatabricksArtifactRepository is returned. This is capable of storing access controlled + artifacts. :param artifact_uri: DBFS root artifact URI (string). :return: Subclass of ArtifactRepository capable of storing artifacts on DBFS. From 870c0aaa20e5117a7ea81d43e0d8e41d11f98165 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Fri, 22 May 2020 15:42:49 -0700 Subject: [PATCH 06/28] Adding multi-part upload logic and unit tests --- mlflow/store/artifact/artifact_repo.py | 21 +-- .../artifact/databricks_artifact_repo.py | 81 ++++++---- mlflow/utils/file_utils.py | 13 ++ .../artifact/test_databricks_artifact_repo.py | 138 ++++++++++++++++++ .../test_dbfs_artifact_repo_delegation.py | 6 + tests/utils/test_uri.py | 23 ++- 6 files changed, 235 insertions(+), 47 deletions(-) create mode 100644 tests/store/artifact/test_databricks_artifact_repo.py diff --git a/mlflow/store/artifact/artifact_repo.py b/mlflow/store/artifact/artifact_repo.py index 408377083c870..864b0de989fcd 100644 --- a/mlflow/store/artifact/artifact_repo.py +++ b/mlflow/store/artifact/artifact_repo.py @@ -80,22 +80,6 @@ def download_artifacts(self, artifact_path, dst_path=None): # TODO: Probably need to add a more efficient method to stream just a single artifact # without downloading it, or to get a pre-signed URL for cloud storage. - if dst_path is None: - dst_path = tempfile.mkdtemp() - dst_path = os.path.abspath(dst_path) - if not os.path.exists(dst_path): - raise MlflowException( - message=( - "The destination path for downloaded artifacts does not" - " exist! Destination path: {dst_path}".format(dst_path=dst_path)), - error_code=RESOURCE_DOES_NOT_EXIST) - elif not os.path.isdir(dst_path): - raise MlflowException( - message=( - "The destination path for downloaded artifacts must be a directory!" - " Destination path: {dst_path}".format(dst_path=dst_path)), - error_code=INVALID_PARAMETER_VALUE) - def download_file(fullpath): dirpath, _ = posixpath.split(fullpath) local_dir_path = os.path.join(dst_path, dirpath) @@ -120,6 +104,11 @@ def download_artifact_dir(dir_path): else: download_file(file_info.path) return local_dir + + if dst_path is None: + dst_path = tempfile.mkdtemp() + dst_path = os.path.abspath(dst_path) + if not os.path.exists(dst_path): raise MlflowException( message=( diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index db4f51d2d8e36..e964fcb349d69 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -3,17 +3,18 @@ import os from mlflow.exceptions import MlflowException from mlflow.store.artifact.artifact_repo import ArtifactRepository -from mlflow.utils.string_utils import strip_suffix -from mlflow.utils.file_utils import relative_path_to_artifact_path -from mlflow.utils.rest_utils import call_endpoint, extract_api_info_for_service -from mlflow.protos.databricks_artifacts_pb2 import DatabricksMlflowArtifactsService -from mlflow.protos.databricks_artifacts_pb2 import GetCredentialsForWrite, GetCredentialsForRead -from mlflow.utils.databricks_utils import get_databricks_host_creds +from mlflow.protos.databricks_artifacts_pb2 import DatabricksMlflowArtifactsService, GetCredentialsForWrite, \ + GetCredentialsForRead from mlflow.protos.service_pb2 import MlflowService, ListArtifacts -from mlflow.utils.uri import extract_and_normalize_path +from mlflow.utils.uri import extract_and_normalize_path, is_databricks_acled_artifacts_uri from mlflow.utils.proto_json_utils import message_to_json +from mlflow.utils.file_utils import relative_path_to_artifact_path, yield_file_in_chunks +from mlflow.utils.rest_utils import call_endpoint, extract_api_info_for_service +from mlflow.utils.databricks_utils import get_databricks_host_creds _PATH_PREFIX = "/api/2.0" +_AZURE_SINGLE_BLOCK_BLOB_MAX_SIZE = 256000000 - 1 # Can upload blob in single request if it is no more thn 256 MB +_AZURE_MAX_BLOCK_CHUNK_SIZE = 100000000 # Maximum size of each block allowed is 100 MB in stage blob class DatabricksArtifactRepository(ArtifactRepository): @@ -26,13 +27,20 @@ class DatabricksArtifactRepository(ArtifactRepository): def __init__(self, artifact_uri): super(DatabricksArtifactRepository, self).__init__(artifact_uri) + if not artifact_uri.startswith('dbfs:/'): + raise MlflowException('DatabricksArtifactRepository URI must start with dbfs:/') + if not is_databricks_acled_artifacts_uri(artifact_uri): + raise MlflowException('Artifact URI incorrect. Expected path prefix to be ' + 'databricks/mlflow-tracking/path/to/artifact/..') + + self.run_id = self._extract_run_id() self._SERVICE_AND_METHOD_TO_INFO = { service: extract_api_info_for_service(service, _PATH_PREFIX) for service in [MlflowService, DatabricksMlflowArtifactsService] } - def _extract_run_id(self, artifact_uri): - artifact_path = extract_and_normalize_path(artifact_uri) + def _extract_run_id(self): + artifact_path = extract_and_normalize_path(self.artifact_uri) return artifact_path.split('/')[3] def _call_endpoint(self, service, api, json_body): @@ -58,12 +66,24 @@ def _get_read_credentials(self, run_id, path=None): return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForRead, json_body) - def _azure_upload_file(self, credentials, local_file): - signed_write_uri = credentials.signed_uri - service = BlobClient.from_blob_url(blob_url=signed_write_uri, credential=None) + def _azure_upload_file(self, credentials, local_file, artifact_path): try: - with open(local_file, "rb") as data: - service.upload_blob(data, overwrite=True) + if os.path.getsize(local_file) < _AZURE_SINGLE_BLOCK_BLOB_MAX_SIZE: + signed_write_uri = credentials.signed_uri + service = BlobClient.from_blob_url(blob_url=signed_write_uri, credential=None) + with open(local_file, "rb") as data: + service.upload_blob(data, overwrite=True) + else: + uploading_block_list = list() + for chunk in yield_file_in_chunks(local_file, _AZURE_MAX_BLOCK_CHUNK_SIZE): + signed_write_uri = self._get_write_credentials(self.run_id, artifact_path).credentials.signed_uri + service = BlobClient.from_blob_url(blob_url=signed_write_uri, credential=None) + block_id = base64.b64encode(uuid.uuid4().hex.encode()) + service.stage_block(block_id, chunk) + uploading_block_list.append(block_id) + signed_write_uri = self._get_write_credentials(self.run_id, artifact_path).credentials.signed_uri + service = BlobClient.from_blob_url(blob_url=signed_write_uri, credential=None) + service.commit_block_list(uploading_block_list) except Exception as err: raise MlflowException(err) @@ -83,9 +103,9 @@ def _aws_upload_file(self, credentials, local_file): def _aws_download_file(self, credentials, local_path): pass - def _upload_to_cloud(self, cloud_credentials, local_file): + def _upload_to_cloud(self, cloud_credentials, local_file, artifact_path): if cloud_credentials.credentials.type == 1: - self._azure_upload_file(cloud_credentials.credentials, local_file) + self._azure_upload_file(cloud_credentials.credentials, local_file, artifact_path) else: raise MlflowException('Not implemented yet') @@ -99,32 +119,33 @@ def log_artifact(self, local_file, artifact_path=None): basename = os.path.basename(local_file) artifact_path = artifact_path or "" artifact_path = os.path.join(artifact_path, basename) - run_id = self._extract_run_id(self.artifact_uri) - write_credentials = self._get_write_credentials(run_id, artifact_path) - self._upload_to_cloud(write_credentials, local_file) + write_credentials = self._get_write_credentials(self.run_id, artifact_path) + self._upload_to_cloud(write_credentials, local_file, artifact_path) def log_artifacts(self, local_dir, artifact_path=None): artifact_path = artifact_path or '' - basename = os.path.basename(strip_suffix(local_dir, '/')) for (dirpath, _, filenames) in os.walk(local_dir): - artifact_subdir = basename + artifact_subdir = artifact_path if dirpath != local_dir: rel_path = os.path.relpath(dirpath, local_dir) rel_path = relative_path_to_artifact_path(rel_path) - artifact_subdir = os.path.join(artifact_subdir, rel_path) + artifact_subdir = os.path.join(artifact_path, rel_path) for name in filenames: - local_file = os.path.join(dirpath, name) - artifact_location = os.path.join(artifact_path, artifact_subdir) - self.log_artifact(local_file, artifact_location) + file_path = os.path.join(dirpath, name) + self.log_artifact(file_path, artifact_subdir) def list_artifacts(self, path=None): - run_id = self._extract_run_id(self.artifact_uri) - json_body = message_to_json(ListArtifacts(run_id=run_id, path=path)) - return self._call_endpoint(MlflowService, ListArtifacts, json_body).files + json_body = message_to_json(ListArtifacts(run_id=self.run_id, path=path)) + artifact_list = self._call_endpoint(MlflowService, ListArtifacts, json_body).files + # If `path` is a file, ListArtifacts returns a single list element with the + # same name as `path`. The list_artifacts API expects us to return an empty list in this + # case, so we do so here. + if len(artifact_list) == 1 and artifact_list[0].path == path: + return [] + return artifact_list def _download_file(self, remote_file_path, local_path): - run_id = self._extract_run_id(self.artifact_uri) - read_credentials = self._get_read_credentials(run_id, remote_file_path) + read_credentials = self._get_read_credentials(self.run_id, remote_file_path) self._download_from_cloud(read_credentials, local_path) def delete_artifacts(self, artifact_path=None): diff --git a/mlflow/utils/file_utils.py b/mlflow/utils/file_utils.py index ad06659b4c17b..568a1a617e07c 100644 --- a/mlflow/utils/file_utils.py +++ b/mlflow/utils/file_utils.py @@ -396,3 +396,16 @@ def get_local_path_or_none(path_or_uri): return local_file_uri_to_path(path_or_uri) else: return None + + +def yield_file_in_chunks(file, chunk_size=100000000): + """ + Generator to chunk-ify the inputted file based on the chunk-size. + """ + with open(file, "rb") as f: + while True: + chunk = f.read(chunk_size) + if chunk: + yield chunk + else: + break diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py new file mode 100644 index 0000000000000..f04a245ee7dea --- /dev/null +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +import os + +import pytest +import mock +from unittest.mock import ANY + +from mlflow.exceptions import MlflowException +from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository +from mlflow.store.artifact.dbfs_artifact_repo import DatabricksArtifactRepository +from mlflow.protos.service_pb2 import ListArtifacts, FileInfo +from mlflow.protos.databricks_artifacts_pb2 import GetCredentialsForWrite, GetCredentialsForRead, \ + ArtifactCredentialType, ArtifactCredentialInfo + + +@pytest.fixture() +def databricks_artifact_repo(): + return get_artifact_repository('dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN/artifact') + + +DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE = 'mlflow.store.artifact.databricks_artifact_repo' +DATABRICKS_ARTIFACT_REPOSITORY = DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".DatabricksArtifactRepository" + +TEST_FILE_1_CONTENT = u"Hello 🍆🍔".encode("utf-8") +TEST_FILE_2_CONTENT = u"World 🍆🍔🍆".encode("utf-8") +TEST_FILE_3_CONTENT = u"¡🍆🍆🍔🍆🍆!".encode("utf-8") + + +@pytest.fixture() +def test_file(tmpdir): + p = tmpdir.join("test.txt") + with open(p.strpath, 'wb') as f: + f.write(TEST_FILE_1_CONTENT) + return p + + +@pytest.fixture() +def test_dir(tmpdir): + with open(tmpdir.mkdir('subdir').join('test.txt').strpath, 'wb') as f: + f.write(TEST_FILE_2_CONTENT) + with open(tmpdir.join('test.txt').strpath, 'wb') as f: + f.write(bytes(TEST_FILE_3_CONTENT)) + with open(tmpdir.join('empty-file').strpath, 'w'): + pass + return tmpdir + + +LIST_ARTIFACTS_PROTO_RESPONSE = [FileInfo(path='test/a.txt', is_dir=False, file_size=100), + FileInfo(path='test/dir', is_dir=True, file_size=0)] + +LIST_ARTIFACTS_SINGLE_FILE_PROTO_RESPONSE = [FileInfo(path='a.txt', is_dir=False, file_size=0)] +MOCK_AZURE_SIGNED_URI = "this_is_a_mock_sas_for_azure" +MOCK_RUN_ID = 'MOCK-RUN' + + +class TestDatabricksArtifactRepository(object): + def test_init_validation_and_cleaning(self): + repo = get_artifact_repository('dbfs:/databricks/mlflow-tracking/EXP/RUN/artifact') + assert repo.artifact_uri == 'dbfs:/databricks/mlflow-tracking/EXP/RUN/artifact' + with pytest.raises(MlflowException): + DatabricksArtifactRepository('s3://test') + with pytest.raises(MlflowException): + DatabricksArtifactRepository('dbfs:/databricks/mlflow/EXP/RUN/artifact') + + @pytest.mark.parametrize("artifact_path,expected_location", [ + (None, 'test.txt'), + ('output', 'output/test.txt'), + ('', 'test.txt'), + ]) + def test_log_artifact(self, databricks_artifact_repo, test_file, artifact_path, expected_location): + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') as write_credentials_mock, \ + mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._azure_upload_file') as azure_upload_mock: + mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, + type=ArtifactCredentialType.AZURE_SAS_URI) + write_credentials_response_proto = GetCredentialsForWrite.Response( + credentials=mock_credentials) + write_credentials_mock.return_value = write_credentials_response_proto + # get_run_id_mock.return_value = MOCK_RUN_ID + azure_upload_mock.return_value = None + databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) + write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) + azure_upload_mock.assert_called_with(mock_credentials, test_file.strpath) + + @pytest.mark.parametrize("artifact_path", [ + None, + 'output/', + '', + ]) + def test_log_artifacts(self, databricks_artifact_repo, test_dir, artifact_path): + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.log_artifact') as log_artifact_mock: + log_artifact_mock.return_value = None + databricks_artifact_repo.log_artifacts(test_dir.strpath, artifact_path) + artifact_path = artifact_path or '' + calls = [mock.call(os.path.join(test_dir.strpath, 'empty-file'), os.path.join(artifact_path, '')), + mock.call(os.path.join(test_dir.strpath, 'test.txt'), os.path.join(artifact_path, '')), + mock.call(os.path.join(test_dir.strpath, 'subdir/test.txt'), + os.path.join(artifact_path, 'subdir'))] + log_artifact_mock.assert_has_calls(calls) + + def test_list_artifacts(self, databricks_artifact_repo): + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._call_endpoint') as call_endpoint_mock: + list_artifact_response_proto = ListArtifacts.Response(root_uri='', files=LIST_ARTIFACTS_PROTO_RESPONSE) + call_endpoint_mock.return_value = list_artifact_response_proto + artifacts = databricks_artifact_repo.list_artifacts('test/') + assert len(artifacts) == 2 + assert artifacts[0].path == 'test/a.txt' + assert artifacts[0].is_dir is False + assert artifacts[0].file_size == 100 + assert artifacts[1].path == 'test/dir' + assert artifacts[1].is_dir is True + assert artifacts[1].file_size is 0 + + # Calling list_artifacts() on a path that's a file should return an empty list + list_artifact_response_proto = ListArtifacts.Response(root_uri='', + files=LIST_ARTIFACTS_SINGLE_FILE_PROTO_RESPONSE) + call_endpoint_mock.return_value = list_artifact_response_proto + artifacts = databricks_artifact_repo.list_artifacts('a.txt') + assert len(artifacts) == 0 + + @pytest.mark.parametrize("remote_file_path, local_path", [ + ('test_file.txt', ''), + ('test_file.txt', None), + ('output/test_file', None), + ]) + def test_databricks_download_file(self, databricks_artifact_repo, remote_file_path, local_path): + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') as read_credentials_mock, \ + mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \ + mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._azure_download_file') as azure_download_mock: + mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, + type=ArtifactCredentialType.AZURE_SAS_URI) + read_credentials_response_proto = GetCredentialsForRead.Response( + credentials=mock_credentials) + read_credentials_mock.return_value = read_credentials_response_proto + azure_download_mock.return_value = None + get_list_mock.return_value = [] + databricks_artifact_repo.download_artifacts(remote_file_path, local_path) + read_credentials_mock.assert_called_with(MOCK_RUN_ID, remote_file_path) + azure_download_mock.assert_called_with(mock_credentials, ANY) diff --git a/tests/store/artifact/test_dbfs_artifact_repo_delegation.py b/tests/store/artifact/test_dbfs_artifact_repo_delegation.py index ac69a919bb601..419fd9469b396 100644 --- a/tests/store/artifact/test_dbfs_artifact_repo_delegation.py +++ b/tests/store/artifact/test_dbfs_artifact_repo_delegation.py @@ -5,6 +5,7 @@ from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository from mlflow.store.artifact.local_artifact_repo import LocalArtifactRepository from mlflow.store.artifact.dbfs_artifact_repo import DbfsRestArtifactRepository +from mlflow.store.artifact.dbfs_artifact_repo import DatabricksArtifactRepository from mlflow.utils.rest_utils import MlflowHostCreds @@ -39,3 +40,8 @@ def test_dbfs_artifact_repo_delegates_to_correct_repo( rest_repo = get_artifact_repository(artifact_uri) assert isinstance(rest_repo, DbfsRestArtifactRepository) assert rest_repo.artifact_uri == artifact_uri + + artifact_uri = "dbfs:/databricks/mlflow-tracking/my/absolute/dbfs/path" + databricks_repo = get_artifact_repository(artifact_uri) + assert isinstance(databricks_repo, DatabricksArtifactRepository) + assert databricks_repo.artifact_uri == artifact_uri diff --git a/tests/utils/test_uri.py b/tests/utils/test_uri.py index 161fcc948add9..90ac58afc99b7 100644 --- a/tests/utils/test_uri.py +++ b/tests/utils/test_uri.py @@ -4,7 +4,8 @@ from mlflow.exceptions import MlflowException from mlflow.store.db.db_types import DATABASE_ENGINES from mlflow.utils.uri import is_databricks_uri, is_http_uri, is_local_uri, \ - extract_db_type_from_uri, get_db_profile_from_uri, get_uri_scheme, append_to_uri_path + extract_db_type_from_uri, get_db_profile_from_uri, get_uri_scheme, append_to_uri_path, \ + extract_and_normalize_path, is_databricks_acled_artifacts_uri def test_extract_db_type_from_uri(): @@ -191,3 +192,23 @@ def test_append_to_uri_path_preserves_uri_schemes_hosts_queries_and_fragments(): "creds=mycreds,param=value#*frag@*" ), ]) + + +def test_extract_and_normalize_path(): + base_uri = 'databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts' + assert extract_and_normalize_path('dbfs:databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') == base_uri + assert extract_and_normalize_path('dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') == base_uri + assert extract_and_normalize_path('dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') == base_uri + assert extract_and_normalize_path('dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/') == base_uri + assert extract_and_normalize_path('dbfs:///databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//') == base_uri + assert extract_and_normalize_path('dbfs:databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//') == base_uri + + +def test_is_databricks_acled_artifacts_uri(): + assert is_databricks_acled_artifacts_uri('dbfs:databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') + assert is_databricks_acled_artifacts_uri('dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') + assert is_databricks_acled_artifacts_uri('dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') + assert is_databricks_acled_artifacts_uri('dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/') + assert is_databricks_acled_artifacts_uri('dbfs:///databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//') + assert is_databricks_acled_artifacts_uri('dbfs:databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//') + assert not is_databricks_acled_artifacts_uri('dbfs:/databricks/mlflow//EXP_ID//RUN_ID///artifacts//') From 10cf458f525dc538d72928cb1139354a8a73f138 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Tue, 26 May 2020 11:38:20 -0700 Subject: [PATCH 07/28] Addressing comments --- mlflow/store/artifact/artifact_repo.py | 1 + .../artifact/databricks_artifact_repo.py | 70 ++++++++++--------- .../artifact/test_databricks_artifact_repo.py | 1 - 3 files changed, 38 insertions(+), 34 deletions(-) diff --git a/mlflow/store/artifact/artifact_repo.py b/mlflow/store/artifact/artifact_repo.py index 864b0de989fcd..ce43cb885d9d0 100644 --- a/mlflow/store/artifact/artifact_repo.py +++ b/mlflow/store/artifact/artifact_repo.py @@ -86,6 +86,7 @@ def download_file(fullpath): local_file_path = os.path.join(dst_path, fullpath) if not os.path.exists(local_dir_path): os.makedirs(local_dir_path) + print (fullpath) self._download_file(remote_file_path=fullpath, local_path=local_file_path) return local_file_path diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index e964fcb349d69..2096f6a557443 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -1,10 +1,14 @@ from azure.storage.blob import BlobClient +from azure.core.exceptions import ClientAuthenticationError import os +import uuid +import base64 + from mlflow.exceptions import MlflowException from mlflow.store.artifact.artifact_repo import ArtifactRepository from mlflow.protos.databricks_artifacts_pb2 import DatabricksMlflowArtifactsService, GetCredentialsForWrite, \ - GetCredentialsForRead + GetCredentialsForRead, ArtifactCredentialType from mlflow.protos.service_pb2 import MlflowService, ListArtifacts from mlflow.utils.uri import extract_and_normalize_path, is_databricks_acled_artifacts_uri from mlflow.utils.proto_json_utils import message_to_json @@ -13,8 +17,11 @@ from mlflow.utils.databricks_utils import get_databricks_host_creds _PATH_PREFIX = "/api/2.0" -_AZURE_SINGLE_BLOCK_BLOB_MAX_SIZE = 256000000 - 1 # Can upload blob in single request if it is no more thn 256 MB -_AZURE_MAX_BLOCK_CHUNK_SIZE = 100000000 # Maximum size of each block allowed is 100 MB in stage blob +_AZURE_MAX_BLOCK_CHUNK_SIZE = 100000000 # Maximum size of each block allowed is 100 MB in stage_block +_SERVICE_AND_METHOD_TO_INFO = { + service: extract_api_info_for_service(service, _PATH_PREFIX) + for service in [MlflowService, DatabricksMlflowArtifactsService] +} class DatabricksArtifactRepository(ArtifactRepository): @@ -32,30 +39,18 @@ def __init__(self, artifact_uri): if not is_databricks_acled_artifacts_uri(artifact_uri): raise MlflowException('Artifact URI incorrect. Expected path prefix to be ' 'databricks/mlflow-tracking/path/to/artifact/..') - self.run_id = self._extract_run_id() - self._SERVICE_AND_METHOD_TO_INFO = { - service: extract_api_info_for_service(service, _PATH_PREFIX) - for service in [MlflowService, DatabricksMlflowArtifactsService] - } def _extract_run_id(self): artifact_path = extract_and_normalize_path(self.artifact_uri) return artifact_path.split('/')[3] def _call_endpoint(self, service, api, json_body): - endpoint, method = self._SERVICE_AND_METHOD_TO_INFO[service][api] + endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api] response_proto = api.Response() return call_endpoint(get_databricks_host_creds(), endpoint, method, json_body, response_proto) - def _create_json_body(self, run_id, path=None): - path = path or "" - return { - "run_id": run_id, - "path": path - } - def _get_write_credentials(self, run_id, path=None): json_body = message_to_json(GetCredentialsForWrite(run_id=run_id, path=path)) return self._call_endpoint(DatabricksMlflowArtifactsService, @@ -67,23 +62,32 @@ def _get_read_credentials(self, run_id, path=None): GetCredentialsForRead, json_body) def _azure_upload_file(self, credentials, local_file, artifact_path): + """ + Uploads a file to a given Azure storage location. + + The function uses a file chunking generator, with 100 MB being the size limit for each chunk. + This limit is imposed by the stage_block API in azure-storage-blob. + In the case the file size is large and the upload takes longer than the validity of the given credentials, + a new credential is generated and the operation continues. + + Finally, a set of credentials is generated before the commit, since the prevailing credentials could + expire in the time between the last stage_block and the actually commit. + """ + service = BlobClient.from_blob_url(blob_url=credentials.signed_uri, credential=None) try: - if os.path.getsize(local_file) < _AZURE_SINGLE_BLOCK_BLOB_MAX_SIZE: - signed_write_uri = credentials.signed_uri - service = BlobClient.from_blob_url(blob_url=signed_write_uri, credential=None) - with open(local_file, "rb") as data: - service.upload_blob(data, overwrite=True) - else: - uploading_block_list = list() - for chunk in yield_file_in_chunks(local_file, _AZURE_MAX_BLOCK_CHUNK_SIZE): - signed_write_uri = self._get_write_credentials(self.run_id, artifact_path).credentials.signed_uri - service = BlobClient.from_blob_url(blob_url=signed_write_uri, credential=None) - block_id = base64.b64encode(uuid.uuid4().hex.encode()) + uploading_block_list = list() + for chunk in yield_file_in_chunks(local_file, _AZURE_MAX_BLOCK_CHUNK_SIZE): + block_id = base64.b64encode(uuid.uuid4().hex.encode()) + try: + service.stage_block(block_id, chunk) + except ClientAuthenticationError: + new_credential = self._get_write_credentials(self.run_id, artifact_path).credentials.signed_uri + service = BlobClient.from_blob_url(blob_url=new_credential, credential=None) service.stage_block(block_id, chunk) - uploading_block_list.append(block_id) - signed_write_uri = self._get_write_credentials(self.run_id, artifact_path).credentials.signed_uri - service = BlobClient.from_blob_url(blob_url=signed_write_uri, credential=None) - service.commit_block_list(uploading_block_list) + uploading_block_list.append(block_id) + signed_write_uri = self._get_write_credentials(self.run_id, artifact_path).credentials.signed_uri + service = BlobClient.from_blob_url(blob_url=signed_write_uri, credential=None) + service.commit_block_list(uploading_block_list) except Exception as err: raise MlflowException(err) @@ -104,13 +108,13 @@ def _aws_download_file(self, credentials, local_path): pass def _upload_to_cloud(self, cloud_credentials, local_file, artifact_path): - if cloud_credentials.credentials.type == 1: + if cloud_credentials.credentials.type == ArtifactCredentialType.AZURE_SAS_URI: self._azure_upload_file(cloud_credentials.credentials, local_file, artifact_path) else: raise MlflowException('Not implemented yet') def _download_from_cloud(self, cloud_credentials, local_path): - if cloud_credentials.credentials.type == 1: + if cloud_credentials.credentials.type == ArtifactCredentialType.AZURE_SAS_URI: self._azure_download_file(cloud_credentials.credentials, local_path) else: raise MlflowException('Not implemented yet') diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index f04a245ee7dea..e94c371b269e6 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -75,7 +75,6 @@ def test_log_artifact(self, databricks_artifact_repo, test_file, artifact_path, write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto - # get_run_id_mock.return_value = MOCK_RUN_ID azure_upload_mock.return_value = None databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) From ecdce6a8cf24f5ba2589ee59b79e09ea7b089d32 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Wed, 27 May 2020 16:50:36 -0700 Subject: [PATCH 08/28] Addressing comments, making azure download more memory efficent and other fixes. --- mlflow/store/artifact/artifact_repo.py | 2 +- .../artifact/databricks_artifact_repo.py | 69 ++++++++---- mlflow/store/artifact/dbfs_artifact_repo.py | 2 +- .../artifact/test_databricks_artifact_repo.py | 101 ++++++++++++++---- tests/utils/test_uri.py | 39 ++++--- 5 files changed, 156 insertions(+), 57 deletions(-) diff --git a/mlflow/store/artifact/artifact_repo.py b/mlflow/store/artifact/artifact_repo.py index ce43cb885d9d0..0105c86eb5649 100644 --- a/mlflow/store/artifact/artifact_repo.py +++ b/mlflow/store/artifact/artifact_repo.py @@ -81,12 +81,12 @@ def download_artifacts(self, artifact_path, dst_path=None): # TODO: Probably need to add a more efficient method to stream just a single artifact # without downloading it, or to get a pre-signed URL for cloud storage. def download_file(fullpath): + fullpath = fullpath.rstrip('/') dirpath, _ = posixpath.split(fullpath) local_dir_path = os.path.join(dst_path, dirpath) local_file_path = os.path.join(dst_path, fullpath) if not os.path.exists(local_dir_path): os.makedirs(local_dir_path) - print (fullpath) self._download_file(remote_file_path=fullpath, local_path=local_file_path) return local_file_path diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index 2096f6a557443..c507ac0cbf921 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -4,10 +4,13 @@ import os import uuid import base64 +import logging +import requests from mlflow.exceptions import MlflowException from mlflow.store.artifact.artifact_repo import ArtifactRepository -from mlflow.protos.databricks_artifacts_pb2 import DatabricksMlflowArtifactsService, GetCredentialsForWrite, \ +from mlflow.protos.databricks_artifacts_pb2 import DatabricksMlflowArtifactsService, \ + GetCredentialsForWrite, \ GetCredentialsForRead, ArtifactCredentialType from mlflow.protos.service_pb2 import MlflowService, ListArtifacts from mlflow.utils.uri import extract_and_normalize_path, is_databricks_acled_artifacts_uri @@ -16,8 +19,9 @@ from mlflow.utils.rest_utils import call_endpoint, extract_api_info_for_service from mlflow.utils.databricks_utils import get_databricks_host_creds +_logger = logging.getLogger(__name__) _PATH_PREFIX = "/api/2.0" -_AZURE_MAX_BLOCK_CHUNK_SIZE = 100000000 # Maximum size of each block allowed is 100 MB in stage_block +_AZURE_MAX_BLOCK_CHUNK_SIZE = 100000000 # Max. size of each block allowed is 100 MB in stage_block _SERVICE_AND_METHOD_TO_INFO = { service: extract_api_info_for_service(service, _PATH_PREFIX) for service in [MlflowService, DatabricksMlflowArtifactsService] @@ -42,6 +46,17 @@ def __init__(self, artifact_uri): self.run_id = self._extract_run_id() def _extract_run_id(self): + """ + The artifact_uri is expected to be + dbfs:/databricks/mlflow-tracking///artifacts/ + Once the path from the inputted uri is extracted and normalized, is is + expected to be of the form + databricks/mlflow-tracking///artifacts/ + + Hence the run_id is the 4th element of the normalized path. + + :return: run_id extracted from the artifact_uri + """ artifact_path = extract_and_normalize_path(self.artifact_uri) return artifact_path.split('/')[3] @@ -65,39 +80,54 @@ def _azure_upload_file(self, credentials, local_file, artifact_path): """ Uploads a file to a given Azure storage location. - The function uses a file chunking generator, with 100 MB being the size limit for each chunk. + The function uses a file chunking generator with 100 MB being the size limit for each chunk. This limit is imposed by the stage_block API in azure-storage-blob. - In the case the file size is large and the upload takes longer than the validity of the given credentials, - a new credential is generated and the operation continues. + In the case the file size is large and the upload takes longer than the validity of the + given credentials, a new set of credentials are generated and the operation continues. This + is the reason for the first nested try-except block - Finally, a set of credentials is generated before the commit, since the prevailing credentials could - expire in the time between the last stage_block and the actually commit. + Finally, since the prevailing credentials could expire in the time between the last + stage_block and the commit, a second try-except block refreshes credentials if needed. """ - service = BlobClient.from_blob_url(blob_url=credentials.signed_uri, credential=None) try: + service = BlobClient.from_blob_url(blob_url=credentials.signed_uri, credential=None) uploading_block_list = list() for chunk in yield_file_in_chunks(local_file, _AZURE_MAX_BLOCK_CHUNK_SIZE): block_id = base64.b64encode(uuid.uuid4().hex.encode()) try: service.stage_block(block_id, chunk) except ClientAuthenticationError: - new_credential = self._get_write_credentials(self.run_id, artifact_path).credentials.signed_uri - service = BlobClient.from_blob_url(blob_url=new_credential, credential=None) + _logger.warning( + "Failed to authorize request, possibly due to credential expiration." + "Refreshing credentials and trying again..") + credentials = self._get_write_credentials(self.run_id, + artifact_path).credentials.signed_uri + service = BlobClient.from_blob_url(blob_url=credentials, credential=None) service.stage_block(block_id, chunk) uploading_block_list.append(block_id) - signed_write_uri = self._get_write_credentials(self.run_id, artifact_path).credentials.signed_uri - service = BlobClient.from_blob_url(blob_url=signed_write_uri, credential=None) - service.commit_block_list(uploading_block_list) + try: + service.commit_block_list(uploading_block_list) + except ClientAuthenticationError: + _logger.warning( + "Failed to authorize request, possibly due to credential expiration." + "Refreshing credentials and trying again..") + credentials = self._get_write_credentials(self.run_id, + artifact_path).credentials.signed_uri + service = BlobClient.from_blob_url(blob_url=credentials, credential=None) + service.commit_block_list(uploading_block_list) except Exception as err: raise MlflowException(err) def _azure_download_file(self, credentials, local_path): - signed_read_uri = credentials.signed_uri - service = BlobClient.from_blob_url(blob_url=signed_read_uri, credential=None) try: + signed_read_uri = credentials.signed_uri + response = requests.get(signed_read_uri) + response.raise_for_status() with open(local_path, "wb") as output_file: - blob = service.download_blob() - output_file.write(blob.readall()) + for chunk in response.iter_content(_AZURE_MAX_BLOCK_CHUNK_SIZE): + if not chunk: + break + output_file.write(chunk) except Exception as err: raise MlflowException(err) @@ -127,7 +157,7 @@ def log_artifact(self, local_file, artifact_path=None): self._upload_to_cloud(write_credentials, local_file, artifact_path) def log_artifacts(self, local_dir, artifact_path=None): - artifact_path = artifact_path or '' + artifact_path = artifact_path or "" for (dirpath, _, filenames) in os.walk(local_dir): artifact_subdir = artifact_path if dirpath != local_dir: @@ -144,7 +174,8 @@ def list_artifacts(self, path=None): # If `path` is a file, ListArtifacts returns a single list element with the # same name as `path`. The list_artifacts API expects us to return an empty list in this # case, so we do so here. - if len(artifact_list) == 1 and artifact_list[0].path == path: + if len(artifact_list) == 1 and artifact_list[0].path == path \ + and not artifact_list[0].is_dir: return [] return artifact_list diff --git a/mlflow/store/artifact/dbfs_artifact_repo.py b/mlflow/store/artifact/dbfs_artifact_repo.py index 9b3041705d1e8..e56f615043328 100644 --- a/mlflow/store/artifact/dbfs_artifact_repo.py +++ b/mlflow/store/artifact/dbfs_artifact_repo.py @@ -163,7 +163,7 @@ def dbfs_artifact_repo_factory(artifact_uri): storage can only be used together with the RestStore. In the special case where the URI is of the form - ``dbfs:/databricks/mlflow-tracking///, + `dbfs:/databricks/mlflow-tracking///', a DatabricksArtifactRepository is returned. This is capable of storing access controlled artifacts. diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index e94c371b269e6..556cd00058f1d 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -4,6 +4,7 @@ import pytest import mock from unittest.mock import ANY +from azure.storage.blob import BlobClient from mlflow.exceptions import MlflowException from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository @@ -19,36 +20,31 @@ def databricks_artifact_repo(): DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE = 'mlflow.store.artifact.databricks_artifact_repo' -DATABRICKS_ARTIFACT_REPOSITORY = DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".DatabricksArtifactRepository" - -TEST_FILE_1_CONTENT = u"Hello 🍆🍔".encode("utf-8") -TEST_FILE_2_CONTENT = u"World 🍆🍔🍆".encode("utf-8") -TEST_FILE_3_CONTENT = u"¡🍆🍆🍔🍆🍆!".encode("utf-8") +DATABRICKS_ARTIFACT_REPOSITORY = DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + \ + ".DatabricksArtifactRepository" @pytest.fixture() def test_file(tmpdir): + test_file_content = u"Hello 🍆🍔".encode("utf-8") p = tmpdir.join("test.txt") with open(p.strpath, 'wb') as f: - f.write(TEST_FILE_1_CONTENT) + f.write(test_file_content) return p @pytest.fixture() def test_dir(tmpdir): + test_file_content = u"World 🍆🍔🍆".encode("utf-8") with open(tmpdir.mkdir('subdir').join('test.txt').strpath, 'wb') as f: - f.write(TEST_FILE_2_CONTENT) + f.write(test_file_content) with open(tmpdir.join('test.txt').strpath, 'wb') as f: - f.write(bytes(TEST_FILE_3_CONTENT)) + f.write(bytes(test_file_content)) with open(tmpdir.join('empty-file').strpath, 'w'): pass return tmpdir -LIST_ARTIFACTS_PROTO_RESPONSE = [FileInfo(path='test/a.txt', is_dir=False, file_size=100), - FileInfo(path='test/dir', is_dir=True, file_size=0)] - -LIST_ARTIFACTS_SINGLE_FILE_PROTO_RESPONSE = [FileInfo(path='a.txt', is_dir=False, file_size=0)] MOCK_AZURE_SIGNED_URI = "this_is_a_mock_sas_for_azure" MOCK_RUN_ID = 'MOCK-RUN' @@ -57,19 +53,36 @@ class TestDatabricksArtifactRepository(object): def test_init_validation_and_cleaning(self): repo = get_artifact_repository('dbfs:/databricks/mlflow-tracking/EXP/RUN/artifact') assert repo.artifact_uri == 'dbfs:/databricks/mlflow-tracking/EXP/RUN/artifact' + assert repo.run_id == 'RUN' with pytest.raises(MlflowException): DatabricksArtifactRepository('s3://test') with pytest.raises(MlflowException): DatabricksArtifactRepository('dbfs:/databricks/mlflow/EXP/RUN/artifact') + def test_extract_run_id(self): + expected_run_id = "RUN_ID" + repo = get_artifact_repository('dbfs:/databricks/mlflow-tracking/EXP/RUN_ID/artifact') + assert repo.run_id == expected_run_id + repo = get_artifact_repository('dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') + assert repo.run_id == expected_run_id + repo = get_artifact_repository( + 'dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/') + assert repo.run_id == expected_run_id + repo = get_artifact_repository( + 'dbfs:/databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//') + assert repo.run_id == expected_run_id + @pytest.mark.parametrize("artifact_path,expected_location", [ (None, 'test.txt'), ('output', 'output/test.txt'), ('', 'test.txt'), ]) - def test_log_artifact(self, databricks_artifact_repo, test_file, artifact_path, expected_location): - with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') as write_credentials_mock, \ - mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._azure_upload_file') as azure_upload_mock: + def test_log_artifact(self, databricks_artifact_repo, test_file, artifact_path, + expected_location): + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ + as write_credentials_mock, \ + mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._azure_upload_file') \ + as azure_upload_mock: mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) write_credentials_response_proto = GetCredentialsForWrite.Response( @@ -78,7 +91,22 @@ def test_log_artifact(self, databricks_artifact_repo, test_file, artifact_path, azure_upload_mock.return_value = None databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) - azure_upload_mock.assert_called_with(mock_credentials, test_file.strpath) + azure_upload_mock.assert_called_with(mock_credentials, test_file.strpath, + expected_location) + + def test_log_artifact_fail_case(self, databricks_artifact_repo, test_file, ): + mock_blob_service = mock.MagicMock(autospec=BlobClient) + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ + as write_credentials_mock: + mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, + type=ArtifactCredentialType.AZURE_SAS_URI) + write_credentials_response_proto = GetCredentialsForWrite.Response( + credentials=mock_credentials) + write_credentials_mock.return_value = write_credentials_response_proto + mock_blob_service.from_blob_url().return_value = MlflowException("MOCK ERROR") + with pytest.raises(MlflowException): + databricks_artifact_repo.log_artifact(test_file.strpath) + write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY) @pytest.mark.parametrize("artifact_path", [ None, @@ -90,15 +118,21 @@ def test_log_artifacts(self, databricks_artifact_repo, test_dir, artifact_path): log_artifact_mock.return_value = None databricks_artifact_repo.log_artifacts(test_dir.strpath, artifact_path) artifact_path = artifact_path or '' - calls = [mock.call(os.path.join(test_dir.strpath, 'empty-file'), os.path.join(artifact_path, '')), - mock.call(os.path.join(test_dir.strpath, 'test.txt'), os.path.join(artifact_path, '')), + calls = [mock.call(os.path.join(test_dir.strpath, 'empty-file'), + os.path.join(artifact_path, '')), + mock.call(os.path.join(test_dir.strpath, 'test.txt'), + os.path.join(artifact_path, '')), mock.call(os.path.join(test_dir.strpath, 'subdir/test.txt'), os.path.join(artifact_path, 'subdir'))] log_artifact_mock.assert_has_calls(calls) def test_list_artifacts(self, databricks_artifact_repo): + list_artifact_file_proto_mock = [FileInfo(path='a.txt', is_dir=False, file_size=0)] + list_artifacts_dir_proto_mock = [FileInfo(path='test/a.txt', is_dir=False, file_size=100), + FileInfo(path='test/dir', is_dir=True, file_size=0)] with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._call_endpoint') as call_endpoint_mock: - list_artifact_response_proto = ListArtifacts.Response(root_uri='', files=LIST_ARTIFACTS_PROTO_RESPONSE) + list_artifact_response_proto = \ + ListArtifacts.Response(root_uri='', files=list_artifacts_dir_proto_mock) call_endpoint_mock.return_value = list_artifact_response_proto artifacts = databricks_artifact_repo.list_artifacts('test/') assert len(artifacts) == 2 @@ -110,8 +144,9 @@ def test_list_artifacts(self, databricks_artifact_repo): assert artifacts[1].file_size is 0 # Calling list_artifacts() on a path that's a file should return an empty list - list_artifact_response_proto = ListArtifacts.Response(root_uri='', - files=LIST_ARTIFACTS_SINGLE_FILE_PROTO_RESPONSE) + list_artifact_response_proto = \ + ListArtifacts.Response(root_uri='', + files=list_artifact_file_proto_mock) call_endpoint_mock.return_value = list_artifact_response_proto artifacts = databricks_artifact_repo.list_artifacts('a.txt') assert len(artifacts) == 0 @@ -122,9 +157,12 @@ def test_list_artifacts(self, databricks_artifact_repo): ('output/test_file', None), ]) def test_databricks_download_file(self, databricks_artifact_repo, remote_file_path, local_path): - with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') as read_credentials_mock, \ + with mock.patch( + DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \ + as read_credentials_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \ - mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._azure_download_file') as azure_download_mock: + mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._azure_download_file') \ + as azure_download_mock: mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) read_credentials_response_proto = GetCredentialsForRead.Response( @@ -135,3 +173,20 @@ def test_databricks_download_file(self, databricks_artifact_repo, remote_file_pa databricks_artifact_repo.download_artifacts(remote_file_path, local_path) read_credentials_mock.assert_called_with(MOCK_RUN_ID, remote_file_path) azure_download_mock.assert_called_with(mock_credentials, ANY) + + def test_databricks_download_file_fail_case(self, databricks_artifact_repo, test_file): + with mock.patch( + DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \ + as read_credentials_mock, \ + mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \ + mock.patch('requests.get') as request_mock: + mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, + type=ArtifactCredentialType.AZURE_SAS_URI) + read_credentials_response_proto = GetCredentialsForRead.Response( + credentials=mock_credentials) + read_credentials_mock.return_value = read_credentials_response_proto + get_list_mock.return_value = [] + request_mock.return_value = MlflowException("MOCK ERROR") + with pytest.raises(MlflowException): + databricks_artifact_repo.download_artifacts(test_file.strpath) + read_credentials_mock.assert_called_with(MOCK_RUN_ID, test_file.strpath) diff --git a/tests/utils/test_uri.py b/tests/utils/test_uri.py index 90ac58afc99b7..7464c21ef9fe7 100644 --- a/tests/utils/test_uri.py +++ b/tests/utils/test_uri.py @@ -196,19 +196,32 @@ def test_append_to_uri_path_preserves_uri_schemes_hosts_queries_and_fragments(): def test_extract_and_normalize_path(): base_uri = 'databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts' - assert extract_and_normalize_path('dbfs:databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') == base_uri - assert extract_and_normalize_path('dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') == base_uri - assert extract_and_normalize_path('dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') == base_uri - assert extract_and_normalize_path('dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/') == base_uri - assert extract_and_normalize_path('dbfs:///databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//') == base_uri - assert extract_and_normalize_path('dbfs:databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//') == base_uri + assert extract_and_normalize_path( + 'dbfs:databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') == base_uri + assert extract_and_normalize_path( + 'dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') == base_uri + assert extract_and_normalize_path( + 'dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') == base_uri + assert extract_and_normalize_path( + 'dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/') == base_uri + assert extract_and_normalize_path( + 'dbfs:///databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//') == base_uri + assert extract_and_normalize_path( + 'dbfs:databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//') == base_uri def test_is_databricks_acled_artifacts_uri(): - assert is_databricks_acled_artifacts_uri('dbfs:databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') - assert is_databricks_acled_artifacts_uri('dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') - assert is_databricks_acled_artifacts_uri('dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') - assert is_databricks_acled_artifacts_uri('dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/') - assert is_databricks_acled_artifacts_uri('dbfs:///databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//') - assert is_databricks_acled_artifacts_uri('dbfs:databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//') - assert not is_databricks_acled_artifacts_uri('dbfs:/databricks/mlflow//EXP_ID//RUN_ID///artifacts//') + assert is_databricks_acled_artifacts_uri( + 'dbfs:databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') + assert is_databricks_acled_artifacts_uri( + 'dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') + assert is_databricks_acled_artifacts_uri( + 'dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') + assert is_databricks_acled_artifacts_uri( + 'dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/') + assert is_databricks_acled_artifacts_uri( + 'dbfs:///databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//') + assert is_databricks_acled_artifacts_uri( + 'dbfs:databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//') + assert not is_databricks_acled_artifacts_uri( + 'dbfs:/databricks/mlflow//EXP_ID//RUN_ID///artifacts//') From 116c5d0a826972fec46d9e13d23cb9d9ef96c541 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Wed, 27 May 2020 16:56:11 -0700 Subject: [PATCH 09/28] Small fix --- .../store/artifact/databricks_artifact_repo.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index c507ac0cbf921..356d23db1d175 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -118,16 +118,16 @@ def _azure_upload_file(self, credentials, local_file, artifact_path): except Exception as err: raise MlflowException(err) - def _azure_download_file(self, credentials, local_path): + def _azure_download_file(self, credentials, local_file): try: signed_read_uri = credentials.signed_uri - response = requests.get(signed_read_uri) - response.raise_for_status() - with open(local_path, "wb") as output_file: - for chunk in response.iter_content(_AZURE_MAX_BLOCK_CHUNK_SIZE): - if not chunk: - break - output_file.write(chunk) + with requests.get(signed_read_uri, stream=True) as response: + response.raise_for_status() + with open(local_file, "wb") as output_file: + for chunk in response.iter_content(chunk_size=_AZURE_MAX_BLOCK_CHUNK_SIZE): + if not chunk: + break + output_file.write(chunk) except Exception as err: raise MlflowException(err) From 11be341cd800ae12928cb46bc6aecdb3a3eb01f3 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Wed, 27 May 2020 23:45:05 -0700 Subject: [PATCH 10/28] Fixing list_artifacts --- mlflow/store/artifact/databricks_artifact_repo.py | 7 ++++++- tests/store/artifact/test_databricks_artifact_repo.py | 7 ++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index 356d23db1d175..4d09524a8e099 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -7,6 +7,7 @@ import logging import requests +from mlflow.entities import FileInfo from mlflow.exceptions import MlflowException from mlflow.store.artifact.artifact_repo import ArtifactRepository from mlflow.protos.databricks_artifacts_pb2 import DatabricksMlflowArtifactsService, \ @@ -177,7 +178,11 @@ def list_artifacts(self, path=None): if len(artifact_list) == 1 and artifact_list[0].path == path \ and not artifact_list[0].is_dir: return [] - return artifact_list + infos = list() + for file in artifact_list: + artifact_size = None if file.is_dir else file.file_size + infos.append(FileInfo(file.path, file.is_dir, artifact_size)) + return infos def _download_file(self, remote_file_path, local_path): read_credentials = self._get_read_credentials(self.run_id, remote_file_path) diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index 556cd00058f1d..3730937fe15b4 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -82,7 +82,7 @@ def test_log_artifact(self, databricks_artifact_repo, test_file, artifact_path, with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ as write_credentials_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._azure_upload_file') \ - as azure_upload_mock: + as azure_upload_mock: mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) write_credentials_response_proto = GetCredentialsForWrite.Response( @@ -135,13 +135,14 @@ def test_list_artifacts(self, databricks_artifact_repo): ListArtifacts.Response(root_uri='', files=list_artifacts_dir_proto_mock) call_endpoint_mock.return_value = list_artifact_response_proto artifacts = databricks_artifact_repo.list_artifacts('test/') + print (artifacts) assert len(artifacts) == 2 assert artifacts[0].path == 'test/a.txt' assert artifacts[0].is_dir is False assert artifacts[0].file_size == 100 assert artifacts[1].path == 'test/dir' assert artifacts[1].is_dir is True - assert artifacts[1].file_size is 0 + assert artifacts[1].file_size is None # Calling list_artifacts() on a path that's a file should return an empty list list_artifact_response_proto = \ @@ -162,7 +163,7 @@ def test_databricks_download_file(self, databricks_artifact_repo, remote_file_pa as read_credentials_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._azure_download_file') \ - as azure_download_mock: + as azure_download_mock: mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) read_credentials_response_proto = GetCredentialsForRead.Response( From 0609a5bb256081a88df1c233da963111aac6cd51 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Sun, 31 May 2020 18:56:29 -0700 Subject: [PATCH 11/28] Addressing final comments. --- mlflow/store/artifact/artifact_repo.py | 2 +- .../artifact/databricks_artifact_repo.py | 65 ++++++++++++------- mlflow/store/artifact/dbfs_artifact_repo.py | 5 +- .../artifact/test_databricks_artifact_repo.py | 4 +- 4 files changed, 48 insertions(+), 28 deletions(-) diff --git a/mlflow/store/artifact/artifact_repo.py b/mlflow/store/artifact/artifact_repo.py index 0105c86eb5649..70328a6de0c6e 100644 --- a/mlflow/store/artifact/artifact_repo.py +++ b/mlflow/store/artifact/artifact_repo.py @@ -81,7 +81,7 @@ def download_artifacts(self, artifact_path, dst_path=None): # TODO: Probably need to add a more efficient method to stream just a single artifact # without downloading it, or to get a pre-signed URL for cloud storage. def download_file(fullpath): - fullpath = fullpath.rstrip('/') + fullpath = fullpath.rstrip('/') # Prevents incorrect split if fullpath ends with a '/' dirpath, _ = posixpath.split(fullpath) local_dir_path = os.path.join(dst_path, dirpath) local_file_path = os.path.join(dst_path, fullpath) diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index 4d09524a8e099..7b4ade2b76811 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -6,13 +6,14 @@ import base64 import logging import requests +import posixpath from mlflow.entities import FileInfo from mlflow.exceptions import MlflowException from mlflow.store.artifact.artifact_repo import ArtifactRepository from mlflow.protos.databricks_artifacts_pb2 import DatabricksMlflowArtifactsService, \ - GetCredentialsForWrite, \ - GetCredentialsForRead, ArtifactCredentialType + GetCredentialsForWrite, GetCredentialsForRead, ArtifactCredentialType +from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE from mlflow.protos.service_pb2 import MlflowService, ListArtifacts from mlflow.utils.uri import extract_and_normalize_path, is_databricks_acled_artifacts_uri from mlflow.utils.proto_json_utils import message_to_json @@ -31,7 +32,11 @@ class DatabricksArtifactRepository(ArtifactRepository): """ - Stores artifacts on Azure/AWS with access control. + Performs storage operations on artifacts in the access-controlled + `dbfs:/databricks/mlflow-tracking` location. + + Signed access URIs for S3 / Azure Blob Storage are fetched from the MLflow service and used to + read and write files from/to this location. The artifact_uri is expected to be of the form dbfs:/databricks/mlflow-tracking///artifacts/ @@ -40,26 +45,13 @@ class DatabricksArtifactRepository(ArtifactRepository): def __init__(self, artifact_uri): super(DatabricksArtifactRepository, self).__init__(artifact_uri) if not artifact_uri.startswith('dbfs:/'): - raise MlflowException('DatabricksArtifactRepository URI must start with dbfs:/') + raise MlflowException(message='DatabricksArtifactRepository URI must start with dbfs:/', + error_code=INVALID_PARAMETER_VALUE) if not is_databricks_acled_artifacts_uri(artifact_uri): - raise MlflowException('Artifact URI incorrect. Expected path prefix to be ' - 'databricks/mlflow-tracking/path/to/artifact/..') - self.run_id = self._extract_run_id() - - def _extract_run_id(self): - """ - The artifact_uri is expected to be - dbfs:/databricks/mlflow-tracking///artifacts/ - Once the path from the inputted uri is extracted and normalized, is is - expected to be of the form - databricks/mlflow-tracking///artifacts/ - - Hence the run_id is the 4th element of the normalized path. - - :return: run_id extracted from the artifact_uri - """ - artifact_path = extract_and_normalize_path(self.artifact_uri) - return artifact_path.split('/')[3] + raise MlflowException(message=('Artifact URI incorrect. Expected path prefix to be' + ' databricks/mlflow-tracking/path/to/artifact/..'), + error_code=INVALID_PARAMETER_VALUE) + self.run_id = extract_run_id(self.artifact_uri) def _call_endpoint(self, service, api, json_body): endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api] @@ -120,6 +112,15 @@ def _azure_upload_file(self, credentials, local_file, artifact_path): raise MlflowException(err) def _azure_download_file(self, credentials, local_file): + """ + Downloads a file from Azure storage and writes it to local_file. + + The default working of requests.get is to download the entire response body immediately. + However, this could be inefficient for large files. Hence the parameter `stream` is set to + true. This only downloads the response headers at first and keeps the connection open, + allowing content retrieval to be made via `iter_content`. + In addition, since the connection is kept open, refreshing credentials is not required. + """ try: signed_read_uri = credentials.signed_uri with requests.get(signed_read_uri, stream=True) as response: @@ -153,7 +154,7 @@ def _download_from_cloud(self, cloud_credentials, local_path): def log_artifact(self, local_file, artifact_path=None): basename = os.path.basename(local_file) artifact_path = artifact_path or "" - artifact_path = os.path.join(artifact_path, basename) + artifact_path = posixpath.join(artifact_path, basename) write_credentials = self._get_write_credentials(self.run_id, artifact_path) self._upload_to_cloud(write_credentials, local_file, artifact_path) @@ -164,7 +165,7 @@ def log_artifacts(self, local_dir, artifact_path=None): if dirpath != local_dir: rel_path = os.path.relpath(dirpath, local_dir) rel_path = relative_path_to_artifact_path(rel_path) - artifact_subdir = os.path.join(artifact_path, rel_path) + artifact_subdir = posixpath.join(artifact_path, rel_path) for name in filenames: file_path = os.path.join(dirpath, name) self.log_artifact(file_path, artifact_subdir) @@ -190,3 +191,19 @@ def _download_file(self, remote_file_path, local_path): def delete_artifacts(self, artifact_path=None): raise MlflowException('Not implemented yet') + + +def extract_run_id(artifact_uri): + """ + The artifact_uri is expected to be + dbfs:/databricks/mlflow-tracking///artifacts/ + Once the path from the input uri is extracted and normalized, it is + expected to be of the form + databricks/mlflow-tracking///artifacts/ + + Hence the run_id is the 4th element of the normalized path. + + :return: run_id extracted from the artifact_uri + """ + artifact_path = extract_and_normalize_path(artifact_uri) + return artifact_path.split('/')[3] diff --git a/mlflow/store/artifact/dbfs_artifact_repo.py b/mlflow/store/artifact/dbfs_artifact_repo.py index e56f615043328..92c0a8b64c602 100644 --- a/mlflow/store/artifact/dbfs_artifact_repo.py +++ b/mlflow/store/artifact/dbfs_artifact_repo.py @@ -28,6 +28,7 @@ class DbfsRestArtifactRepository(ArtifactRepository): This repository is used with URIs of the form ``dbfs:/``. The repository can only be used together with the RestStore. """ + def __init__(self, artifact_uri): super(DbfsRestArtifactRepository, self).__init__(artifact_uri) # NOTE: if we ever need to support databricks profiles different from that set for @@ -173,8 +174,8 @@ def dbfs_artifact_repo_factory(artifact_uri): cleaned_artifact_uri = artifact_uri.rstrip('/') uri_scheme = get_uri_scheme(artifact_uri) if uri_scheme != 'dbfs': - raise Exception("DBFS URI must be of the form " - "dbfs:/, but received {uri}".format(uri=artifact_uri)) + raise MlflowException("DBFS URI must be of the form " + "dbfs:/, but received {uri}".format(uri=artifact_uri)) if is_databricks_acled_artifacts_uri(artifact_uri): return DatabricksArtifactRepository(artifact_uri) elif mlflow.utils.databricks_utils.is_dbfs_fuse_available() \ diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index 3730937fe15b4..cac90199d1cdd 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -12,6 +12,7 @@ from mlflow.protos.service_pb2 import ListArtifacts, FileInfo from mlflow.protos.databricks_artifacts_pb2 import GetCredentialsForWrite, GetCredentialsForRead, \ ArtifactCredentialType, ArtifactCredentialInfo +from mlflow.entities.file_info import FileInfo as FileInfoEntity @pytest.fixture() @@ -135,7 +136,8 @@ def test_list_artifacts(self, databricks_artifact_repo): ListArtifacts.Response(root_uri='', files=list_artifacts_dir_proto_mock) call_endpoint_mock.return_value = list_artifact_response_proto artifacts = databricks_artifact_repo.list_artifacts('test/') - print (artifacts) + assert isinstance(artifacts, list) + assert isinstance(artifacts[0], FileInfoEntity) assert len(artifacts) == 2 assert artifacts[0].path == 'test/a.txt' assert artifacts[0].is_dir is False From 3f1327d0ca3c70e7799ffd8d74bc077758bf1b32 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Sun, 31 May 2020 19:16:07 -0700 Subject: [PATCH 12/28] Making extract_run_id static --- .../artifact/databricks_artifact_repo.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index 7b4ade2b76811..0533e7291b450 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -51,7 +51,23 @@ def __init__(self, artifact_uri): raise MlflowException(message=('Artifact URI incorrect. Expected path prefix to be' ' databricks/mlflow-tracking/path/to/artifact/..'), error_code=INVALID_PARAMETER_VALUE) - self.run_id = extract_run_id(self.artifact_uri) + self.run_id = self._extract_run_id(self.artifact_uri) + + @staticmethod + def _extract_run_id(artifact_uri): + """ + The artifact_uri is expected to be + dbfs:/databricks/mlflow-tracking///artifacts/ + Once the path from the input uri is extracted and normalized, it is + expected to be of the form + databricks/mlflow-tracking///artifacts/ + + Hence the run_id is the 4th element of the normalized path. + + :return: run_id extracted from the artifact_uri + """ + artifact_path = extract_and_normalize_path(artifact_uri) + return artifact_path.split('/')[3] def _call_endpoint(self, service, api, json_body): endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api] @@ -191,19 +207,3 @@ def _download_file(self, remote_file_path, local_path): def delete_artifacts(self, artifact_path=None): raise MlflowException('Not implemented yet') - - -def extract_run_id(artifact_uri): - """ - The artifact_uri is expected to be - dbfs:/databricks/mlflow-tracking///artifacts/ - Once the path from the input uri is extracted and normalized, it is - expected to be of the form - databricks/mlflow-tracking///artifacts/ - - Hence the run_id is the 4th element of the normalized path. - - :return: run_id extracted from the artifact_uri - """ - artifact_path = extract_and_normalize_path(artifact_uri) - return artifact_path.split('/')[3] From 98035b4975daec0b714ee991d92af9905553802c Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Tue, 2 Jun 2020 14:58:23 -0700 Subject: [PATCH 13/28] Adding AWS support --- mlflow/protos/databricks_artifacts.proto | 26 ++++- mlflow/protos/databricks_artifacts_pb2.py | 106 ++++++++++++++---- .../artifact/databricks_artifact_repo.py | 72 +++++++----- .../artifact/test_databricks_artifact_repo.py | 38 +++++-- 4 files changed, 176 insertions(+), 66 deletions(-) diff --git a/mlflow/protos/databricks_artifacts.proto b/mlflow/protos/databricks_artifacts.proto index 6e516743b2fd6..65e194970327e 100644 --- a/mlflow/protos/databricks_artifacts.proto +++ b/mlflow/protos/databricks_artifacts.proto @@ -70,21 +70,36 @@ message ArtifactCredentialInfo { // The signed URI credential that provides access to the artifact optional string signed_uri = 3; + message HttpHeader { + + // The HTTP header name + optional string name = 1; + + // The HTTP header value + optional string value = 2; + + } + + // A collection of HTTP headers that should be specified when uploading to + // or downloading from the specified `signed_uri` + repeated HttpHeader headers = 4; + // The type of the signed credential URI (e.g., an AWS presigned URL // or an Azure Shared Access Signature URI) - optional ArtifactCredentialType type = 4; + optional ArtifactCredentialType type = 5; } message GetCredentialsForRead { option (scalapb.message).extends = "com.databricks.rpc.RPC[$this.Response]"; + option (scalapb.message).extends = "com.databricks.mlflow.api.MlflowTrackingMessage"; // The ID of the MLflow Run for which to fetch artifact read credentials - optional string run_id = 1; + optional string run_id = 1 [(validate_required) = true]; // The artifact path, relative to the Run's artifact root location, for which to // fetch artifact read credentials - optional string path = 2; + optional string path = 2 [(validate_required) = true]; message Response { @@ -96,13 +111,14 @@ message GetCredentialsForRead { message GetCredentialsForWrite { option (scalapb.message).extends = "com.databricks.rpc.RPC[$this.Response]"; + option (scalapb.message).extends = "com.databricks.mlflow.api.MlflowTrackingMessage"; // The ID of the MLflow Run for which to fetch artifact write credentials - optional string run_id = 1; + optional string run_id = 1 [(validate_required) = true]; // The artifact path, relative to the Run's artifact root location, for which to // fetch artifact write credentials - optional string path = 2; + optional string path = 2 [(validate_required) = true]; message Response { diff --git a/mlflow/protos/databricks_artifacts_pb2.py b/mlflow/protos/databricks_artifacts_pb2.py index 8d67294e1cf93..89467048dbac7 100644 --- a/mlflow/protos/databricks_artifacts_pb2.py +++ b/mlflow/protos/databricks_artifacts_pb2.py @@ -24,7 +24,7 @@ package='mlflow', syntax='proto2', serialized_options=_b('\n\037com.databricks.api.proto.mlflow\220\001\001\240\001\001\342?\002\020\001'), - serialized_pb=_b('\n\x1a\x64\x61tabricks_artifacts.proto\x12\x06mlflow\x1a\x15scalapb/scalapb.proto\x1a\x10\x64\x61tabricks.proto\"x\n\x16\x41rtifactCredentialInfo\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0c\n\x04path\x18\x02 \x01(\t\x12\x12\n\nsigned_uri\x18\x03 \x01(\t\x12,\n\x04type\x18\x04 \x01(\x0e\x32\x1e.mlflow.ArtifactCredentialType\"\xa3\x01\n\x15GetCredentialsForRead\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0c\n\x04path\x18\x02 \x01(\t\x1a?\n\x08Response\x12\x33\n\x0b\x63redentials\x18\x01 \x01(\x0b\x32\x1e.mlflow.ArtifactCredentialInfo:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\xa4\x01\n\x16GetCredentialsForWrite\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0c\n\x04path\x18\x02 \x01(\t\x1a?\n\x08Response\x12\x33\n\x0b\x63redentials\x18\x01 \x01(\x0b\x32\x1e.mlflow.ArtifactCredentialInfo:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]*B\n\x16\x41rtifactCredentialType\x12\x11\n\rAZURE_SAS_URI\x10\x01\x12\x15\n\x11\x41WS_PRESIGNED_URL\x10\x02\x32\xe2\x02\n DatabricksMlflowArtifactsService\x12\x9b\x01\n\x15getCredentialsForRead\x12\x1d.mlflow.GetCredentialsForRead\x1a&.mlflow.GetCredentialsForRead.Response\";\xf2\x86\x19\x37\n3\n\x03GET\x12&/mlflow/artifacts/credentials-for-read\x1a\x04\x08\x02\x10\x00\x10\x03\x12\x9f\x01\n\x16getCredentialsForWrite\x12\x1e.mlflow.GetCredentialsForWrite\x1a\'.mlflow.GetCredentialsForWrite.Response\"<\xf2\x86\x19\x38\n4\n\x03GET\x12\'/mlflow/artifacts/credentials-for-write\x1a\x04\x08\x02\x10\x00\x10\x03\x42,\n\x1f\x63om.databricks.api.proto.mlflow\x90\x01\x01\xa0\x01\x01\xe2?\x02\x10\x01') + serialized_pb=_b('\n\x1a\x64\x61tabricks_artifacts.proto\x12\x06mlflow\x1a\x15scalapb/scalapb.proto\x1a\x10\x64\x61tabricks.proto\"\xdf\x01\n\x16\x41rtifactCredentialInfo\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0c\n\x04path\x18\x02 \x01(\t\x12\x12\n\nsigned_uri\x18\x03 \x01(\t\x12:\n\x07headers\x18\x04 \x03(\x0b\x32).mlflow.ArtifactCredentialInfo.HttpHeader\x12,\n\x04type\x18\x05 \x01(\x0e\x32\x1e.mlflow.ArtifactCredentialType\x1a)\n\nHttpHeader\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"\xe3\x01\n\x15GetCredentialsForRead\x12\x14\n\x06run_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x12\x12\n\x04path\x18\x02 \x01(\tB\x04\xf8\x86\x19\x01\x1a?\n\x08Response\x12\x33\n\x0b\x63redentials\x18\x01 \x01(\x0b\x32\x1e.mlflow.ArtifactCredentialInfo:_\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\xe2?1\n/com.databricks.mlflow.api.MlflowTrackingMessage\"\xe4\x01\n\x16GetCredentialsForWrite\x12\x14\n\x06run_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x12\x12\n\x04path\x18\x02 \x01(\tB\x04\xf8\x86\x19\x01\x1a?\n\x08Response\x12\x33\n\x0b\x63redentials\x18\x01 \x01(\x0b\x32\x1e.mlflow.ArtifactCredentialInfo:_\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\xe2?1\n/com.databricks.mlflow.api.MlflowTrackingMessage*B\n\x16\x41rtifactCredentialType\x12\x11\n\rAZURE_SAS_URI\x10\x01\x12\x15\n\x11\x41WS_PRESIGNED_URL\x10\x02\x32\xe2\x02\n DatabricksMlflowArtifactsService\x12\x9b\x01\n\x15getCredentialsForRead\x12\x1d.mlflow.GetCredentialsForRead\x1a&.mlflow.GetCredentialsForRead.Response\";\xf2\x86\x19\x37\n3\n\x03GET\x12&/mlflow/artifacts/credentials-for-read\x1a\x04\x08\x02\x10\x00\x10\x03\x12\x9f\x01\n\x16getCredentialsForWrite\x12\x1e.mlflow.GetCredentialsForWrite\x1a\'.mlflow.GetCredentialsForWrite.Response\"<\xf2\x86\x19\x38\n4\n\x03GET\x12\'/mlflow/artifacts/credentials-for-write\x1a\x04\x08\x02\x10\x00\x10\x03\x42,\n\x1f\x63om.databricks.api.proto.mlflow\x90\x01\x01\xa0\x01\x01\xe2?\x02\x10\x01') , dependencies=[scalapb_dot_scalapb__pb2.DESCRIPTOR,databricks__pb2.DESCRIPTOR,]) @@ -45,8 +45,8 @@ ], containing_type=None, serialized_options=None, - serialized_start=534, - serialized_end=600, + serialized_start=766, + serialized_end=832, ) _sym_db.RegisterEnumDescriptor(_ARTIFACTCREDENTIALTYPE) @@ -56,6 +56,43 @@ +_ARTIFACTCREDENTIALINFO_HTTPHEADER = _descriptor.Descriptor( + name='HttpHeader', + full_name='mlflow.ArtifactCredentialInfo.HttpHeader', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='mlflow.ArtifactCredentialInfo.HttpHeader.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='mlflow.ArtifactCredentialInfo.HttpHeader.value', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=262, + serialized_end=303, +) + _ARTIFACTCREDENTIALINFO = _descriptor.Descriptor( name='ArtifactCredentialInfo', full_name='mlflow.ArtifactCredentialInfo', @@ -85,8 +122,15 @@ is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( - name='type', full_name='mlflow.ArtifactCredentialInfo.type', index=3, - number=4, type=14, cpp_type=8, label=1, + name='headers', full_name='mlflow.ArtifactCredentialInfo.headers', index=3, + number=4, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='type', full_name='mlflow.ArtifactCredentialInfo.type', index=4, + number=5, type=14, cpp_type=8, label=1, has_default_value=False, default_value=1, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, @@ -94,7 +138,7 @@ ], extensions=[ ], - nested_types=[], + nested_types=[_ARTIFACTCREDENTIALINFO_HTTPHEADER, ], enum_types=[ ], serialized_options=None, @@ -103,8 +147,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=79, - serialized_end=199, + serialized_start=80, + serialized_end=303, ) @@ -134,8 +178,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=257, - serialized_end=320, + serialized_start=373, + serialized_end=436, ) _GETCREDENTIALSFORREAD = _descriptor.Descriptor( @@ -151,28 +195,28 @@ has_default_value=False, default_value=_b("").decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), + serialized_options=_b('\370\206\031\001'), file=DESCRIPTOR), _descriptor.FieldDescriptor( name='path', full_name='mlflow.GetCredentialsForRead.path', index=1, number=2, type=9, cpp_type=9, label=1, has_default_value=False, default_value=_b("").decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), + serialized_options=_b('\370\206\031\001'), file=DESCRIPTOR), ], extensions=[ ], nested_types=[_GETCREDENTIALSFORREAD_RESPONSE, ], enum_types=[ ], - serialized_options=_b('\342?(\n&com.databricks.rpc.RPC[$this.Response]'), + serialized_options=_b('\342?(\n&com.databricks.rpc.RPC[$this.Response]\342?1\n/com.databricks.mlflow.api.MlflowTrackingMessage'), is_extendable=False, syntax='proto2', extension_ranges=[], oneofs=[ ], - serialized_start=202, - serialized_end=365, + serialized_start=306, + serialized_end=533, ) @@ -202,8 +246,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=257, - serialized_end=320, + serialized_start=373, + serialized_end=436, ) _GETCREDENTIALSFORWRITE = _descriptor.Descriptor( @@ -219,30 +263,32 @@ has_default_value=False, default_value=_b("").decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), + serialized_options=_b('\370\206\031\001'), file=DESCRIPTOR), _descriptor.FieldDescriptor( name='path', full_name='mlflow.GetCredentialsForWrite.path', index=1, number=2, type=9, cpp_type=9, label=1, has_default_value=False, default_value=_b("").decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), + serialized_options=_b('\370\206\031\001'), file=DESCRIPTOR), ], extensions=[ ], nested_types=[_GETCREDENTIALSFORWRITE_RESPONSE, ], enum_types=[ ], - serialized_options=_b('\342?(\n&com.databricks.rpc.RPC[$this.Response]'), + serialized_options=_b('\342?(\n&com.databricks.rpc.RPC[$this.Response]\342?1\n/com.databricks.mlflow.api.MlflowTrackingMessage'), is_extendable=False, syntax='proto2', extension_ranges=[], oneofs=[ ], - serialized_start=368, - serialized_end=532, + serialized_start=536, + serialized_end=764, ) +_ARTIFACTCREDENTIALINFO_HTTPHEADER.containing_type = _ARTIFACTCREDENTIALINFO +_ARTIFACTCREDENTIALINFO.fields_by_name['headers'].message_type = _ARTIFACTCREDENTIALINFO_HTTPHEADER _ARTIFACTCREDENTIALINFO.fields_by_name['type'].enum_type = _ARTIFACTCREDENTIALTYPE _GETCREDENTIALSFORREAD_RESPONSE.fields_by_name['credentials'].message_type = _ARTIFACTCREDENTIALINFO _GETCREDENTIALSFORREAD_RESPONSE.containing_type = _GETCREDENTIALSFORREAD @@ -255,11 +301,19 @@ _sym_db.RegisterFileDescriptor(DESCRIPTOR) ArtifactCredentialInfo = _reflection.GeneratedProtocolMessageType('ArtifactCredentialInfo', (_message.Message,), dict( + + HttpHeader = _reflection.GeneratedProtocolMessageType('HttpHeader', (_message.Message,), dict( + DESCRIPTOR = _ARTIFACTCREDENTIALINFO_HTTPHEADER, + __module__ = 'databricks_artifacts_pb2' + # @@protoc_insertion_point(class_scope:mlflow.ArtifactCredentialInfo.HttpHeader) + )) + , DESCRIPTOR = _ARTIFACTCREDENTIALINFO, __module__ = 'databricks_artifacts_pb2' # @@protoc_insertion_point(class_scope:mlflow.ArtifactCredentialInfo) )) _sym_db.RegisterMessage(ArtifactCredentialInfo) +_sym_db.RegisterMessage(ArtifactCredentialInfo.HttpHeader) GetCredentialsForRead = _reflection.GeneratedProtocolMessageType('GetCredentialsForRead', (_message.Message,), dict( @@ -293,7 +347,11 @@ DESCRIPTOR._options = None +_GETCREDENTIALSFORREAD.fields_by_name['run_id']._options = None +_GETCREDENTIALSFORREAD.fields_by_name['path']._options = None _GETCREDENTIALSFORREAD._options = None +_GETCREDENTIALSFORWRITE.fields_by_name['run_id']._options = None +_GETCREDENTIALSFORWRITE.fields_by_name['path']._options = None _GETCREDENTIALSFORWRITE._options = None _DATABRICKSMLFLOWARTIFACTSSERVICE = _descriptor.ServiceDescriptor( @@ -302,8 +360,8 @@ file=DESCRIPTOR, index=0, serialized_options=None, - serialized_start=603, - serialized_end=957, + serialized_start=835, + serialized_end=1189, methods=[ _descriptor.MethodDescriptor( name='getCredentialsForRead', diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index 0533e7291b450..05649a47e4841 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -24,6 +24,7 @@ _logger = logging.getLogger(__name__) _PATH_PREFIX = "/api/2.0" _AZURE_MAX_BLOCK_CHUNK_SIZE = 100000000 # Max. size of each block allowed is 100 MB in stage_block +_DOWNLOAD_CHUNK_SIZE = 100000000 _SERVICE_AND_METHOD_TO_INFO = { service: extract_api_info_for_service(service, _PATH_PREFIX) for service in [MlflowService, DatabricksMlflowArtifactsService] @@ -85,6 +86,12 @@ def _get_read_credentials(self, run_id, path=None): return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForRead, json_body) + def _extract_headers_from_credentials(self, credential): + headers = dict() + for header in credential.headers: + headers[header.name] = header.value + return headers + def _azure_upload_file(self, credentials, local_file, artifact_path): """ Uploads a file to a given Azure storage location. @@ -99,12 +106,14 @@ def _azure_upload_file(self, credentials, local_file, artifact_path): stage_block and the commit, a second try-except block refreshes credentials if needed. """ try: - service = BlobClient.from_blob_url(blob_url=credentials.signed_uri, credential=None) + headers = self._extract_headers_from_credentials(credentials) + service = BlobClient.from_blob_url(blob_url=credentials.signed_uri, credential=None, + headers=headers) uploading_block_list = list() for chunk in yield_file_in_chunks(local_file, _AZURE_MAX_BLOCK_CHUNK_SIZE): block_id = base64.b64encode(uuid.uuid4().hex.encode()) try: - service.stage_block(block_id, chunk) + service.stage_block(block_id, chunk, headers=headers) except ClientAuthenticationError: _logger.warning( "Failed to authorize request, possibly due to credential expiration." @@ -112,7 +121,7 @@ def _azure_upload_file(self, credentials, local_file, artifact_path): credentials = self._get_write_credentials(self.run_id, artifact_path).credentials.signed_uri service = BlobClient.from_blob_url(blob_url=credentials, credential=None) - service.stage_block(block_id, chunk) + service.stage_block(block_id, chunk, headers=headers) uploading_block_list.append(block_id) try: service.commit_block_list(uploading_block_list) @@ -127,46 +136,53 @@ def _azure_upload_file(self, credentials, local_file, artifact_path): except Exception as err: raise MlflowException(err) - def _azure_download_file(self, credentials, local_file): + def _aws_upload_file(self, credentials, local_file): + try: + headers = self._extract_headers_from_credentials(credentials) + signed_write_uri = credentials.signed_uri + with open(local_file, 'rb') as file: + put_request = requests.put(signed_write_uri, headers=headers, data=file) + put_request.raise_for_status() + except Exception as err: + raise MlflowException(err) + + def _upload_to_cloud(self, cloud_credentials, local_file, artifact_path): + if cloud_credentials.credentials.type == ArtifactCredentialType.AZURE_SAS_URI: + self._azure_upload_file(cloud_credentials.credentials, local_file, artifact_path) + elif cloud_credentials.credentials.type == ArtifactCredentialType.AWS_PRESIGNED_URL: + self._aws_upload_file(cloud_credentials.credentials, local_file) + else: + raise MlflowException('Not implemented yet') + + def _download_from_cloud(self, cloud_credential, local_file_path): """ - Downloads a file from Azure storage and writes it to local_file. + Downloads a file from the input `cloud_credential` and save it to `local_path`. + + Since the download mechanism for both cloud services, i.e., Azure and AWS is the same, + a single download method is sufficient. - The default working of requests.get is to download the entire response body immediately. + The default working of `requests.get` is to download the entire response body immediately. However, this could be inefficient for large files. Hence the parameter `stream` is set to true. This only downloads the response headers at first and keeps the connection open, allowing content retrieval to be made via `iter_content`. In addition, since the connection is kept open, refreshing credentials is not required. """ + if cloud_credential.type not in [ArtifactCredentialType.AZURE_SAS_URI, + ArtifactCredentialType.AWS_PRESIGNED_URL]: + raise MlflowException(message='Cloud provider not supported.', + error_code=INVALID_PARAMETER_VALUE) try: - signed_read_uri = credentials.signed_uri + signed_read_uri = cloud_credential.signed_uri with requests.get(signed_read_uri, stream=True) as response: response.raise_for_status() - with open(local_file, "wb") as output_file: - for chunk in response.iter_content(chunk_size=_AZURE_MAX_BLOCK_CHUNK_SIZE): + with open(local_file_path, "wb") as output_file: + for chunk in response.iter_content(chunk_size=_DOWNLOAD_CHUNK_SIZE): if not chunk: break output_file.write(chunk) except Exception as err: raise MlflowException(err) - def _aws_upload_file(self, credentials, local_file): - pass - - def _aws_download_file(self, credentials, local_path): - pass - - def _upload_to_cloud(self, cloud_credentials, local_file, artifact_path): - if cloud_credentials.credentials.type == ArtifactCredentialType.AZURE_SAS_URI: - self._azure_upload_file(cloud_credentials.credentials, local_file, artifact_path) - else: - raise MlflowException('Not implemented yet') - - def _download_from_cloud(self, cloud_credentials, local_path): - if cloud_credentials.credentials.type == ArtifactCredentialType.AZURE_SAS_URI: - self._azure_download_file(cloud_credentials.credentials, local_path) - else: - raise MlflowException('Not implemented yet') - def log_artifact(self, local_file, artifact_path=None): basename = os.path.basename(local_file) artifact_path = artifact_path or "" @@ -203,7 +219,7 @@ def list_artifacts(self, path=None): def _download_file(self, remote_file_path, local_path): read_credentials = self._get_read_credentials(self.run_id, remote_file_path) - self._download_from_cloud(read_credentials, local_path) + self._download_from_cloud(read_credentials.credentials, local_path) def delete_artifacts(self, artifact_path=None): raise MlflowException('Not implemented yet') diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index cac90199d1cdd..af7587501ca1f 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -17,7 +17,7 @@ @pytest.fixture() def databricks_artifact_repo(): - return get_artifact_repository('dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN/artifact') + return get_artifact_repository('dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifact') DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE = 'mlflow.store.artifact.databricks_artifact_repo' @@ -47,7 +47,8 @@ def test_dir(tmpdir): MOCK_AZURE_SIGNED_URI = "this_is_a_mock_sas_for_azure" -MOCK_RUN_ID = 'MOCK-RUN' +MOCK_AWS_SIGNED_URI = "this_is_a_mock_presigned_uri_for_aws" +MOCK_RUN_ID = 'MOCK-RUN-ID' class TestDatabricksArtifactRepository(object): @@ -78,8 +79,8 @@ def test_extract_run_id(self): ('output', 'output/test.txt'), ('', 'test.txt'), ]) - def test_log_artifact(self, databricks_artifact_repo, test_file, artifact_path, - expected_location): + def test_log_artifact_azure(self, databricks_artifact_repo, test_file, artifact_path, + expected_location): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ as write_credentials_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._azure_upload_file') \ @@ -95,6 +96,25 @@ def test_log_artifact(self, databricks_artifact_repo, test_file, artifact_path, azure_upload_mock.assert_called_with(mock_credentials, test_file.strpath, expected_location) + @pytest.mark.parametrize("artifact_path,expected_location", [ + (None, 'test.txt'), + ]) + def test_log_artifact_aws(self, databricks_artifact_repo, test_file, artifact_path, + expected_location): + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ + as write_credentials_mock, \ + mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._aws_upload_file') \ + as aws_upload_mock: + mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AWS_SIGNED_URI, + type=ArtifactCredentialType.AWS_PRESIGNED_URL) + write_credentials_response_proto = GetCredentialsForWrite.Response( + credentials=mock_credentials) + write_credentials_mock.return_value = write_credentials_response_proto + aws_upload_mock.return_value = None + databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) + write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) + aws_upload_mock.assert_called_with(mock_credentials, test_file.strpath) + def test_log_artifact_fail_case(self, databricks_artifact_repo, test_file, ): mock_blob_service = mock.MagicMock(autospec=BlobClient) with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ @@ -164,20 +184,20 @@ def test_databricks_download_file(self, databricks_artifact_repo, remote_file_pa DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \ as read_credentials_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \ - mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._azure_download_file') \ - as azure_download_mock: + mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._download_from_cloud') \ + as download_mock: mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) read_credentials_response_proto = GetCredentialsForRead.Response( credentials=mock_credentials) read_credentials_mock.return_value = read_credentials_response_proto - azure_download_mock.return_value = None + download_mock.return_value = None get_list_mock.return_value = [] databricks_artifact_repo.download_artifacts(remote_file_path, local_path) read_credentials_mock.assert_called_with(MOCK_RUN_ID, remote_file_path) - azure_download_mock.assert_called_with(mock_credentials, ANY) + download_mock.assert_called_with(mock_credentials, ANY) - def test_databricks_download_file_fail_case(self, databricks_artifact_repo, test_file): + def test_databricks_download_file_get_request_fail(self, databricks_artifact_repo, test_file): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \ as read_credentials_mock, \ From 0b1af46a04233f9f028074d67cb70027fa2d27e9 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Wed, 3 Jun 2020 17:43:12 -0700 Subject: [PATCH 14/28] Addressing comments --- .../api/proto/mlflow/DatabricksArtifacts.java | 1648 +++++++++++++++-- .../artifact/databricks_artifact_repo.py | 46 +- .../artifact/test_databricks_artifact_repo.py | 126 +- 3 files changed, 1633 insertions(+), 187 deletions(-) diff --git a/mlflow/java/client/src/main/java/com/databricks/api/proto/mlflow/DatabricksArtifacts.java b/mlflow/java/client/src/main/java/com/databricks/api/proto/mlflow/DatabricksArtifacts.java index e94afa768bd0b..a0beb4844cec0 100644 --- a/mlflow/java/client/src/main/java/com/databricks/api/proto/mlflow/DatabricksArtifacts.java +++ b/mlflow/java/client/src/main/java/com/databricks/api/proto/mlflow/DatabricksArtifacts.java @@ -216,13 +216,62 @@ public interface ArtifactCredentialInfoOrBuilder extends com.google.protobuf.ByteString getSignedUriBytes(); + /** + *
+     * A collection of HTTP headers that should be specified when uploading to
+     * or downloading from the specified `signed_uri`
+     * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + java.util.List + getHeadersList(); + /** + *
+     * A collection of HTTP headers that should be specified when uploading to
+     * or downloading from the specified `signed_uri`
+     * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader getHeaders(int index); + /** + *
+     * A collection of HTTP headers that should be specified when uploading to
+     * or downloading from the specified `signed_uri`
+     * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + int getHeadersCount(); + /** + *
+     * A collection of HTTP headers that should be specified when uploading to
+     * or downloading from the specified `signed_uri`
+     * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + java.util.List + getHeadersOrBuilderList(); + /** + *
+     * A collection of HTTP headers that should be specified when uploading to
+     * or downloading from the specified `signed_uri`
+     * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeaderOrBuilder getHeadersOrBuilder( + int index); + /** *
      * The type of the signed credential URI (e.g., an AWS presigned URL
      * or an Azure Shared Access Signature URI)
      * 
* - * optional .mlflow.ArtifactCredentialType type = 4; + * optional .mlflow.ArtifactCredentialType type = 5; */ boolean hasType(); /** @@ -231,7 +280,7 @@ public interface ArtifactCredentialInfoOrBuilder extends * or an Azure Shared Access Signature URI) * * - * optional .mlflow.ArtifactCredentialType type = 4; + * optional .mlflow.ArtifactCredentialType type = 5; */ com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType getType(); } @@ -251,6 +300,7 @@ private ArtifactCredentialInfo() { runId_ = ""; path_ = ""; signedUri_ = ""; + headers_ = java.util.Collections.emptyList(); type_ = 1; } @@ -296,48 +346,904 @@ private ArtifactCredentialInfo( signedUri_ = bs; break; } - case 32: { - int rawValue = input.readEnum(); - @SuppressWarnings("deprecation") - com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType value = com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType.valueOf(rawValue); - if (value == null) { - unknownFields.mergeVarintField(4, rawValue); - } else { - bitField0_ |= 0x00000008; - type_ = rawValue; - } - break; + case 34: { + if (!((mutable_bitField0_ & 0x00000008) == 0x00000008)) { + headers_ = new java.util.ArrayList(); + mutable_bitField0_ |= 0x00000008; + } + headers_.add( + input.readMessage(com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.PARSER, extensionRegistry)); + break; + } + case 40: { + int rawValue = input.readEnum(); + @SuppressWarnings("deprecation") + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType value = com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType.valueOf(rawValue); + if (value == null) { + unknownFields.mergeVarintField(5, rawValue); + } else { + bitField0_ |= 0x00000008; + type_ = rawValue; + } + break; + } + default: { + if (!parseUnknownField( + input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException( + e).setUnfinishedMessage(this); + } finally { + if (((mutable_bitField0_ & 0x00000008) == 0x00000008)) { + headers_ = java.util.Collections.unmodifiableList(headers_); + } + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_ArtifactCredentialInfo_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_ArtifactCredentialInfo_fieldAccessorTable + .ensureFieldAccessorsInitialized( + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.class, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder.class); + } + + public interface HttpHeaderOrBuilder extends + // @@protoc_insertion_point(interface_extends:mlflow.ArtifactCredentialInfo.HttpHeader) + com.google.protobuf.MessageOrBuilder { + + /** + *
+       * The HTTP header name
+       * 
+ * + * optional string name = 1; + */ + boolean hasName(); + /** + *
+       * The HTTP header name
+       * 
+ * + * optional string name = 1; + */ + java.lang.String getName(); + /** + *
+       * The HTTP header name
+       * 
+ * + * optional string name = 1; + */ + com.google.protobuf.ByteString + getNameBytes(); + + /** + *
+       * The HTTP header value
+       * 
+ * + * optional string value = 2; + */ + boolean hasValue(); + /** + *
+       * The HTTP header value
+       * 
+ * + * optional string value = 2; + */ + java.lang.String getValue(); + /** + *
+       * The HTTP header value
+       * 
+ * + * optional string value = 2; + */ + com.google.protobuf.ByteString + getValueBytes(); + } + /** + * Protobuf type {@code mlflow.ArtifactCredentialInfo.HttpHeader} + */ + public static final class HttpHeader extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:mlflow.ArtifactCredentialInfo.HttpHeader) + HttpHeaderOrBuilder { + private static final long serialVersionUID = 0L; + // Use HttpHeader.newBuilder() to construct. + private HttpHeader(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private HttpHeader() { + name_ = ""; + value_ = ""; + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + private HttpHeader( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + com.google.protobuf.ByteString bs = input.readBytes(); + bitField0_ |= 0x00000001; + name_ = bs; + break; + } + case 18: { + com.google.protobuf.ByteString bs = input.readBytes(); + bitField0_ |= 0x00000002; + value_ = bs; + break; + } + default: { + if (!parseUnknownField( + input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException( + e).setUnfinishedMessage(this); + } finally { + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_ArtifactCredentialInfo_HttpHeader_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_ArtifactCredentialInfo_HttpHeader_fieldAccessorTable + .ensureFieldAccessorsInitialized( + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.class, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.Builder.class); + } + + private int bitField0_; + public static final int NAME_FIELD_NUMBER = 1; + private volatile java.lang.Object name_; + /** + *
+       * The HTTP header name
+       * 
+ * + * optional string name = 1; + */ + public boolean hasName() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + *
+       * The HTTP header name
+       * 
+ * + * optional string name = 1; + */ + public java.lang.String getName() { + java.lang.Object ref = name_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + name_ = s; + } + return s; + } + } + /** + *
+       * The HTTP header name
+       * 
+ * + * optional string name = 1; + */ + public com.google.protobuf.ByteString + getNameBytes() { + java.lang.Object ref = name_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + name_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + public static final int VALUE_FIELD_NUMBER = 2; + private volatile java.lang.Object value_; + /** + *
+       * The HTTP header value
+       * 
+ * + * optional string value = 2; + */ + public boolean hasValue() { + return ((bitField0_ & 0x00000002) == 0x00000002); + } + /** + *
+       * The HTTP header value
+       * 
+ * + * optional string value = 2; + */ + public java.lang.String getValue() { + java.lang.Object ref = value_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + value_ = s; + } + return s; + } + } + /** + *
+       * The HTTP header value
+       * 
+ * + * optional string value = 2; + */ + public com.google.protobuf.ByteString + getValueBytes() { + java.lang.Object ref = value_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + value_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (((bitField0_ & 0x00000001) == 0x00000001)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, name_); + } + if (((bitField0_ & 0x00000002) == 0x00000002)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 2, value_); + } + unknownFields.writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (((bitField0_ & 0x00000001) == 0x00000001)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, name_); + } + if (((bitField0_ & 0x00000002) == 0x00000002)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(2, value_); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader)) { + return super.equals(obj); + } + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader other = (com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader) obj; + + boolean result = true; + result = result && (hasName() == other.hasName()); + if (hasName()) { + result = result && getName() + .equals(other.getName()); + } + result = result && (hasValue() == other.hasValue()); + if (hasValue()) { + result = result && getValue() + .equals(other.getValue()); + } + result = result && unknownFields.equals(other.unknownFields); + return result; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (hasName()) { + hash = (37 * hash) + NAME_FIELD_NUMBER; + hash = (53 * hash) + getName().hashCode(); + } + if (hasValue()) { + hash = (37 * hash) + VALUE_FIELD_NUMBER; + hash = (53 * hash) + getValue().hashCode(); + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code mlflow.ArtifactCredentialInfo.HttpHeader} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:mlflow.ArtifactCredentialInfo.HttpHeader) + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeaderOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_ArtifactCredentialInfo_HttpHeader_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_ArtifactCredentialInfo_HttpHeader_fieldAccessorTable + .ensureFieldAccessorsInitialized( + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.class, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.Builder.class); + } + + // Construct using com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3 + .alwaysUseFieldBuilders) { + } + } + @java.lang.Override + public Builder clear() { + super.clear(); + name_ = ""; + bitField0_ = (bitField0_ & ~0x00000001); + value_ = ""; + bitField0_ = (bitField0_ & ~0x00000002); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_ArtifactCredentialInfo_HttpHeader_descriptor; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader getDefaultInstanceForType() { + return com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.getDefaultInstance(); + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader build() { + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader buildPartial() { + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader result = new com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader(this); + int from_bitField0_ = bitField0_; + int to_bitField0_ = 0; + if (((from_bitField0_ & 0x00000001) == 0x00000001)) { + to_bitField0_ |= 0x00000001; + } + result.name_ = name_; + if (((from_bitField0_ & 0x00000002) == 0x00000002)) { + to_bitField0_ |= 0x00000002; + } + result.value_ = value_; + result.bitField0_ = to_bitField0_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return (Builder) super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return (Builder) super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return (Builder) super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return (Builder) super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return (Builder) super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return (Builder) super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader) { + return mergeFrom((com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader other) { + if (other == com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.getDefaultInstance()) return this; + if (other.hasName()) { + bitField0_ |= 0x00000001; + name_ = other.name_; + onChanged(); + } + if (other.hasValue()) { + bitField0_ |= 0x00000002; + value_ = other.value_; + onChanged(); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + private int bitField0_; + + private java.lang.Object name_ = ""; + /** + *
+         * The HTTP header name
+         * 
+ * + * optional string name = 1; + */ + public boolean hasName() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + *
+         * The HTTP header name
+         * 
+ * + * optional string name = 1; + */ + public java.lang.String getName() { + java.lang.Object ref = name_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + name_ = s; } - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + *
+         * The HTTP header name
+         * 
+ * + * optional string name = 1; + */ + public com.google.protobuf.ByteString + getNameBytes() { + java.lang.Object ref = name_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + name_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + *
+         * The HTTP header name
+         * 
+ * + * optional string name = 1; + */ + public Builder setName( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000001; + name_ = value; + onChanged(); + return this; + } + /** + *
+         * The HTTP header name
+         * 
+ * + * optional string name = 1; + */ + public Builder clearName() { + bitField0_ = (bitField0_ & ~0x00000001); + name_ = getDefaultInstance().getName(); + onChanged(); + return this; + } + /** + *
+         * The HTTP header name
+         * 
+ * + * optional string name = 1; + */ + public Builder setNameBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000001; + name_ = value; + onChanged(); + return this; + } + + private java.lang.Object value_ = ""; + /** + *
+         * The HTTP header value
+         * 
+ * + * optional string value = 2; + */ + public boolean hasValue() { + return ((bitField0_ & 0x00000002) == 0x00000002); + } + /** + *
+         * The HTTP header value
+         * 
+ * + * optional string value = 2; + */ + public java.lang.String getValue() { + java.lang.Object ref = value_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + value_ = s; } + return s; + } else { + return (java.lang.String) ref; } } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); + /** + *
+         * The HTTP header value
+         * 
+ * + * optional string value = 2; + */ + public com.google.protobuf.ByteString + getValueBytes() { + java.lang.Object ref = value_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + value_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + *
+         * The HTTP header value
+         * 
+ * + * optional string value = 2; + */ + public Builder setValue( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000002; + value_ = value; + onChanged(); + return this; + } + /** + *
+         * The HTTP header value
+         * 
+ * + * optional string value = 2; + */ + public Builder clearValue() { + bitField0_ = (bitField0_ & ~0x00000002); + value_ = getDefaultInstance().getValue(); + onChanged(); + return this; + } + /** + *
+         * The HTTP header value
+         * 
+ * + * optional string value = 2; + */ + public Builder setValueBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000002; + value_ = value; + onChanged(); + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:mlflow.ArtifactCredentialInfo.HttpHeader) + } + + // @@protoc_insertion_point(class_scope:mlflow.ArtifactCredentialInfo.HttpHeader) + private static final com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader(); + } + + public static com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + @java.lang.Deprecated public static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public HttpHeader parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new HttpHeader(input, extensionRegistry); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader getDefaultInstanceForType() { + return DEFAULT_INSTANCE; } - } - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_ArtifactCredentialInfo_descriptor; - } - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return com.databricks.api.proto.mlflow.DatabricksArtifacts.internal_static_mlflow_ArtifactCredentialInfo_fieldAccessorTable - .ensureFieldAccessorsInitialized( - com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.class, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.Builder.class); } private int bitField0_; @@ -490,26 +1396,86 @@ public java.lang.String getSignedUri() { } /** *
-     * The signed URI credential that provides access to the artifact
+     * The signed URI credential that provides access to the artifact
+     * 
+ * + * optional string signed_uri = 3; + */ + public com.google.protobuf.ByteString + getSignedUriBytes() { + java.lang.Object ref = signedUri_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + signedUri_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + public static final int HEADERS_FIELD_NUMBER = 4; + private java.util.List headers_; + /** + *
+     * A collection of HTTP headers that should be specified when uploading to
+     * or downloading from the specified `signed_uri`
+     * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public java.util.List getHeadersList() { + return headers_; + } + /** + *
+     * A collection of HTTP headers that should be specified when uploading to
+     * or downloading from the specified `signed_uri`
+     * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public java.util.List + getHeadersOrBuilderList() { + return headers_; + } + /** + *
+     * A collection of HTTP headers that should be specified when uploading to
+     * or downloading from the specified `signed_uri`
+     * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public int getHeadersCount() { + return headers_.size(); + } + /** + *
+     * A collection of HTTP headers that should be specified when uploading to
+     * or downloading from the specified `signed_uri`
      * 
* - * optional string signed_uri = 3; + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; */ - public com.google.protobuf.ByteString - getSignedUriBytes() { - java.lang.Object ref = signedUri_; - if (ref instanceof java.lang.String) { - com.google.protobuf.ByteString b = - com.google.protobuf.ByteString.copyFromUtf8( - (java.lang.String) ref); - signedUri_ = b; - return b; - } else { - return (com.google.protobuf.ByteString) ref; - } + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader getHeaders(int index) { + return headers_.get(index); + } + /** + *
+     * A collection of HTTP headers that should be specified when uploading to
+     * or downloading from the specified `signed_uri`
+     * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeaderOrBuilder getHeadersOrBuilder( + int index) { + return headers_.get(index); } - public static final int TYPE_FIELD_NUMBER = 4; + public static final int TYPE_FIELD_NUMBER = 5; private int type_; /** *
@@ -517,7 +1483,7 @@ public java.lang.String getSignedUri() {
      * or an Azure Shared Access Signature URI)
      * 
* - * optional .mlflow.ArtifactCredentialType type = 4; + * optional .mlflow.ArtifactCredentialType type = 5; */ public boolean hasType() { return ((bitField0_ & 0x00000008) == 0x00000008); @@ -528,7 +1494,7 @@ public boolean hasType() { * or an Azure Shared Access Signature URI) * * - * optional .mlflow.ArtifactCredentialType type = 4; + * optional .mlflow.ArtifactCredentialType type = 5; */ public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType getType() { @SuppressWarnings("deprecation") @@ -559,8 +1525,11 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (((bitField0_ & 0x00000004) == 0x00000004)) { com.google.protobuf.GeneratedMessageV3.writeString(output, 3, signedUri_); } + for (int i = 0; i < headers_.size(); i++) { + output.writeMessage(4, headers_.get(i)); + } if (((bitField0_ & 0x00000008) == 0x00000008)) { - output.writeEnum(4, type_); + output.writeEnum(5, type_); } unknownFields.writeTo(output); } @@ -580,9 +1549,13 @@ public int getSerializedSize() { if (((bitField0_ & 0x00000004) == 0x00000004)) { size += com.google.protobuf.GeneratedMessageV3.computeStringSize(3, signedUri_); } + for (int i = 0; i < headers_.size(); i++) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(4, headers_.get(i)); + } if (((bitField0_ & 0x00000008) == 0x00000008)) { size += com.google.protobuf.CodedOutputStream - .computeEnumSize(4, type_); + .computeEnumSize(5, type_); } size += unknownFields.getSerializedSize(); memoizedSize = size; @@ -615,6 +1588,8 @@ public boolean equals(final java.lang.Object obj) { result = result && getSignedUri() .equals(other.getSignedUri()); } + result = result && getHeadersList() + .equals(other.getHeadersList()); result = result && (hasType() == other.hasType()); if (hasType()) { result = result && type_ == other.type_; @@ -642,6 +1617,10 @@ public int hashCode() { hash = (37 * hash) + SIGNED_URI_FIELD_NUMBER; hash = (53 * hash) + getSignedUri().hashCode(); } + if (getHeadersCount() > 0) { + hash = (37 * hash) + HEADERS_FIELD_NUMBER; + hash = (53 * hash) + getHeadersList().hashCode(); + } if (hasType()) { hash = (37 * hash) + TYPE_FIELD_NUMBER; hash = (53 * hash) + type_; @@ -774,6 +1753,7 @@ private Builder( private void maybeForceBuilderInitialization() { if (com.google.protobuf.GeneratedMessageV3 .alwaysUseFieldBuilders) { + getHeadersFieldBuilder(); } } @java.lang.Override @@ -785,8 +1765,14 @@ public Builder clear() { bitField0_ = (bitField0_ & ~0x00000002); signedUri_ = ""; bitField0_ = (bitField0_ & ~0x00000004); + if (headersBuilder_ == null) { + headers_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000008); + } else { + headersBuilder_.clear(); + } type_ = 1; - bitField0_ = (bitField0_ & ~0x00000008); + bitField0_ = (bitField0_ & ~0x00000010); return this; } @@ -827,7 +1813,16 @@ public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInf to_bitField0_ |= 0x00000004; } result.signedUri_ = signedUri_; - if (((from_bitField0_ & 0x00000008) == 0x00000008)) { + if (headersBuilder_ == null) { + if (((bitField0_ & 0x00000008) == 0x00000008)) { + headers_ = java.util.Collections.unmodifiableList(headers_); + bitField0_ = (bitField0_ & ~0x00000008); + } + result.headers_ = headers_; + } else { + result.headers_ = headersBuilder_.build(); + } + if (((from_bitField0_ & 0x00000010) == 0x00000010)) { to_bitField0_ |= 0x00000008; } result.type_ = type_; @@ -895,6 +1890,32 @@ public Builder mergeFrom(com.databricks.api.proto.mlflow.DatabricksArtifacts.Art signedUri_ = other.signedUri_; onChanged(); } + if (headersBuilder_ == null) { + if (!other.headers_.isEmpty()) { + if (headers_.isEmpty()) { + headers_ = other.headers_; + bitField0_ = (bitField0_ & ~0x00000008); + } else { + ensureHeadersIsMutable(); + headers_.addAll(other.headers_); + } + onChanged(); + } + } else { + if (!other.headers_.isEmpty()) { + if (headersBuilder_.isEmpty()) { + headersBuilder_.dispose(); + headersBuilder_ = null; + headers_ = other.headers_; + bitField0_ = (bitField0_ & ~0x00000008); + headersBuilder_ = + com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders ? + getHeadersFieldBuilder() : null; + } else { + headersBuilder_.addAllMessages(other.headers_); + } + } + } if (other.hasType()) { setType(other.getType()); } @@ -1240,6 +2261,336 @@ public Builder setSignedUriBytes( return this; } + private java.util.List headers_ = + java.util.Collections.emptyList(); + private void ensureHeadersIsMutable() { + if (!((bitField0_ & 0x00000008) == 0x00000008)) { + headers_ = new java.util.ArrayList(headers_); + bitField0_ |= 0x00000008; + } + } + + private com.google.protobuf.RepeatedFieldBuilderV3< + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.Builder, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeaderOrBuilder> headersBuilder_; + + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public java.util.List getHeadersList() { + if (headersBuilder_ == null) { + return java.util.Collections.unmodifiableList(headers_); + } else { + return headersBuilder_.getMessageList(); + } + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public int getHeadersCount() { + if (headersBuilder_ == null) { + return headers_.size(); + } else { + return headersBuilder_.getCount(); + } + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader getHeaders(int index) { + if (headersBuilder_ == null) { + return headers_.get(index); + } else { + return headersBuilder_.getMessage(index); + } + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public Builder setHeaders( + int index, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader value) { + if (headersBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureHeadersIsMutable(); + headers_.set(index, value); + onChanged(); + } else { + headersBuilder_.setMessage(index, value); + } + return this; + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public Builder setHeaders( + int index, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.Builder builderForValue) { + if (headersBuilder_ == null) { + ensureHeadersIsMutable(); + headers_.set(index, builderForValue.build()); + onChanged(); + } else { + headersBuilder_.setMessage(index, builderForValue.build()); + } + return this; + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public Builder addHeaders(com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader value) { + if (headersBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureHeadersIsMutable(); + headers_.add(value); + onChanged(); + } else { + headersBuilder_.addMessage(value); + } + return this; + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public Builder addHeaders( + int index, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader value) { + if (headersBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureHeadersIsMutable(); + headers_.add(index, value); + onChanged(); + } else { + headersBuilder_.addMessage(index, value); + } + return this; + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public Builder addHeaders( + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.Builder builderForValue) { + if (headersBuilder_ == null) { + ensureHeadersIsMutable(); + headers_.add(builderForValue.build()); + onChanged(); + } else { + headersBuilder_.addMessage(builderForValue.build()); + } + return this; + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public Builder addHeaders( + int index, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.Builder builderForValue) { + if (headersBuilder_ == null) { + ensureHeadersIsMutable(); + headers_.add(index, builderForValue.build()); + onChanged(); + } else { + headersBuilder_.addMessage(index, builderForValue.build()); + } + return this; + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public Builder addAllHeaders( + java.lang.Iterable values) { + if (headersBuilder_ == null) { + ensureHeadersIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll( + values, headers_); + onChanged(); + } else { + headersBuilder_.addAllMessages(values); + } + return this; + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public Builder clearHeaders() { + if (headersBuilder_ == null) { + headers_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000008); + onChanged(); + } else { + headersBuilder_.clear(); + } + return this; + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public Builder removeHeaders(int index) { + if (headersBuilder_ == null) { + ensureHeadersIsMutable(); + headers_.remove(index); + onChanged(); + } else { + headersBuilder_.remove(index); + } + return this; + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.Builder getHeadersBuilder( + int index) { + return getHeadersFieldBuilder().getBuilder(index); + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeaderOrBuilder getHeadersOrBuilder( + int index) { + if (headersBuilder_ == null) { + return headers_.get(index); } else { + return headersBuilder_.getMessageOrBuilder(index); + } + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public java.util.List + getHeadersOrBuilderList() { + if (headersBuilder_ != null) { + return headersBuilder_.getMessageOrBuilderList(); + } else { + return java.util.Collections.unmodifiableList(headers_); + } + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.Builder addHeadersBuilder() { + return getHeadersFieldBuilder().addBuilder( + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.getDefaultInstance()); + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.Builder addHeadersBuilder( + int index) { + return getHeadersFieldBuilder().addBuilder( + index, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.getDefaultInstance()); + } + /** + *
+       * A collection of HTTP headers that should be specified when uploading to
+       * or downloading from the specified `signed_uri`
+       * 
+ * + * repeated .mlflow.ArtifactCredentialInfo.HttpHeader headers = 4; + */ + public java.util.List + getHeadersBuilderList() { + return getHeadersFieldBuilder().getBuilderList(); + } + private com.google.protobuf.RepeatedFieldBuilderV3< + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.Builder, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeaderOrBuilder> + getHeadersFieldBuilder() { + if (headersBuilder_ == null) { + headersBuilder_ = new com.google.protobuf.RepeatedFieldBuilderV3< + com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeader.Builder, com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialInfo.HttpHeaderOrBuilder>( + headers_, + ((bitField0_ & 0x00000008) == 0x00000008), + getParentForChildren(), + isClean()); + headers_ = null; + } + return headersBuilder_; + } + private int type_ = 1; /** *
@@ -1247,10 +2598,10 @@ public Builder setSignedUriBytes(
        * or an Azure Shared Access Signature URI)
        * 
* - * optional .mlflow.ArtifactCredentialType type = 4; + * optional .mlflow.ArtifactCredentialType type = 5; */ public boolean hasType() { - return ((bitField0_ & 0x00000008) == 0x00000008); + return ((bitField0_ & 0x00000010) == 0x00000010); } /** *
@@ -1258,7 +2609,7 @@ public boolean hasType() {
        * or an Azure Shared Access Signature URI)
        * 
* - * optional .mlflow.ArtifactCredentialType type = 4; + * optional .mlflow.ArtifactCredentialType type = 5; */ public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType getType() { @SuppressWarnings("deprecation") @@ -1271,13 +2622,13 @@ public com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialTyp * or an Azure Shared Access Signature URI) * * - * optional .mlflow.ArtifactCredentialType type = 4; + * optional .mlflow.ArtifactCredentialType type = 5; */ public Builder setType(com.databricks.api.proto.mlflow.DatabricksArtifacts.ArtifactCredentialType value) { if (value == null) { throw new NullPointerException(); } - bitField0_ |= 0x00000008; + bitField0_ |= 0x00000010; type_ = value.getNumber(); onChanged(); return this; @@ -1288,10 +2639,10 @@ public Builder setType(com.databricks.api.proto.mlflow.DatabricksArtifacts.Artif * or an Azure Shared Access Signature URI) * * - * optional .mlflow.ArtifactCredentialType type = 4; + * optional .mlflow.ArtifactCredentialType type = 5; */ public Builder clearType() { - bitField0_ = (bitField0_ & ~0x00000008); + bitField0_ = (bitField0_ & ~0x00000010); type_ = 1; onChanged(); return this; @@ -1358,7 +2709,7 @@ public interface GetCredentialsForReadOrBuilder extends * The ID of the MLflow Run for which to fetch artifact read credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ boolean hasRunId(); /** @@ -1366,7 +2717,7 @@ public interface GetCredentialsForReadOrBuilder extends * The ID of the MLflow Run for which to fetch artifact read credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ java.lang.String getRunId(); /** @@ -1374,7 +2725,7 @@ public interface GetCredentialsForReadOrBuilder extends * The ID of the MLflow Run for which to fetch artifact read credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ com.google.protobuf.ByteString getRunIdBytes(); @@ -1385,7 +2736,7 @@ public interface GetCredentialsForReadOrBuilder extends * fetch artifact read credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ boolean hasPath(); /** @@ -1394,7 +2745,7 @@ public interface GetCredentialsForReadOrBuilder extends * fetch artifact read credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ java.lang.String getPath(); /** @@ -1403,7 +2754,7 @@ public interface GetCredentialsForReadOrBuilder extends * fetch artifact read credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ com.google.protobuf.ByteString getPathBytes(); @@ -2178,7 +3529,7 @@ public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForRead * The ID of the MLflow Run for which to fetch artifact read credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public boolean hasRunId() { return ((bitField0_ & 0x00000001) == 0x00000001); @@ -2188,7 +3539,7 @@ public boolean hasRunId() { * The ID of the MLflow Run for which to fetch artifact read credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public java.lang.String getRunId() { java.lang.Object ref = runId_; @@ -2209,7 +3560,7 @@ public java.lang.String getRunId() { * The ID of the MLflow Run for which to fetch artifact read credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public com.google.protobuf.ByteString getRunIdBytes() { @@ -2233,7 +3584,7 @@ public java.lang.String getRunId() { * fetch artifact read credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public boolean hasPath() { return ((bitField0_ & 0x00000002) == 0x00000002); @@ -2244,7 +3595,7 @@ public boolean hasPath() { * fetch artifact read credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public java.lang.String getPath() { java.lang.Object ref = path_; @@ -2266,7 +3617,7 @@ public java.lang.String getPath() { * fetch artifact read credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public com.google.protobuf.ByteString getPathBytes() { @@ -2630,7 +3981,7 @@ public Builder mergeFrom( * The ID of the MLflow Run for which to fetch artifact read credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public boolean hasRunId() { return ((bitField0_ & 0x00000001) == 0x00000001); @@ -2640,7 +3991,7 @@ public boolean hasRunId() { * The ID of the MLflow Run for which to fetch artifact read credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public java.lang.String getRunId() { java.lang.Object ref = runId_; @@ -2661,7 +4012,7 @@ public java.lang.String getRunId() { * The ID of the MLflow Run for which to fetch artifact read credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public com.google.protobuf.ByteString getRunIdBytes() { @@ -2681,7 +4032,7 @@ public java.lang.String getRunId() { * The ID of the MLflow Run for which to fetch artifact read credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public Builder setRunId( java.lang.String value) { @@ -2698,7 +4049,7 @@ public Builder setRunId( * The ID of the MLflow Run for which to fetch artifact read credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public Builder clearRunId() { bitField0_ = (bitField0_ & ~0x00000001); @@ -2711,7 +4062,7 @@ public Builder clearRunId() { * The ID of the MLflow Run for which to fetch artifact read credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public Builder setRunIdBytes( com.google.protobuf.ByteString value) { @@ -2731,7 +4082,7 @@ public Builder setRunIdBytes( * fetch artifact read credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public boolean hasPath() { return ((bitField0_ & 0x00000002) == 0x00000002); @@ -2742,7 +4093,7 @@ public boolean hasPath() { * fetch artifact read credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public java.lang.String getPath() { java.lang.Object ref = path_; @@ -2764,7 +4115,7 @@ public java.lang.String getPath() { * fetch artifact read credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public com.google.protobuf.ByteString getPathBytes() { @@ -2785,7 +4136,7 @@ public java.lang.String getPath() { * fetch artifact read credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public Builder setPath( java.lang.String value) { @@ -2803,7 +4154,7 @@ public Builder setPath( * fetch artifact read credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public Builder clearPath() { bitField0_ = (bitField0_ & ~0x00000002); @@ -2817,7 +4168,7 @@ public Builder clearPath() { * fetch artifact read credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public Builder setPathBytes( com.google.protobuf.ByteString value) { @@ -2891,7 +4242,7 @@ public interface GetCredentialsForWriteOrBuilder extends * The ID of the MLflow Run for which to fetch artifact write credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ boolean hasRunId(); /** @@ -2899,7 +4250,7 @@ public interface GetCredentialsForWriteOrBuilder extends * The ID of the MLflow Run for which to fetch artifact write credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ java.lang.String getRunId(); /** @@ -2907,7 +4258,7 @@ public interface GetCredentialsForWriteOrBuilder extends * The ID of the MLflow Run for which to fetch artifact write credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ com.google.protobuf.ByteString getRunIdBytes(); @@ -2918,7 +4269,7 @@ public interface GetCredentialsForWriteOrBuilder extends * fetch artifact write credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ boolean hasPath(); /** @@ -2927,7 +4278,7 @@ public interface GetCredentialsForWriteOrBuilder extends * fetch artifact write credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ java.lang.String getPath(); /** @@ -2936,7 +4287,7 @@ public interface GetCredentialsForWriteOrBuilder extends * fetch artifact write credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ com.google.protobuf.ByteString getPathBytes(); @@ -3711,7 +5062,7 @@ public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrit * The ID of the MLflow Run for which to fetch artifact write credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public boolean hasRunId() { return ((bitField0_ & 0x00000001) == 0x00000001); @@ -3721,7 +5072,7 @@ public boolean hasRunId() { * The ID of the MLflow Run for which to fetch artifact write credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public java.lang.String getRunId() { java.lang.Object ref = runId_; @@ -3742,7 +5093,7 @@ public java.lang.String getRunId() { * The ID of the MLflow Run for which to fetch artifact write credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public com.google.protobuf.ByteString getRunIdBytes() { @@ -3766,7 +5117,7 @@ public java.lang.String getRunId() { * fetch artifact write credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public boolean hasPath() { return ((bitField0_ & 0x00000002) == 0x00000002); @@ -3777,7 +5128,7 @@ public boolean hasPath() { * fetch artifact write credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public java.lang.String getPath() { java.lang.Object ref = path_; @@ -3799,7 +5150,7 @@ public java.lang.String getPath() { * fetch artifact write credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public com.google.protobuf.ByteString getPathBytes() { @@ -4163,7 +5514,7 @@ public Builder mergeFrom( * The ID of the MLflow Run for which to fetch artifact write credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public boolean hasRunId() { return ((bitField0_ & 0x00000001) == 0x00000001); @@ -4173,7 +5524,7 @@ public boolean hasRunId() { * The ID of the MLflow Run for which to fetch artifact write credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public java.lang.String getRunId() { java.lang.Object ref = runId_; @@ -4194,7 +5545,7 @@ public java.lang.String getRunId() { * The ID of the MLflow Run for which to fetch artifact write credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public com.google.protobuf.ByteString getRunIdBytes() { @@ -4214,7 +5565,7 @@ public java.lang.String getRunId() { * The ID of the MLflow Run for which to fetch artifact write credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public Builder setRunId( java.lang.String value) { @@ -4231,7 +5582,7 @@ public Builder setRunId( * The ID of the MLflow Run for which to fetch artifact write credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public Builder clearRunId() { bitField0_ = (bitField0_ & ~0x00000001); @@ -4244,7 +5595,7 @@ public Builder clearRunId() { * The ID of the MLflow Run for which to fetch artifact write credentials * * - * optional string run_id = 1; + * optional string run_id = 1 [(.mlflow.validate_required) = true]; */ public Builder setRunIdBytes( com.google.protobuf.ByteString value) { @@ -4264,7 +5615,7 @@ public Builder setRunIdBytes( * fetch artifact write credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public boolean hasPath() { return ((bitField0_ & 0x00000002) == 0x00000002); @@ -4275,7 +5626,7 @@ public boolean hasPath() { * fetch artifact write credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public java.lang.String getPath() { java.lang.Object ref = path_; @@ -4297,7 +5648,7 @@ public java.lang.String getPath() { * fetch artifact write credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public com.google.protobuf.ByteString getPathBytes() { @@ -4318,7 +5669,7 @@ public java.lang.String getPath() { * fetch artifact write credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public Builder setPath( java.lang.String value) { @@ -4336,7 +5687,7 @@ public Builder setPath( * fetch artifact write credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public Builder clearPath() { bitField0_ = (bitField0_ & ~0x00000002); @@ -4350,7 +5701,7 @@ public Builder clearPath() { * fetch artifact write credentials * * - * optional string path = 2; + * optional string path = 2 [(.mlflow.validate_required) = true]; */ public Builder setPathBytes( com.google.protobuf.ByteString value) { @@ -4420,6 +5771,11 @@ public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrit private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internal_static_mlflow_ArtifactCredentialInfo_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_mlflow_ArtifactCredentialInfo_HttpHeader_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_mlflow_ArtifactCredentialInfo_HttpHeader_fieldAccessorTable; private static final com.google.protobuf.Descriptors.Descriptor internal_static_mlflow_GetCredentialsForRead_descriptor; private static final @@ -4450,31 +5806,36 @@ public com.databricks.api.proto.mlflow.DatabricksArtifacts.GetCredentialsForWrit static { java.lang.String[] descriptorData = { "\n\032databricks_artifacts.proto\022\006mlflow\032\025sc" + - "alapb/scalapb.proto\032\020databricks.proto\"x\n" + - "\026ArtifactCredentialInfo\022\016\n\006run_id\030\001 \001(\t\022" + - "\014\n\004path\030\002 \001(\t\022\022\n\nsigned_uri\030\003 \001(\t\022,\n\004typ" + - "e\030\004 \001(\0162\036.mlflow.ArtifactCredentialType\"" + - "\243\001\n\025GetCredentialsForRead\022\016\n\006run_id\030\001 \001(" + - "\t\022\014\n\004path\030\002 \001(\t\032?\n\010Response\0223\n\013credentia" + - "ls\030\001 \001(\0132\036.mlflow.ArtifactCredentialInfo" + - ":+\342?(\n&com.databricks.rpc.RPC[$this.Resp" + - "onse]\"\244\001\n\026GetCredentialsForWrite\022\016\n\006run_" + - "id\030\001 \001(\t\022\014\n\004path\030\002 \001(\t\032?\n\010Response\0223\n\013cr" + - "edentials\030\001 \001(\0132\036.mlflow.ArtifactCredent" + - "ialInfo:+\342?(\n&com.databricks.rpc.RPC[$th" + - "is.Response]*B\n\026ArtifactCredentialType\022\021" + - "\n\rAZURE_SAS_URI\020\001\022\025\n\021AWS_PRESIGNED_URL\020\002" + - "2\342\002\n DatabricksMlflowArtifactsService\022\233\001" + - "\n\025getCredentialsForRead\022\035.mlflow.GetCred" + - "entialsForRead\032&.mlflow.GetCredentialsFo" + - "rRead.Response\";\362\206\0317\n3\n\003GET\022&/mlflow/art" + - "ifacts/credentials-for-read\032\004\010\002\020\000\020\003\022\237\001\n\026" + - "getCredentialsForWrite\022\036.mlflow.GetCrede" + - "ntialsForWrite\032\'.mlflow.GetCredentialsFo" + - "rWrite.Response\"<\362\206\0318\n4\n\003GET\022\'/mlflow/ar" + - "tifacts/credentials-for-write\032\004\010\002\020\000\020\003B,\n" + - "\037com.databricks.api.proto.mlflow\220\001\001\240\001\001\342?" + - "\002\020\001" + "alapb/scalapb.proto\032\020databricks.proto\"\337\001" + + "\n\026ArtifactCredentialInfo\022\016\n\006run_id\030\001 \001(\t" + + "\022\014\n\004path\030\002 \001(\t\022\022\n\nsigned_uri\030\003 \001(\t\022:\n\007he" + + "aders\030\004 \003(\0132).mlflow.ArtifactCredentialI" + + "nfo.HttpHeader\022,\n\004type\030\005 \001(\0162\036.mlflow.Ar" + + "tifactCredentialType\032)\n\nHttpHeader\022\014\n\004na" + + "me\030\001 \001(\t\022\r\n\005value\030\002 \001(\t\"\343\001\n\025GetCredentia" + + "lsForRead\022\024\n\006run_id\030\001 \001(\tB\004\370\206\031\001\022\022\n\004path\030" + + "\002 \001(\tB\004\370\206\031\001\032?\n\010Response\0223\n\013credentials\030\001" + + " \001(\0132\036.mlflow.ArtifactCredentialInfo:_\342?" + + "(\n&com.databricks.rpc.RPC[$this.Response" + + "]\342?1\n/com.databricks.mlflow.api.MlflowTr" + + "ackingMessage\"\344\001\n\026GetCredentialsForWrite" + + "\022\024\n\006run_id\030\001 \001(\tB\004\370\206\031\001\022\022\n\004path\030\002 \001(\tB\004\370\206" + + "\031\001\032?\n\010Response\0223\n\013credentials\030\001 \001(\0132\036.ml" + + "flow.ArtifactCredentialInfo:_\342?(\n&com.da" + + "tabricks.rpc.RPC[$this.Response]\342?1\n/com" + + ".databricks.mlflow.api.MlflowTrackingMes" + + "sage*B\n\026ArtifactCredentialType\022\021\n\rAZURE_" + + "SAS_URI\020\001\022\025\n\021AWS_PRESIGNED_URL\020\0022\342\002\n Dat" + + "abricksMlflowArtifactsService\022\233\001\n\025getCre" + + "dentialsForRead\022\035.mlflow.GetCredentialsF" + + "orRead\032&.mlflow.GetCredentialsForRead.Re" + + "sponse\";\362\206\0317\n3\n\003GET\022&/mlflow/artifacts/c" + + "redentials-for-read\032\004\010\002\020\000\020\003\022\237\001\n\026getCrede" + + "ntialsForWrite\022\036.mlflow.GetCredentialsFo" + + "rWrite\032\'.mlflow.GetCredentialsForWrite.R" + + "esponse\"<\362\206\0318\n4\n\003GET\022\'/mlflow/artifacts/" + + "credentials-for-write\032\004\010\002\020\000\020\003B,\n\037com.dat" + + "abricks.api.proto.mlflow\220\001\001\240\001\001\342?\002\020\001" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { @@ -4495,7 +5856,13 @@ public com.google.protobuf.ExtensionRegistry assignDescriptors( internal_static_mlflow_ArtifactCredentialInfo_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_mlflow_ArtifactCredentialInfo_descriptor, - new java.lang.String[] { "RunId", "Path", "SignedUri", "Type", }); + new java.lang.String[] { "RunId", "Path", "SignedUri", "Headers", "Type", }); + internal_static_mlflow_ArtifactCredentialInfo_HttpHeader_descriptor = + internal_static_mlflow_ArtifactCredentialInfo_descriptor.getNestedTypes().get(0); + internal_static_mlflow_ArtifactCredentialInfo_HttpHeader_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_mlflow_ArtifactCredentialInfo_HttpHeader_descriptor, + new java.lang.String[] { "Name", "Value", }); internal_static_mlflow_GetCredentialsForRead_descriptor = getDescriptor().getMessageTypes().get(1); internal_static_mlflow_GetCredentialsForRead_fieldAccessorTable = new @@ -4523,6 +5890,7 @@ public com.google.protobuf.ExtensionRegistry assignDescriptors( com.google.protobuf.ExtensionRegistry registry = com.google.protobuf.ExtensionRegistry.newInstance(); registry.add(com.databricks.api.proto.databricks.Databricks.rpc); + registry.add(com.databricks.api.proto.databricks.Databricks.validateRequired); registry.add(org.mlflow.scalapb_interface.Scalapb.message); registry.add(org.mlflow.scalapb_interface.Scalapb.options); com.google.protobuf.Descriptors.FileDescriptor diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index 05649a47e4841..934261ed57197 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -1,25 +1,25 @@ -from azure.storage.blob import BlobClient -from azure.core.exceptions import ClientAuthenticationError - -import os -import uuid import base64 import logging -import requests +import os import posixpath +import requests +import uuid + +from azure.core.exceptions import ClientAuthenticationError +from azure.storage.blob import BlobClient from mlflow.entities import FileInfo from mlflow.exceptions import MlflowException -from mlflow.store.artifact.artifact_repo import ArtifactRepository +from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, INTERNAL_ERROR from mlflow.protos.databricks_artifacts_pb2 import DatabricksMlflowArtifactsService, \ GetCredentialsForWrite, GetCredentialsForRead, ArtifactCredentialType -from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE from mlflow.protos.service_pb2 import MlflowService, ListArtifacts -from mlflow.utils.uri import extract_and_normalize_path, is_databricks_acled_artifacts_uri -from mlflow.utils.proto_json_utils import message_to_json +from mlflow.store.artifact.artifact_repo import ArtifactRepository +from mlflow.utils.databricks_utils import get_databricks_host_creds from mlflow.utils.file_utils import relative_path_to_artifact_path, yield_file_in_chunks +from mlflow.utils.proto_json_utils import message_to_json from mlflow.utils.rest_utils import call_endpoint, extract_api_info_for_service -from mlflow.utils.databricks_utils import get_databricks_host_creds +from mlflow.utils.uri import extract_and_normalize_path, is_databricks_acled_artifacts_uri _logger = logging.getLogger(__name__) _PATH_PREFIX = "/api/2.0" @@ -86,11 +86,10 @@ def _get_read_credentials(self, run_id, path=None): return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForRead, json_body) - def _extract_headers_from_credentials(self, credential): - headers = dict() - for header in credential.headers: - headers[header.name] = header.value - return headers + def _extract_headers_from_credentials(self, headers): + return { + header.name: header.value for header in headers + } def _azure_upload_file(self, credentials, local_file, artifact_path): """ @@ -106,7 +105,7 @@ def _azure_upload_file(self, credentials, local_file, artifact_path): stage_block and the commit, a second try-except block refreshes credentials if needed. """ try: - headers = self._extract_headers_from_credentials(credentials) + headers = self._extract_headers_from_credentials(credentials.headers) service = BlobClient.from_blob_url(blob_url=credentials.signed_uri, credential=None, headers=headers) uploading_block_list = list() @@ -124,7 +123,7 @@ def _azure_upload_file(self, credentials, local_file, artifact_path): service.stage_block(block_id, chunk, headers=headers) uploading_block_list.append(block_id) try: - service.commit_block_list(uploading_block_list) + service.commit_block_list(uploading_block_list, headers=headers) except ClientAuthenticationError: _logger.warning( "Failed to authorize request, possibly due to credential expiration." @@ -132,16 +131,16 @@ def _azure_upload_file(self, credentials, local_file, artifact_path): credentials = self._get_write_credentials(self.run_id, artifact_path).credentials.signed_uri service = BlobClient.from_blob_url(blob_url=credentials, credential=None) - service.commit_block_list(uploading_block_list) + service.commit_block_list(uploading_block_list, headers=headers) except Exception as err: raise MlflowException(err) def _aws_upload_file(self, credentials, local_file): try: - headers = self._extract_headers_from_credentials(credentials) + headers = self._extract_headers_from_credentials(credentials.headers) signed_write_uri = credentials.signed_uri with open(local_file, 'rb') as file: - put_request = requests.put(signed_write_uri, headers=headers, data=file) + put_request = requests.put(signed_write_uri, file, headers=headers) put_request.raise_for_status() except Exception as err: raise MlflowException(err) @@ -152,7 +151,8 @@ def _upload_to_cloud(self, cloud_credentials, local_file, artifact_path): elif cloud_credentials.credentials.type == ArtifactCredentialType.AWS_PRESIGNED_URL: self._aws_upload_file(cloud_credentials.credentials, local_file) else: - raise MlflowException('Not implemented yet') + raise MlflowException(message='Cloud provider not supported.', + error_code=INTERNAL_ERROR) def _download_from_cloud(self, cloud_credential, local_file_path): """ @@ -170,7 +170,7 @@ def _download_from_cloud(self, cloud_credential, local_file_path): if cloud_credential.type not in [ArtifactCredentialType.AZURE_SAS_URI, ArtifactCredentialType.AWS_PRESIGNED_URL]: raise MlflowException(message='Cloud provider not supported.', - error_code=INVALID_PARAMETER_VALUE) + error_code=INTERNAL_ERROR) try: signed_read_uri = cloud_credential.signed_uri with requests.get(signed_read_uri, stream=True) as response: diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index af7587501ca1f..0045447631415 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -1,18 +1,20 @@ # -*- coding: utf-8 -*- import os -import pytest +from azure.storage.blob import BlobClient import mock +import pytest +from requests.models import Response from unittest.mock import ANY -from azure.storage.blob import BlobClient +from mlflow.entities.file_info import FileInfo as FileInfoEntity from mlflow.exceptions import MlflowException -from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository -from mlflow.store.artifact.dbfs_artifact_repo import DatabricksArtifactRepository -from mlflow.protos.service_pb2 import ListArtifacts, FileInfo from mlflow.protos.databricks_artifacts_pb2 import GetCredentialsForWrite, GetCredentialsForRead, \ ArtifactCredentialType, ArtifactCredentialInfo -from mlflow.entities.file_info import FileInfo as FileInfoEntity +from mlflow.protos.service_pb2 import ListArtifacts, FileInfo +from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository +from mlflow.store.artifact.dbfs_artifact_repo import DatabricksArtifactRepository + @pytest.fixture() @@ -46,9 +48,11 @@ def test_dir(tmpdir): return tmpdir -MOCK_AZURE_SIGNED_URI = "this_is_a_mock_sas_for_azure" -MOCK_AWS_SIGNED_URI = "this_is_a_mock_presigned_uri_for_aws" -MOCK_RUN_ID = 'MOCK-RUN-ID' +MOCK_AZURE_SIGNED_URI = "http://this_is_a_mock_sas_for_azure" +MOCK_AWS_SIGNED_URI = "http://this_is_a_mock_presigned_uri_for_aws?" +MOCK_RUN_ID = "MOCK-RUN-ID" +MOCK_HEADERS = [ArtifactCredentialInfo.HttpHeader(name='Mock-Name1', value='Mock-Value1'), + ArtifactCredentialInfo.HttpHeader(name='Mock-Name2', value='Mock-Value2')] class TestDatabricksArtifactRepository(object): @@ -86,7 +90,7 @@ def test_log_artifact_azure(self, databricks_artifact_repo, test_file, artifact_ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._azure_upload_file') \ as azure_upload_mock: mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, - type=ArtifactCredentialType.AZURE_SAS_URI) + type=ArtifactCredentialType.AZURE_SAS_URI, ) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto @@ -96,6 +100,53 @@ def test_log_artifact_azure(self, databricks_artifact_repo, test_file, artifact_ azure_upload_mock.assert_called_with(mock_credentials, test_file.strpath, expected_location) + @pytest.mark.parametrize("artifact_path,expected_location", [ + (None, 'test.txt'), + ]) + def test_log_artifact_azure_with_headers(self, databricks_artifact_repo, test_file, + artifact_path, expected_location): + expected_headers = { + header.name: header.value for header in MOCK_HEADERS + } + mock_blob_service = mock.MagicMock(autospec=BlobClient) + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ + as write_credentials_mock, \ + mock.patch( + 'azure.storage.blob.BlobClient.from_blob_url') as mock_create_blob_client: + mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, + type=ArtifactCredentialType.AZURE_SAS_URI, + headers=MOCK_HEADERS) + write_credentials_response_proto = GetCredentialsForWrite.Response( + credentials=mock_credentials) + write_credentials_mock.return_value = write_credentials_response_proto + + mock_create_blob_client.return_value = mock_blob_service + mock_blob_service.stage_block.side_effect = None + mock_blob_service.commit_block_list.side_effect = None + + databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) + write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) + mock_create_blob_client.assert_called_with(blob_url=MOCK_AZURE_SIGNED_URI, + credential=None, + headers=expected_headers) + mock_blob_service.stage_block.assert_called_with(ANY, ANY, headers=expected_headers) + mock_blob_service.commit_block_list.assert_called_with(ANY, headers=expected_headers) + + def test_log_artifact_azure_blob_client_sas_error(self, databricks_artifact_repo, test_file): + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ + as write_credentials_mock, \ + mock.patch( + 'azure.storage.blob.BlobClient.from_blob_url') as mock_create_blob_client: + mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, + type=ArtifactCredentialType.AZURE_SAS_URI) + write_credentials_response_proto = GetCredentialsForWrite.Response( + credentials=mock_credentials) + write_credentials_mock.return_value = write_credentials_response_proto + mock_create_blob_client.side_effect = MlflowException("MOCK ERROR") + with pytest.raises(MlflowException): + databricks_artifact_repo.log_artifact(test_file.strpath) + write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY) + @pytest.mark.parametrize("artifact_path,expected_location", [ (None, 'test.txt'), ]) @@ -115,19 +166,44 @@ def test_log_artifact_aws(self, databricks_artifact_repo, test_file, artifact_pa write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) aws_upload_mock.assert_called_with(mock_credentials, test_file.strpath) - def test_log_artifact_fail_case(self, databricks_artifact_repo, test_file, ): - mock_blob_service = mock.MagicMock(autospec=BlobClient) + @pytest.mark.parametrize("artifact_path,expected_location", [ + (None, 'test.txt'), + ]) + def test_log_artifact_aws_with_headers(self, databricks_artifact_repo, test_file, artifact_path, + expected_location): + expected_headers = { + header.name: header.value for header in MOCK_HEADERS + } + mock_response = Response() + mock_response.status_code = 200 with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ - as write_credentials_mock: - mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, - type=ArtifactCredentialType.AZURE_SAS_URI) + as write_credentials_mock, \ + mock.patch('requests.put') as request_mock: + mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AWS_SIGNED_URI, + type=ArtifactCredentialType.AWS_PRESIGNED_URL, + headers=MOCK_HEADERS) + write_credentials_response_proto = GetCredentialsForWrite.Response( + credentials=mock_credentials) + write_credentials_mock.return_value = write_credentials_response_proto + request_mock.return_value = mock_response + databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) + write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) + request_mock.assert_called_with(ANY, ANY, + headers=expected_headers) + + def test_log_artifact_aws_presigned_url_error(self, databricks_artifact_repo, test_file): + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ + as write_credentials_mock, \ + mock.patch('requests.put') as request_mock: + mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AWS_SIGNED_URI, + type=ArtifactCredentialType.AWS_PRESIGNED_URL) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto - mock_blob_service.from_blob_url().return_value = MlflowException("MOCK ERROR") + request_mock.side_effect = MlflowException("MOCK ERROR") with pytest.raises(MlflowException): databricks_artifact_repo.log_artifact(test_file.strpath) - write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY) + write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY) @pytest.mark.parametrize("artifact_path", [ None, @@ -174,12 +250,14 @@ def test_list_artifacts(self, databricks_artifact_repo): artifacts = databricks_artifact_repo.list_artifacts('a.txt') assert len(artifacts) == 0 - @pytest.mark.parametrize("remote_file_path, local_path", [ - ('test_file.txt', ''), - ('test_file.txt', None), - ('output/test_file', None), + @pytest.mark.parametrize("remote_file_path, local_path, cloud_credential_type", [ + ('test_file.txt', '', ArtifactCredentialType.AZURE_SAS_URI), + ('test_file.txt', None, ArtifactCredentialType.AZURE_SAS_URI), + ('output/test_file', None, ArtifactCredentialType.AZURE_SAS_URI), + ('test_file.txt', '', ArtifactCredentialType.AWS_PRESIGNED_URL), ]) - def test_databricks_download_file(self, databricks_artifact_repo, remote_file_path, local_path): + def test_databricks_download_file(self, databricks_artifact_repo, remote_file_path, local_path, + cloud_credential_type): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \ as read_credentials_mock, \ @@ -187,7 +265,7 @@ def test_databricks_download_file(self, databricks_artifact_repo, remote_file_pa mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._download_from_cloud') \ as download_mock: mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, - type=ArtifactCredentialType.AZURE_SAS_URI) + type=cloud_credential_type) read_credentials_response_proto = GetCredentialsForRead.Response( credentials=mock_credentials) read_credentials_mock.return_value = read_credentials_response_proto @@ -212,4 +290,4 @@ def test_databricks_download_file_get_request_fail(self, databricks_artifact_rep request_mock.return_value = MlflowException("MOCK ERROR") with pytest.raises(MlflowException): databricks_artifact_repo.download_artifacts(test_file.strpath) - read_credentials_mock.assert_called_with(MOCK_RUN_ID, test_file.strpath) + read_credentials_mock.assert_called_with(MOCK_RUN_ID, test_file.strpath) From 3d923cb61a0259f1288997d0a9504cff6b92c289 Mon Sep 17 00:00:00 2001 From: Corey Zumar Date: Sun, 7 Jun 2020 18:34:22 -0700 Subject: [PATCH 15/28] Fix - needs docs and tests --- .../artifact/databricks_artifact_repo.py | 46 +++++++++++++++---- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index 934261ed57197..afbcec2e1271b 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -13,7 +13,7 @@ from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, INTERNAL_ERROR from mlflow.protos.databricks_artifacts_pb2 import DatabricksMlflowArtifactsService, \ GetCredentialsForWrite, GetCredentialsForRead, ArtifactCredentialType -from mlflow.protos.service_pb2 import MlflowService, ListArtifacts +from mlflow.protos.service_pb2 import MlflowService, GetRun, ListArtifacts from mlflow.store.artifact.artifact_repo import ArtifactRepository from mlflow.utils.databricks_utils import get_databricks_host_creds from mlflow.utils.file_utils import relative_path_to_artifact_path, yield_file_in_chunks @@ -54,6 +54,16 @@ def __init__(self, artifact_uri): error_code=INVALID_PARAMETER_VALUE) self.run_id = self._extract_run_id(self.artifact_uri) + artifact_repo_root_path = extract_and_normalize_path(artifact_uri) + run_artifact_root_uri = self._get_run_artifact_root(self.run_id) + run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri) + if artifact_repo_root_path == run_artifact_root_path: + self.run_relative_artifact_repo_root_path = "" + else: + self.run_relative_artifact_repo_root_path = posixpath.relpath( + path=artifact_repo_root_path, start=run_artifact_root_path + ) + @staticmethod def _extract_run_id(artifact_uri): """ @@ -76,6 +86,12 @@ def _call_endpoint(self, service, api, json_body): return call_endpoint(get_databricks_host_creds(), endpoint, method, json_body, response_proto) + def _get_run_artifact_root(self, run_id): + json_body = message_to_json(GetRun(run_id=run_id)) + run_response = self._call_endpoint(MlflowService, + GetRun, json_body) + return run_response.run.info.artifact_uri + def _get_write_credentials(self, run_id, path=None): json_body = message_to_json(GetCredentialsForWrite(run_id=run_id, path=path)) return self._call_endpoint(DatabricksMlflowArtifactsService, @@ -187,8 +203,13 @@ def log_artifact(self, local_file, artifact_path=None): basename = os.path.basename(local_file) artifact_path = artifact_path or "" artifact_path = posixpath.join(artifact_path, basename) - write_credentials = self._get_write_credentials(self.run_id, artifact_path) - self._upload_to_cloud(write_credentials, local_file, artifact_path) + if len(artifact_path) > 0: + run_relative_artifact_path = posixpath.join( + self.run_relative_artifact_repo_root_path, artifact_path) + else: + run_relative_artifact_path = self.run_relative_artifact_repo_root_path + write_credentials = self._get_write_credentials(self.run_id, run_relative_artifact_path) + self._upload_to_cloud(write_credentials, local_file, run_relative_artifact_path) def log_artifacts(self, local_dir, artifact_path=None): artifact_path = artifact_path or "" @@ -203,7 +224,12 @@ def log_artifacts(self, local_dir, artifact_path=None): self.log_artifact(file_path, artifact_subdir) def list_artifacts(self, path=None): - json_body = message_to_json(ListArtifacts(run_id=self.run_id, path=path)) + if path: + run_relative_path = posixpath.join( + self.run_relative_artifact_repo_root_path, path) + else: + run_relative_path = self.run_relative_artifact_repo_root_path + json_body = message_to_json(ListArtifacts(run_id=self.run_id, path=run_relative_path)) artifact_list = self._call_endpoint(MlflowService, ListArtifacts, json_body).files # If `path` is a file, ListArtifacts returns a single list element with the # same name as `path`. The list_artifacts API expects us to return an empty list in this @@ -212,13 +238,17 @@ def list_artifacts(self, path=None): and not artifact_list[0].is_dir: return [] infos = list() - for file in artifact_list: - artifact_size = None if file.is_dir else file.file_size - infos.append(FileInfo(file.path, file.is_dir, artifact_size)) + for output_file in artifact_list: + file_rel_path = posixpath.relpath( + path=output_file.path, start=self.run_relative_artifact_repo_root_path) + artifact_size = None if output_file.is_dir else output_file.file_size + infos.append(FileInfo(file_rel_path, output_file.is_dir, artifact_size)) return infos def _download_file(self, remote_file_path, local_path): - read_credentials = self._get_read_credentials(self.run_id, remote_file_path) + run_relative_remote_file_path = posixpath.join( + self.run_relative_artifact_repo_root_path, remote_file_path) + read_credentials = self._get_read_credentials(self.run_id, run_relative_remote_file_path) self._download_from_cloud(read_credentials.credentials, local_path) def delete_artifacts(self, artifact_path=None): From d3fab4a9af52dd503d86f7c89fe32156d325226e Mon Sep 17 00:00:00 2001 From: Corey Zumar Date: Sun, 7 Jun 2020 18:38:36 -0700 Subject: [PATCH 16/28] Comment and simplification --- mlflow/store/artifact/databricks_artifact_repo.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index afbcec2e1271b..491736c624d61 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -54,15 +54,16 @@ def __init__(self, artifact_uri): error_code=INVALID_PARAMETER_VALUE) self.run_id = self._extract_run_id(self.artifact_uri) + # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute + # the path of `artifact_uri` relative to the MLflow Run's artifact root + # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact + # repository will be performed relative to this computed location artifact_repo_root_path = extract_and_normalize_path(artifact_uri) run_artifact_root_uri = self._get_run_artifact_root(self.run_id) run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri) - if artifact_repo_root_path == run_artifact_root_path: - self.run_relative_artifact_repo_root_path = "" - else: - self.run_relative_artifact_repo_root_path = posixpath.relpath( - path=artifact_repo_root_path, start=run_artifact_root_path - ) + self.run_relative_artifact_repo_root_path = posixpath.relpath( + path=artifact_repo_root_path, start=run_artifact_root_path + ) @staticmethod def _extract_run_id(artifact_uri): From 368973d59f269d7cdd96af70bfe29b7df787cbc6 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Sun, 7 Jun 2020 20:32:44 -0700 Subject: [PATCH 17/28] Special case for empty file upload to AWS --- mlflow/store/artifact/databricks_artifact_repo.py | 10 +++++++--- mlflow/store/artifact/dbfs_artifact_repo.py | 2 +- tests/store/artifact/test_databricks_artifact_repo.py | 1 - 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index 934261ed57197..02b20d192895a 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -139,9 +139,13 @@ def _aws_upload_file(self, credentials, local_file): try: headers = self._extract_headers_from_credentials(credentials.headers) signed_write_uri = credentials.signed_uri - with open(local_file, 'rb') as file: - put_request = requests.put(signed_write_uri, file, headers=headers) - put_request.raise_for_status() + # Putting an empty file in a request by reading file bytes gives 501 error. + if os.stat(local_file).st_size == 0: + put_request = requests.put(signed_write_uri, "", headers=headers) + else: + with open(local_file, 'rb') as file: + put_request = requests.put(signed_write_uri, file, headers=headers) + put_request.raise_for_status() except Exception as err: raise MlflowException(err) diff --git a/mlflow/store/artifact/dbfs_artifact_repo.py b/mlflow/store/artifact/dbfs_artifact_repo.py index 92c0a8b64c602..1f11effd16b5c 100644 --- a/mlflow/store/artifact/dbfs_artifact_repo.py +++ b/mlflow/store/artifact/dbfs_artifact_repo.py @@ -177,7 +177,7 @@ def dbfs_artifact_repo_factory(artifact_uri): raise MlflowException("DBFS URI must be of the form " "dbfs:/, but received {uri}".format(uri=artifact_uri)) if is_databricks_acled_artifacts_uri(artifact_uri): - return DatabricksArtifactRepository(artifact_uri) + return DatabricksArtifactRepository(cleaned_artifact_uri) elif mlflow.utils.databricks_utils.is_dbfs_fuse_available() \ and os.environ.get(USE_FUSE_ENV_VAR, "").lower() != "false" \ and not artifact_uri.startswith("dbfs:/databricks/mlflow-registry"): diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index 0045447631415..48ebba9b00b31 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -16,7 +16,6 @@ from mlflow.store.artifact.dbfs_artifact_repo import DatabricksArtifactRepository - @pytest.fixture() def databricks_artifact_repo(): return get_artifact_repository('dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifact') From 72427a303025b55480ca71cc9a70a9e01fca09ec Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Mon, 8 Jun 2020 12:58:53 -0700 Subject: [PATCH 18/28] Clean up and added tests for relative path --- .../artifact/databricks_artifact_repo.py | 7 +- .../artifact/test_databricks_artifact_repo.py | 112 +++++++++++------- 2 files changed, 77 insertions(+), 42 deletions(-) diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index 26f9dde14701d..405c2547ed769 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -40,7 +40,7 @@ class DatabricksArtifactRepository(ArtifactRepository): read and write files from/to this location. The artifact_uri is expected to be of the form - dbfs:/databricks/mlflow-tracking///artifacts/ + dbfs:/databricks/mlflow-tracking/// """ def __init__(self, artifact_uri): @@ -61,9 +61,12 @@ def __init__(self, artifact_uri): artifact_repo_root_path = extract_and_normalize_path(artifact_uri) run_artifact_root_uri = self._get_run_artifact_root(self.run_id) run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri) - self.run_relative_artifact_repo_root_path = posixpath.relpath( + run_relative_root_path = posixpath.relpath( path=artifact_repo_root_path, start=run_artifact_root_path ) + # If the paths are equal, then use empty string over "./" for ListArtifact compatibility. + self.run_relative_artifact_repo_root_path = \ + "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path @staticmethod def _extract_run_id(artifact_uri): diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index 48ebba9b00b31..2e1af3504cbb9 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -15,16 +15,27 @@ from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository from mlflow.store.artifact.dbfs_artifact_repo import DatabricksArtifactRepository - -@pytest.fixture() -def databricks_artifact_repo(): - return get_artifact_repository('dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifact') - - DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE = 'mlflow.store.artifact.databricks_artifact_repo' DATABRICKS_ARTIFACT_REPOSITORY = DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + \ ".DatabricksArtifactRepository" +MOCK_AZURE_SIGNED_URI = "http://this_is_a_mock_sas_for_azure" +MOCK_AWS_SIGNED_URI = "http://this_is_a_mock_presigned_uri_for_aws?" +MOCK_RUN_ID = "MOCK-RUN-ID" +MOCK_HEADERS = [ArtifactCredentialInfo.HttpHeader(name='Mock-Name1', value='Mock-Value1'), + ArtifactCredentialInfo.HttpHeader(name='Mock-Name2', value='Mock-Value2')] +MOCK_RUN_ROOT_URI = \ + "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts" + + +@pytest.fixture() +def databricks_artifact_repo(): + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ + as get_run_artifact_root_mock: + get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI + return get_artifact_repository( + "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts") + @pytest.fixture() def test_file(tmpdir): @@ -47,37 +58,57 @@ def test_dir(tmpdir): return tmpdir -MOCK_AZURE_SIGNED_URI = "http://this_is_a_mock_sas_for_azure" -MOCK_AWS_SIGNED_URI = "http://this_is_a_mock_presigned_uri_for_aws?" -MOCK_RUN_ID = "MOCK-RUN-ID" -MOCK_HEADERS = [ArtifactCredentialInfo.HttpHeader(name='Mock-Name1', value='Mock-Value1'), - ArtifactCredentialInfo.HttpHeader(name='Mock-Name2', value='Mock-Value2')] - - class TestDatabricksArtifactRepository(object): def test_init_validation_and_cleaning(self): - repo = get_artifact_repository('dbfs:/databricks/mlflow-tracking/EXP/RUN/artifact') - assert repo.artifact_uri == 'dbfs:/databricks/mlflow-tracking/EXP/RUN/artifact' - assert repo.run_id == 'RUN' - with pytest.raises(MlflowException): - DatabricksArtifactRepository('s3://test') - with pytest.raises(MlflowException): - DatabricksArtifactRepository('dbfs:/databricks/mlflow/EXP/RUN/artifact') + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ + as get_run_artifact_root_mock: + get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI + # Basic artifact uri + repo = get_artifact_repository( + 'dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts') + assert repo.artifact_uri == 'dbfs:/databricks/mlflow-tracking/' \ + 'MOCK-EXP/MOCK-RUN-ID/artifacts' + assert repo.run_id == MOCK_RUN_ID + assert repo.run_relative_artifact_repo_root_path == "" + + with pytest.raises(MlflowException): + DatabricksArtifactRepository('s3://test') + with pytest.raises(MlflowException): + DatabricksArtifactRepository('dbfs:/databricks/mlflow/EXP/RUN/artifact') + + @pytest.mark.parametrize("artifact_uri, expected_relative_path", [ + ('dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts', ''), + ('dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts/arty', 'arty'), + ('dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/awesome/path', '../awesome/path'), + ]) + def test_run_relative_artifact_repo_root_path(self, artifact_uri, expected_relative_path): + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ + as get_run_artifact_root_mock: + get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI + # Basic artifact uri + repo = get_artifact_repository(artifact_uri) + assert repo.artifact_uri == artifact_uri + assert repo.run_id == MOCK_RUN_ID + assert repo.run_relative_artifact_repo_root_path == expected_relative_path def test_extract_run_id(self): - expected_run_id = "RUN_ID" - repo = get_artifact_repository('dbfs:/databricks/mlflow-tracking/EXP/RUN_ID/artifact') - assert repo.run_id == expected_run_id - repo = get_artifact_repository('dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') - assert repo.run_id == expected_run_id - repo = get_artifact_repository( - 'dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/') - assert repo.run_id == expected_run_id - repo = get_artifact_repository( - 'dbfs:/databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//') - assert repo.run_id == expected_run_id + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ + as get_run_artifact_root_mock: + get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI + expected_run_id = "RUN_ID" + repo = get_artifact_repository('dbfs:/databricks/mlflow-tracking/EXP/RUN_ID/artifact') + assert repo.run_id == expected_run_id + repo = get_artifact_repository( + 'dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') + assert repo.run_id == expected_run_id + repo = get_artifact_repository( + 'dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/') + assert repo.run_id == expected_run_id + repo = get_artifact_repository( + 'dbfs:/databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//') + assert repo.run_id == expected_run_id - @pytest.mark.parametrize("artifact_path,expected_location", [ + @pytest.mark.parametrize("artifact_path, expected_location", [ (None, 'test.txt'), ('output', 'output/test.txt'), ('', 'test.txt'), @@ -89,7 +120,7 @@ def test_log_artifact_azure(self, databricks_artifact_repo, test_file, artifact_ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._azure_upload_file') \ as azure_upload_mock: mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, - type=ArtifactCredentialType.AZURE_SAS_URI, ) + type=ArtifactCredentialType.AZURE_SAS_URI) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto @@ -249,12 +280,13 @@ def test_list_artifacts(self, databricks_artifact_repo): artifacts = databricks_artifact_repo.list_artifacts('a.txt') assert len(artifacts) == 0 - @pytest.mark.parametrize("remote_file_path, local_path, cloud_credential_type", [ - ('test_file.txt', '', ArtifactCredentialType.AZURE_SAS_URI), - ('test_file.txt', None, ArtifactCredentialType.AZURE_SAS_URI), - ('output/test_file', None, ArtifactCredentialType.AZURE_SAS_URI), - ('test_file.txt', '', ArtifactCredentialType.AWS_PRESIGNED_URL), - ]) + @pytest.mark.parametrize( + "remote_file_path, local_path, cloud_credential_type", [ + ('test_file.txt', '', ArtifactCredentialType.AZURE_SAS_URI), + ('test_file.txt', None, ArtifactCredentialType.AZURE_SAS_URI), + ('output/test_file', None, ArtifactCredentialType.AZURE_SAS_URI), + ('test_file.txt', '', ArtifactCredentialType.AWS_PRESIGNED_URL), + ]) def test_databricks_download_file(self, databricks_artifact_repo, remote_file_path, local_path, cloud_credential_type): with mock.patch( @@ -262,7 +294,7 @@ def test_databricks_download_file(self, databricks_artifact_repo, remote_file_pa as read_credentials_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._download_from_cloud') \ - as download_mock: + as download_mock: mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, type=cloud_credential_type) read_credentials_response_proto = GetCredentialsForRead.Response( From f82c60738354ffc48ffd7ba602765dcaf8ce8b55 Mon Sep 17 00:00:00 2001 From: Corey Zumar Date: Tue, 9 Jun 2020 13:33:51 -0700 Subject: [PATCH 19/28] Page --- .../java/org/mlflow/api/proto/Service.java | 681 ++++++++++++++---- mlflow/protos/service.proto | 6 + mlflow/protos/service_pb2.py | 62 +- .../artifact/databricks_artifact_repo.py | 41 +- 4 files changed, 625 insertions(+), 165 deletions(-) diff --git a/mlflow/java/client/src/main/java/org/mlflow/api/proto/Service.java b/mlflow/java/client/src/main/java/org/mlflow/api/proto/Service.java index 35deeef3cadd0..590c1baa405e8 100644 --- a/mlflow/java/client/src/main/java/org/mlflow/api/proto/Service.java +++ b/mlflow/java/client/src/main/java/org/mlflow/api/proto/Service.java @@ -37538,6 +37538,32 @@ public interface ListArtifactsOrBuilder extends */ com.google.protobuf.ByteString getPathBytes(); + + /** + *
+     * Token indicating the page of artifact results to fetch
+     * 
+ * + * optional string page_token = 4; + */ + boolean hasPageToken(); + /** + *
+     * Token indicating the page of artifact results to fetch
+     * 
+ * + * optional string page_token = 4; + */ + java.lang.String getPageToken(); + /** + *
+     * Token indicating the page of artifact results to fetch
+     * 
+ * + * optional string page_token = 4; + */ + com.google.protobuf.ByteString + getPageTokenBytes(); } /** * Protobuf type {@code mlflow.ListArtifacts} @@ -37555,6 +37581,7 @@ private ListArtifacts() { runId_ = ""; runUuid_ = ""; path_ = ""; + pageToken_ = ""; } @java.lang.Override @@ -37599,6 +37626,12 @@ private ListArtifacts( runId_ = bs; break; } + case 34: { + com.google.protobuf.ByteString bs = input.readBytes(); + bitField0_ |= 0x00000008; + pageToken_ = bs; + break; + } default: { if (!parseUnknownField( input, unknownFields, extensionRegistry, tag)) { @@ -37704,6 +37737,32 @@ public interface ResponseOrBuilder extends */ org.mlflow.api.proto.Service.FileInfoOrBuilder getFilesOrBuilder( int index); + + /** + *
+       * Token that can be used to retrieve the next page of artifact results
+       * 
+ * + * optional string next_page_token = 3; + */ + boolean hasNextPageToken(); + /** + *
+       * Token that can be used to retrieve the next page of artifact results
+       * 
+ * + * optional string next_page_token = 3; + */ + java.lang.String getNextPageToken(); + /** + *
+       * Token that can be used to retrieve the next page of artifact results
+       * 
+ * + * optional string next_page_token = 3; + */ + com.google.protobuf.ByteString + getNextPageTokenBytes(); } /** * Protobuf type {@code mlflow.ListArtifacts.Response} @@ -37720,6 +37779,7 @@ private Response(com.google.protobuf.GeneratedMessageV3.Builder builder) { private Response() { rootUri_ = ""; files_ = java.util.Collections.emptyList(); + nextPageToken_ = ""; } @java.lang.Override @@ -37761,6 +37821,12 @@ private Response( input.readMessage(org.mlflow.api.proto.Service.FileInfo.PARSER, extensionRegistry)); break; } + case 26: { + com.google.protobuf.ByteString bs = input.readBytes(); + bitField0_ |= 0x00000002; + nextPageToken_ = bs; + break; + } default: { if (!parseUnknownField( input, unknownFields, extensionRegistry, tag)) { @@ -37906,6 +37972,60 @@ public org.mlflow.api.proto.Service.FileInfoOrBuilder getFilesOrBuilder( return files_.get(index); } + public static final int NEXT_PAGE_TOKEN_FIELD_NUMBER = 3; + private volatile java.lang.Object nextPageToken_; + /** + *
+       * Token that can be used to retrieve the next page of artifact results
+       * 
+ * + * optional string next_page_token = 3; + */ + public boolean hasNextPageToken() { + return ((bitField0_ & 0x00000002) == 0x00000002); + } + /** + *
+       * Token that can be used to retrieve the next page of artifact results
+       * 
+ * + * optional string next_page_token = 3; + */ + public java.lang.String getNextPageToken() { + java.lang.Object ref = nextPageToken_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + nextPageToken_ = s; + } + return s; + } + } + /** + *
+       * Token that can be used to retrieve the next page of artifact results
+       * 
+ * + * optional string next_page_token = 3; + */ + public com.google.protobuf.ByteString + getNextPageTokenBytes() { + java.lang.Object ref = nextPageToken_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + nextPageToken_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + private byte memoizedIsInitialized = -1; @java.lang.Override public final boolean isInitialized() { @@ -37926,6 +38046,9 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) for (int i = 0; i < files_.size(); i++) { output.writeMessage(2, files_.get(i)); } + if (((bitField0_ & 0x00000002) == 0x00000002)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 3, nextPageToken_); + } unknownFields.writeTo(output); } @@ -37942,6 +38065,9 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeMessageSize(2, files_.get(i)); } + if (((bitField0_ & 0x00000002) == 0x00000002)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(3, nextPageToken_); + } size += unknownFields.getSerializedSize(); memoizedSize = size; return size; @@ -37965,6 +38091,11 @@ public boolean equals(final java.lang.Object obj) { } result = result && getFilesList() .equals(other.getFilesList()); + result = result && (hasNextPageToken() == other.hasNextPageToken()); + if (hasNextPageToken()) { + result = result && getNextPageToken() + .equals(other.getNextPageToken()); + } result = result && unknownFields.equals(other.unknownFields); return result; } @@ -37984,6 +38115,10 @@ public int hashCode() { hash = (37 * hash) + FILES_FIELD_NUMBER; hash = (53 * hash) + getFilesList().hashCode(); } + if (hasNextPageToken()) { + hash = (37 * hash) + NEXT_PAGE_TOKEN_FIELD_NUMBER; + hash = (53 * hash) + getNextPageToken().hashCode(); + } hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; @@ -38126,6 +38261,8 @@ public Builder clear() { } else { filesBuilder_.clear(); } + nextPageToken_ = ""; + bitField0_ = (bitField0_ & ~0x00000004); return this; } @@ -38167,6 +38304,10 @@ public org.mlflow.api.proto.Service.ListArtifacts.Response buildPartial() { } else { result.files_ = filesBuilder_.build(); } + if (((from_bitField0_ & 0x00000004) == 0x00000004)) { + to_bitField0_ |= 0x00000002; + } + result.nextPageToken_ = nextPageToken_; result.bitField0_ = to_bitField0_; onBuilt(); return result; @@ -38247,6 +38388,11 @@ public Builder mergeFrom(org.mlflow.api.proto.Service.ListArtifacts.Response oth } } } + if (other.hasNextPageToken()) { + bitField0_ |= 0x00000004; + nextPageToken_ = other.nextPageToken_; + onChanged(); + } this.mergeUnknownFields(other.unknownFields); onChanged(); return this; @@ -38688,6 +38834,106 @@ public org.mlflow.api.proto.Service.FileInfo.Builder addFilesBuilder( } return filesBuilder_; } + + private java.lang.Object nextPageToken_ = ""; + /** + *
+         * Token that can be used to retrieve the next page of artifact results
+         * 
+ * + * optional string next_page_token = 3; + */ + public boolean hasNextPageToken() { + return ((bitField0_ & 0x00000004) == 0x00000004); + } + /** + *
+         * Token that can be used to retrieve the next page of artifact results
+         * 
+ * + * optional string next_page_token = 3; + */ + public java.lang.String getNextPageToken() { + java.lang.Object ref = nextPageToken_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + nextPageToken_ = s; + } + return s; + } else { + return (java.lang.String) ref; + } + } + /** + *
+         * Token that can be used to retrieve the next page of artifact results
+         * 
+ * + * optional string next_page_token = 3; + */ + public com.google.protobuf.ByteString + getNextPageTokenBytes() { + java.lang.Object ref = nextPageToken_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + nextPageToken_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + *
+         * Token that can be used to retrieve the next page of artifact results
+         * 
+ * + * optional string next_page_token = 3; + */ + public Builder setNextPageToken( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000004; + nextPageToken_ = value; + onChanged(); + return this; + } + /** + *
+         * Token that can be used to retrieve the next page of artifact results
+         * 
+ * + * optional string next_page_token = 3; + */ + public Builder clearNextPageToken() { + bitField0_ = (bitField0_ & ~0x00000004); + nextPageToken_ = getDefaultInstance().getNextPageToken(); + onChanged(); + return this; + } + /** + *
+         * Token that can be used to retrieve the next page of artifact results
+         * 
+ * + * optional string next_page_token = 3; + */ + public Builder setNextPageTokenBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000004; + nextPageToken_ = value; + onChanged(); + return this; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -38907,6 +39153,60 @@ public java.lang.String getPath() { } } + public static final int PAGE_TOKEN_FIELD_NUMBER = 4; + private volatile java.lang.Object pageToken_; + /** + *
+     * Token indicating the page of artifact results to fetch
+     * 
+ * + * optional string page_token = 4; + */ + public boolean hasPageToken() { + return ((bitField0_ & 0x00000008) == 0x00000008); + } + /** + *
+     * Token indicating the page of artifact results to fetch
+     * 
+ * + * optional string page_token = 4; + */ + public java.lang.String getPageToken() { + java.lang.Object ref = pageToken_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + pageToken_ = s; + } + return s; + } + } + /** + *
+     * Token indicating the page of artifact results to fetch
+     * 
+ * + * optional string page_token = 4; + */ + public com.google.protobuf.ByteString + getPageTokenBytes() { + java.lang.Object ref = pageToken_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + pageToken_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + private byte memoizedIsInitialized = -1; @java.lang.Override public final boolean isInitialized() { @@ -38930,6 +39230,9 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (((bitField0_ & 0x00000001) == 0x00000001)) { com.google.protobuf.GeneratedMessageV3.writeString(output, 3, runId_); } + if (((bitField0_ & 0x00000008) == 0x00000008)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 4, pageToken_); + } unknownFields.writeTo(output); } @@ -38948,6 +39251,9 @@ public int getSerializedSize() { if (((bitField0_ & 0x00000001) == 0x00000001)) { size += com.google.protobuf.GeneratedMessageV3.computeStringSize(3, runId_); } + if (((bitField0_ & 0x00000008) == 0x00000008)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(4, pageToken_); + } size += unknownFields.getSerializedSize(); memoizedSize = size; return size; @@ -38979,6 +39285,11 @@ public boolean equals(final java.lang.Object obj) { result = result && getPath() .equals(other.getPath()); } + result = result && (hasPageToken() == other.hasPageToken()); + if (hasPageToken()) { + result = result && getPageToken() + .equals(other.getPageToken()); + } result = result && unknownFields.equals(other.unknownFields); return result; } @@ -39002,6 +39313,10 @@ public int hashCode() { hash = (37 * hash) + PATH_FIELD_NUMBER; hash = (53 * hash) + getPath().hashCode(); } + if (hasPageToken()) { + hash = (37 * hash) + PAGE_TOKEN_FIELD_NUMBER; + hash = (53 * hash) + getPageToken().hashCode(); + } hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; @@ -39141,6 +39456,8 @@ public Builder clear() { bitField0_ = (bitField0_ & ~0x00000002); path_ = ""; bitField0_ = (bitField0_ & ~0x00000004); + pageToken_ = ""; + bitField0_ = (bitField0_ & ~0x00000008); return this; } @@ -39181,6 +39498,10 @@ public org.mlflow.api.proto.Service.ListArtifacts buildPartial() { to_bitField0_ |= 0x00000004; } result.path_ = path_; + if (((from_bitField0_ & 0x00000008) == 0x00000008)) { + to_bitField0_ |= 0x00000008; + } + result.pageToken_ = pageToken_; result.bitField0_ = to_bitField0_; onBuilt(); return result; @@ -39245,6 +39566,11 @@ public Builder mergeFrom(org.mlflow.api.proto.Service.ListArtifacts other) { path_ = other.path_; onChanged(); } + if (other.hasPageToken()) { + bitField0_ |= 0x00000008; + pageToken_ = other.pageToken_; + onChanged(); + } this.mergeUnknownFields(other.unknownFields); onChanged(); return this; @@ -39580,6 +39906,106 @@ public Builder setPathBytes( onChanged(); return this; } + + private java.lang.Object pageToken_ = ""; + /** + *
+       * Token indicating the page of artifact results to fetch
+       * 
+ * + * optional string page_token = 4; + */ + public boolean hasPageToken() { + return ((bitField0_ & 0x00000008) == 0x00000008); + } + /** + *
+       * Token indicating the page of artifact results to fetch
+       * 
+ * + * optional string page_token = 4; + */ + public java.lang.String getPageToken() { + java.lang.Object ref = pageToken_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + pageToken_ = s; + } + return s; + } else { + return (java.lang.String) ref; + } + } + /** + *
+       * Token indicating the page of artifact results to fetch
+       * 
+ * + * optional string page_token = 4; + */ + public com.google.protobuf.ByteString + getPageTokenBytes() { + java.lang.Object ref = pageToken_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + pageToken_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + *
+       * Token indicating the page of artifact results to fetch
+       * 
+ * + * optional string page_token = 4; + */ + public Builder setPageToken( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000008; + pageToken_ = value; + onChanged(); + return this; + } + /** + *
+       * Token indicating the page of artifact results to fetch
+       * 
+ * + * optional string page_token = 4; + */ + public Builder clearPageToken() { + bitField0_ = (bitField0_ & ~0x00000008); + pageToken_ = getDefaultInstance().getPageToken(); + onChanged(); + return this; + } + /** + *
+       * Token indicating the page of artifact results to fetch
+       * 
+ * + * optional string page_token = 4; + */ + public Builder setPageTokenBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000008; + pageToken_ = value; + onChanged(); + return this; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -47960,133 +48386,134 @@ public org.mlflow.api.proto.Service.GetExperimentByName getDefaultInstanceForTyp ":\0041000\022\020\n\010order_by\030\006 \003(\t\022\022\n\npage_token\030\007" + " \001(\t\032>\n\010Response\022\031\n\004runs\030\001 \003(\0132\013.mlflow." + "Run\022\027\n\017next_page_token\030\002 \001(\t:+\342?(\n&com.d" + - "atabricks.rpc.RPC[$this.Response]\"\253\001\n\rLi" + + "atabricks.rpc.RPC[$this.Response]\"\330\001\n\rLi" + "stArtifacts\022\016\n\006run_id\030\003 \001(\t\022\020\n\010run_uuid\030" + - "\001 \001(\t\022\014\n\004path\030\002 \001(\t\032=\n\010Response\022\020\n\010root_" + - "uri\030\001 \001(\t\022\037\n\005files\030\002 \003(\0132\020.mlflow.FileIn" + - "fo:+\342?(\n&com.databricks.rpc.RPC[$this.Re" + - "sponse]\";\n\010FileInfo\022\014\n\004path\030\001 \001(\t\022\016\n\006is_" + - "dir\030\002 \001(\010\022\021\n\tfile_size\030\003 \001(\003\"\250\001\n\020GetMetr" + - "icHistory\022\016\n\006run_id\030\003 \001(\t\022\020\n\010run_uuid\030\001 " + - "\001(\t\022\030\n\nmetric_key\030\002 \001(\tB\004\370\206\031\001\032+\n\010Respons" + - "e\022\037\n\007metrics\030\001 \003(\0132\016.mlflow.Metric:+\342?(\n" + - "&com.databricks.rpc.RPC[$this.Response]\"" + - "\261\001\n\010LogBatch\022\016\n\006run_id\030\001 \001(\t\022\037\n\007metrics\030" + - "\002 \003(\0132\016.mlflow.Metric\022\035\n\006params\030\003 \003(\0132\r." + - "mlflow.Param\022\034\n\004tags\030\004 \003(\0132\016.mlflow.RunT" + - "ag\032\n\n\010Response:+\342?(\n&com.databricks.rpc." + - "RPC[$this.Response]\"g\n\010LogModel\022\016\n\006run_i" + - "d\030\001 \001(\t\022\022\n\nmodel_json\030\002 \001(\t\032\n\n\010Response:" + + "\001 \001(\t\022\014\n\004path\030\002 \001(\t\022\022\n\npage_token\030\004 \001(\t\032" + + "V\n\010Response\022\020\n\010root_uri\030\001 \001(\t\022\037\n\005files\030\002" + + " \003(\0132\020.mlflow.FileInfo\022\027\n\017next_page_toke" + + "n\030\003 \001(\t:+\342?(\n&com.databricks.rpc.RPC[$th" + + "is.Response]\";\n\010FileInfo\022\014\n\004path\030\001 \001(\t\022\016" + + "\n\006is_dir\030\002 \001(\010\022\021\n\tfile_size\030\003 \001(\003\"\250\001\n\020Ge" + + "tMetricHistory\022\016\n\006run_id\030\003 \001(\t\022\020\n\010run_uu" + + "id\030\001 \001(\t\022\030\n\nmetric_key\030\002 \001(\tB\004\370\206\031\001\032+\n\010Re" + + "sponse\022\037\n\007metrics\030\001 \003(\0132\016.mlflow.Metric:" + "+\342?(\n&com.databricks.rpc.RPC[$this.Respo" + - "nse]\"\225\001\n\023GetExperimentByName\022\035\n\017experime" + - "nt_name\030\001 \001(\tB\004\370\206\031\001\0322\n\010Response\022&\n\nexper" + - "iment\030\001 \001(\0132\022.mlflow.Experiment:+\342?(\n&co" + - "m.databricks.rpc.RPC[$this.Response]*6\n\010" + - "ViewType\022\017\n\013ACTIVE_ONLY\020\001\022\020\n\014DELETED_ONL" + - "Y\020\002\022\007\n\003ALL\020\003*I\n\nSourceType\022\014\n\010NOTEBOOK\020\001" + - "\022\007\n\003JOB\020\002\022\013\n\007PROJECT\020\003\022\t\n\005LOCAL\020\004\022\014\n\007UNK" + - "NOWN\020\350\007*M\n\tRunStatus\022\013\n\007RUNNING\020\001\022\r\n\tSCH" + - "EDULED\020\002\022\014\n\010FINISHED\020\003\022\n\n\006FAILED\020\004\022\n\n\006KI" + - "LLED\020\0052\341\036\n\rMlflowService\022\246\001\n\023getExperime" + - "ntByName\022\033.mlflow.GetExperimentByName\032$." + - "mlflow.GetExperimentByName.Response\"L\362\206\031" + - "H\n,\n\003GET\022\037/mlflow/experiments/get-by-nam" + - "e\032\004\010\002\020\000\020\001*\026Get Experiment By Name\022\306\001\n\020cr" + - "eateExperiment\022\030.mlflow.CreateExperiment" + - "\032!.mlflow.CreateExperiment.Response\"u\362\206\031" + - "q\n(\n\004POST\022\032/mlflow/experiments/create\032\004\010" + - "\002\020\000\n0\n\004POST\022\"/preview/mlflow/experiments" + - "/create\032\004\010\002\020\000\020\001*\021Create Experiment\022\274\001\n\017l" + - "istExperiments\022\027.mlflow.ListExperiments\032" + - " .mlflow.ListExperiments.Response\"n\362\206\031j\n" + - "%\n\003GET\022\030/mlflow/experiments/list\032\004\010\002\020\000\n-" + - "\n\003GET\022 /preview/mlflow/experiments/list\032" + - "\004\010\002\020\000\020\001*\020List Experiments\022\262\001\n\rgetExperim" + - "ent\022\025.mlflow.GetExperiment\032\036.mlflow.GetE" + - "xperiment.Response\"j\362\206\031f\n$\n\003GET\022\027/mlflow" + - "/experiments/get\032\004\010\002\020\000\n,\n\003GET\022\037/preview/" + - "mlflow/experiments/get\032\004\010\002\020\000\020\001*\016Get Expe" + - "riment\022\306\001\n\020deleteExperiment\022\030.mlflow.Del" + - "eteExperiment\032!.mlflow.DeleteExperiment." + - "Response\"u\362\206\031q\n(\n\004POST\022\032/mlflow/experime" + - "nts/delete\032\004\010\002\020\000\n0\n\004POST\022\"/preview/mlflo" + - "w/experiments/delete\032\004\010\002\020\000\020\001*\021Delete Exp" + - "eriment\022\314\001\n\021restoreExperiment\022\031.mlflow.R" + - "estoreExperiment\032\".mlflow.RestoreExperim" + - "ent.Response\"x\362\206\031t\n)\n\004POST\022\033/mlflow/expe" + - "riments/restore\032\004\010\002\020\000\n1\n\004POST\022#/preview/" + - "mlflow/experiments/restore\032\004\010\002\020\000\020\001*\022Rest" + - "ore Experiment\022\306\001\n\020updateExperiment\022\030.ml" + - "flow.UpdateExperiment\032!.mlflow.UpdateExp" + - "eriment.Response\"u\362\206\031q\n(\n\004POST\022\032/mlflow/" + - "experiments/update\032\004\010\002\020\000\n0\n\004POST\022\"/previ" + - "ew/mlflow/experiments/update\032\004\010\002\020\000\020\001*\021Up" + - "date Experiment\022\234\001\n\tcreateRun\022\021.mlflow.C" + - "reateRun\032\032.mlflow.CreateRun.Response\"`\362\206" + - "\031\\\n!\n\004POST\022\023/mlflow/runs/create\032\004\010\002\020\000\n)\n" + - "\004POST\022\033/preview/mlflow/runs/create\032\004\010\002\020\000" + - "\020\001*\nCreate Run\022\234\001\n\tupdateRun\022\021.mlflow.Up" + - "dateRun\032\032.mlflow.UpdateRun.Response\"`\362\206\031" + - "\\\n!\n\004POST\022\023/mlflow/runs/update\032\004\010\002\020\000\n)\n\004" + - "POST\022\033/preview/mlflow/runs/update\032\004\010\002\020\000\020" + - "\001*\nUpdate Run\022\234\001\n\tdeleteRun\022\021.mlflow.Del" + - "eteRun\032\032.mlflow.DeleteRun.Response\"`\362\206\031\\" + - "\n!\n\004POST\022\023/mlflow/runs/delete\032\004\010\002\020\000\n)\n\004P" + - "OST\022\033/preview/mlflow/runs/delete\032\004\010\002\020\000\020\001" + - "*\nDelete Run\022\242\001\n\nrestoreRun\022\022.mlflow.Res" + - "toreRun\032\033.mlflow.RestoreRun.Response\"c\362\206" + - "\031_\n\"\n\004POST\022\024/mlflow/runs/restore\032\004\010\002\020\000\n*" + - "\n\004POST\022\034/preview/mlflow/runs/restore\032\004\010\002" + - "\020\000\020\001*\013Restore Run\022\244\001\n\tlogMetric\022\021.mlflow" + - ".LogMetric\032\032.mlflow.LogMetric.Response\"h" + - "\362\206\031d\n%\n\004POST\022\027/mlflow/runs/log-metric\032\004\010" + - "\002\020\000\n-\n\004POST\022\037/preview/mlflow/runs/log-me" + - "tric\032\004\010\002\020\000\020\001*\nLog Metric\022\246\001\n\010logParam\022\020." + - "mlflow.LogParam\032\031.mlflow.LogParam.Respon" + - "se\"m\362\206\031i\n(\n\004POST\022\032/mlflow/runs/log-param" + - "eter\032\004\010\002\020\000\n0\n\004POST\022\"/preview/mlflow/runs" + - "/log-parameter\032\004\010\002\020\000\020\001*\tLog Param\022\341\001\n\020se" + - "tExperimentTag\022\030.mlflow.SetExperimentTag" + - "\032!.mlflow.SetExperimentTag.Response\"\217\001\362\206" + - "\031\212\001\n4\n\004POST\022&/mlflow/experiments/set-exp" + - "eriment-tag\032\004\010\002\020\000\n<\n\004POST\022./preview/mlfl" + - "ow/experiments/set-experiment-tag\032\004\010\002\020\000\020" + - "\001*\022Set Experiment Tag\022\222\001\n\006setTag\022\016.mlflo" + - "w.SetTag\032\027.mlflow.SetTag.Response\"_\362\206\031[\n" + - "\"\n\004POST\022\024/mlflow/runs/set-tag\032\004\010\002\020\000\n*\n\004P" + - "OST\022\034/preview/mlflow/runs/set-tag\032\004\010\002\020\000\020" + - "\001*\007Set Tag\022\244\001\n\tdeleteTag\022\021.mlflow.Delete" + - "Tag\032\032.mlflow.DeleteTag.Response\"h\362\206\031d\n%\n" + - "\004POST\022\027/mlflow/runs/delete-tag\032\004\010\002\020\000\n-\n\004" + - "POST\022\037/preview/mlflow/runs/delete-tag\032\004\010" + - "\002\020\000\020\001*\nDelete Tag\022\210\001\n\006getRun\022\016.mlflow.Ge" + - "tRun\032\027.mlflow.GetRun.Response\"U\362\206\031Q\n\035\n\003G" + - "ET\022\020/mlflow/runs/get\032\004\010\002\020\000\n%\n\003GET\022\030/prev" + - "iew/mlflow/runs/get\032\004\010\002\020\000\020\001*\007Get Run\022\314\001\n" + - "\nsearchRuns\022\022.mlflow.SearchRuns\032\033.mlflow" + - ".SearchRuns.Response\"\214\001\362\206\031\207\001\n!\n\004POST\022\023/m" + - "lflow/runs/search\032\004\010\002\020\000\n)\n\004POST\022\033/previe" + - "w/mlflow/runs/search\032\004\010\002\020\000\n(\n\003GET\022\033/prev" + - "iew/mlflow/runs/search\032\004\010\002\020\000\020\001*\013Search R" + - "uns\022\260\001\n\rlistArtifacts\022\025.mlflow.ListArtif" + - "acts\032\036.mlflow.ListArtifacts.Response\"h\362\206" + - "\031d\n#\n\003GET\022\026/mlflow/artifacts/list\032\004\010\002\020\000\n" + - "+\n\003GET\022\036/preview/mlflow/artifacts/list\032\004" + - "\010\002\020\000\020\001*\016List Artifacts\022\307\001\n\020getMetricHist" + - "ory\022\030.mlflow.GetMetricHistory\032!.mlflow.G" + - "etMetricHistory.Response\"v\362\206\031r\n(\n\003GET\022\033/" + - "mlflow/metrics/get-history\032\004\010\002\020\000\n0\n\003GET\022" + - "#/preview/mlflow/metrics/get-history\032\004\010\002" + - "\020\000\020\001*\022Get Metric History\022\236\001\n\010logBatch\022\020." + - "mlflow.LogBatch\032\031.mlflow.LogBatch.Respon" + - "se\"e\362\206\031a\n$\n\004POST\022\026/mlflow/runs/log-batch" + - "\032\004\010\002\020\000\n,\n\004POST\022\036/preview/mlflow/runs/log" + - "-batch\032\004\010\002\020\000\020\001*\tLog Batch\022\236\001\n\010logModel\022\020" + - ".mlflow.LogModel\032\031.mlflow.LogModel.Respo" + - "nse\"e\362\206\031a\n$\n\004POST\022\026/mlflow/runs/log-mode" + - "l\032\004\010\002\020\000\n,\n\004POST\022\036/preview/mlflow/runs/lo" + - "g-model\032\004\010\002\020\000\020\001*\tLog ModelB\036\n\024org.mlflow" + - ".api.proto\220\001\001\342?\002\020\001" + "nse]\"\261\001\n\010LogBatch\022\016\n\006run_id\030\001 \001(\t\022\037\n\007met" + + "rics\030\002 \003(\0132\016.mlflow.Metric\022\035\n\006params\030\003 \003" + + "(\0132\r.mlflow.Param\022\034\n\004tags\030\004 \003(\0132\016.mlflow" + + ".RunTag\032\n\n\010Response:+\342?(\n&com.databricks" + + ".rpc.RPC[$this.Response]\"g\n\010LogModel\022\016\n\006" + + "run_id\030\001 \001(\t\022\022\n\nmodel_json\030\002 \001(\t\032\n\n\010Resp" + + "onse:+\342?(\n&com.databricks.rpc.RPC[$this." + + "Response]\"\225\001\n\023GetExperimentByName\022\035\n\017exp" + + "eriment_name\030\001 \001(\tB\004\370\206\031\001\0322\n\010Response\022&\n\n" + + "experiment\030\001 \001(\0132\022.mlflow.Experiment:+\342?" + + "(\n&com.databricks.rpc.RPC[$this.Response" + + "]*6\n\010ViewType\022\017\n\013ACTIVE_ONLY\020\001\022\020\n\014DELETE" + + "D_ONLY\020\002\022\007\n\003ALL\020\003*I\n\nSourceType\022\014\n\010NOTEB" + + "OOK\020\001\022\007\n\003JOB\020\002\022\013\n\007PROJECT\020\003\022\t\n\005LOCAL\020\004\022\014" + + "\n\007UNKNOWN\020\350\007*M\n\tRunStatus\022\013\n\007RUNNING\020\001\022\r" + + "\n\tSCHEDULED\020\002\022\014\n\010FINISHED\020\003\022\n\n\006FAILED\020\004\022" + + "\n\n\006KILLED\020\0052\341\036\n\rMlflowService\022\246\001\n\023getExp" + + "erimentByName\022\033.mlflow.GetExperimentByNa" + + "me\032$.mlflow.GetExperimentByName.Response" + + "\"L\362\206\031H\n,\n\003GET\022\037/mlflow/experiments/get-b" + + "y-name\032\004\010\002\020\000\020\001*\026Get Experiment By Name\022\306" + + "\001\n\020createExperiment\022\030.mlflow.CreateExper" + + "iment\032!.mlflow.CreateExperiment.Response" + + "\"u\362\206\031q\n(\n\004POST\022\032/mlflow/experiments/crea" + + "te\032\004\010\002\020\000\n0\n\004POST\022\"/preview/mlflow/experi" + + "ments/create\032\004\010\002\020\000\020\001*\021Create Experiment\022" + + "\274\001\n\017listExperiments\022\027.mlflow.ListExperim" + + "ents\032 .mlflow.ListExperiments.Response\"n" + + "\362\206\031j\n%\n\003GET\022\030/mlflow/experiments/list\032\004\010" + + "\002\020\000\n-\n\003GET\022 /preview/mlflow/experiments/" + + "list\032\004\010\002\020\000\020\001*\020List Experiments\022\262\001\n\rgetEx" + + "periment\022\025.mlflow.GetExperiment\032\036.mlflow" + + ".GetExperiment.Response\"j\362\206\031f\n$\n\003GET\022\027/m" + + "lflow/experiments/get\032\004\010\002\020\000\n,\n\003GET\022\037/pre" + + "view/mlflow/experiments/get\032\004\010\002\020\000\020\001*\016Get" + + " Experiment\022\306\001\n\020deleteExperiment\022\030.mlflo" + + "w.DeleteExperiment\032!.mlflow.DeleteExperi" + + "ment.Response\"u\362\206\031q\n(\n\004POST\022\032/mlflow/exp" + + "eriments/delete\032\004\010\002\020\000\n0\n\004POST\022\"/preview/" + + "mlflow/experiments/delete\032\004\010\002\020\000\020\001*\021Delet" + + "e Experiment\022\314\001\n\021restoreExperiment\022\031.mlf" + + "low.RestoreExperiment\032\".mlflow.RestoreEx" + + "periment.Response\"x\362\206\031t\n)\n\004POST\022\033/mlflow" + + "/experiments/restore\032\004\010\002\020\000\n1\n\004POST\022#/pre" + + "view/mlflow/experiments/restore\032\004\010\002\020\000\020\001*" + + "\022Restore Experiment\022\306\001\n\020updateExperiment" + + "\022\030.mlflow.UpdateExperiment\032!.mlflow.Upda" + + "teExperiment.Response\"u\362\206\031q\n(\n\004POST\022\032/ml" + + "flow/experiments/update\032\004\010\002\020\000\n0\n\004POST\022\"/" + + "preview/mlflow/experiments/update\032\004\010\002\020\000\020" + + "\001*\021Update Experiment\022\234\001\n\tcreateRun\022\021.mlf" + + "low.CreateRun\032\032.mlflow.CreateRun.Respons" + + "e\"`\362\206\031\\\n!\n\004POST\022\023/mlflow/runs/create\032\004\010\002" + + "\020\000\n)\n\004POST\022\033/preview/mlflow/runs/create\032" + + "\004\010\002\020\000\020\001*\nCreate Run\022\234\001\n\tupdateRun\022\021.mlfl" + + "ow.UpdateRun\032\032.mlflow.UpdateRun.Response" + + "\"`\362\206\031\\\n!\n\004POST\022\023/mlflow/runs/update\032\004\010\002\020" + + "\000\n)\n\004POST\022\033/preview/mlflow/runs/update\032\004" + + "\010\002\020\000\020\001*\nUpdate Run\022\234\001\n\tdeleteRun\022\021.mlflo" + + "w.DeleteRun\032\032.mlflow.DeleteRun.Response\"" + + "`\362\206\031\\\n!\n\004POST\022\023/mlflow/runs/delete\032\004\010\002\020\000" + + "\n)\n\004POST\022\033/preview/mlflow/runs/delete\032\004\010" + + "\002\020\000\020\001*\nDelete Run\022\242\001\n\nrestoreRun\022\022.mlflo" + + "w.RestoreRun\032\033.mlflow.RestoreRun.Respons" + + "e\"c\362\206\031_\n\"\n\004POST\022\024/mlflow/runs/restore\032\004\010" + + "\002\020\000\n*\n\004POST\022\034/preview/mlflow/runs/restor" + + "e\032\004\010\002\020\000\020\001*\013Restore Run\022\244\001\n\tlogMetric\022\021.m" + + "lflow.LogMetric\032\032.mlflow.LogMetric.Respo" + + "nse\"h\362\206\031d\n%\n\004POST\022\027/mlflow/runs/log-metr" + + "ic\032\004\010\002\020\000\n-\n\004POST\022\037/preview/mlflow/runs/l" + + "og-metric\032\004\010\002\020\000\020\001*\nLog Metric\022\246\001\n\010logPar" + + "am\022\020.mlflow.LogParam\032\031.mlflow.LogParam.R" + + "esponse\"m\362\206\031i\n(\n\004POST\022\032/mlflow/runs/log-" + + "parameter\032\004\010\002\020\000\n0\n\004POST\022\"/preview/mlflow" + + "/runs/log-parameter\032\004\010\002\020\000\020\001*\tLog Param\022\341" + + "\001\n\020setExperimentTag\022\030.mlflow.SetExperime" + + "ntTag\032!.mlflow.SetExperimentTag.Response" + + "\"\217\001\362\206\031\212\001\n4\n\004POST\022&/mlflow/experiments/se" + + "t-experiment-tag\032\004\010\002\020\000\n<\n\004POST\022./preview" + + "/mlflow/experiments/set-experiment-tag\032\004" + + "\010\002\020\000\020\001*\022Set Experiment Tag\022\222\001\n\006setTag\022\016." + + "mlflow.SetTag\032\027.mlflow.SetTag.Response\"_" + + "\362\206\031[\n\"\n\004POST\022\024/mlflow/runs/set-tag\032\004\010\002\020\000" + + "\n*\n\004POST\022\034/preview/mlflow/runs/set-tag\032\004" + + "\010\002\020\000\020\001*\007Set Tag\022\244\001\n\tdeleteTag\022\021.mlflow.D" + + "eleteTag\032\032.mlflow.DeleteTag.Response\"h\362\206" + + "\031d\n%\n\004POST\022\027/mlflow/runs/delete-tag\032\004\010\002\020" + + "\000\n-\n\004POST\022\037/preview/mlflow/runs/delete-t" + + "ag\032\004\010\002\020\000\020\001*\nDelete Tag\022\210\001\n\006getRun\022\016.mlfl" + + "ow.GetRun\032\027.mlflow.GetRun.Response\"U\362\206\031Q" + + "\n\035\n\003GET\022\020/mlflow/runs/get\032\004\010\002\020\000\n%\n\003GET\022\030" + + "/preview/mlflow/runs/get\032\004\010\002\020\000\020\001*\007Get Ru" + + "n\022\314\001\n\nsearchRuns\022\022.mlflow.SearchRuns\032\033.m" + + "lflow.SearchRuns.Response\"\214\001\362\206\031\207\001\n!\n\004POS" + + "T\022\023/mlflow/runs/search\032\004\010\002\020\000\n)\n\004POST\022\033/p" + + "review/mlflow/runs/search\032\004\010\002\020\000\n(\n\003GET\022\033" + + "/preview/mlflow/runs/search\032\004\010\002\020\000\020\001*\013Sea" + + "rch Runs\022\260\001\n\rlistArtifacts\022\025.mlflow.List" + + "Artifacts\032\036.mlflow.ListArtifacts.Respons" + + "e\"h\362\206\031d\n#\n\003GET\022\026/mlflow/artifacts/list\032\004" + + "\010\002\020\000\n+\n\003GET\022\036/preview/mlflow/artifacts/l" + + "ist\032\004\010\002\020\000\020\001*\016List Artifacts\022\307\001\n\020getMetri" + + "cHistory\022\030.mlflow.GetMetricHistory\032!.mlf" + + "low.GetMetricHistory.Response\"v\362\206\031r\n(\n\003G" + + "ET\022\033/mlflow/metrics/get-history\032\004\010\002\020\000\n0\n" + + "\003GET\022#/preview/mlflow/metrics/get-histor" + + "y\032\004\010\002\020\000\020\001*\022Get Metric History\022\236\001\n\010logBat" + + "ch\022\020.mlflow.LogBatch\032\031.mlflow.LogBatch.R" + + "esponse\"e\362\206\031a\n$\n\004POST\022\026/mlflow/runs/log-" + + "batch\032\004\010\002\020\000\n,\n\004POST\022\036/preview/mlflow/run" + + "s/log-batch\032\004\010\002\020\000\020\001*\tLog Batch\022\236\001\n\010logMo" + + "del\022\020.mlflow.LogModel\032\031.mlflow.LogModel." + + "Response\"e\362\206\031a\n$\n\004POST\022\026/mlflow/runs/log" + + "-model\032\004\010\002\020\000\n,\n\004POST\022\036/preview/mlflow/ru" + + "ns/log-model\032\004\010\002\020\000\020\001*\tLog ModelB\036\n\024org.m" + + "lflow.api.proto\220\001\001\342?\002\020\001" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { @@ -48359,13 +48786,13 @@ public com.google.protobuf.ExtensionRegistry assignDescriptors( internal_static_mlflow_ListArtifacts_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_mlflow_ListArtifacts_descriptor, - new java.lang.String[] { "RunId", "RunUuid", "Path", }); + new java.lang.String[] { "RunId", "RunUuid", "Path", "PageToken", }); internal_static_mlflow_ListArtifacts_Response_descriptor = internal_static_mlflow_ListArtifacts_descriptor.getNestedTypes().get(0); internal_static_mlflow_ListArtifacts_Response_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_mlflow_ListArtifacts_Response_descriptor, - new java.lang.String[] { "RootUri", "Files", }); + new java.lang.String[] { "RootUri", "Files", "NextPageToken", }); internal_static_mlflow_FileInfo_descriptor = getDescriptor().getMessageTypes().get(26); internal_static_mlflow_FileInfo_fieldAccessorTable = new diff --git a/mlflow/protos/service.proto b/mlflow/protos/service.proto index 335f8304957c6..1cdb78563a5f7 100644 --- a/mlflow/protos/service.proto +++ b/mlflow/protos/service.proto @@ -972,12 +972,18 @@ message ListArtifacts { // Filter artifacts matching this path (a relative path from the root artifact directory). optional string path = 2; + // Token indicating the page of artifact results to fetch + optional string page_token = 4; + message Response { // Root artifact directory for the run. optional string root_uri = 1; // File location and metadata for artifacts. repeated FileInfo files = 2; + + // Token that can be used to retrieve the next page of artifact results + optional string next_page_token = 3; } } diff --git a/mlflow/protos/service_pb2.py b/mlflow/protos/service_pb2.py index 991bd1c852f03..f39110ef0266d 100644 --- a/mlflow/protos/service_pb2.py +++ b/mlflow/protos/service_pb2.py @@ -24,7 +24,7 @@ package='mlflow', syntax='proto2', serialized_options=_b('\n\024org.mlflow.api.proto\220\001\001\342?\002\020\001'), - serialized_pb=_b('\n\rservice.proto\x12\x06mlflow\x1a\x15scalapb/scalapb.proto\x1a\x10\x64\x61tabricks.proto\"H\n\x06Metric\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x01\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x04step\x18\x04 \x01(\x03:\x01\x30\"#\n\x05Param\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"C\n\x03Run\x12\x1d\n\x04info\x18\x01 \x01(\x0b\x32\x0f.mlflow.RunInfo\x12\x1d\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x0f.mlflow.RunData\"g\n\x07RunData\x12\x1f\n\x07metrics\x18\x01 \x03(\x0b\x32\x0e.mlflow.Metric\x12\x1d\n\x06params\x18\x02 \x03(\x0b\x32\r.mlflow.Param\x12\x1c\n\x04tags\x18\x03 \x03(\x0b\x32\x0e.mlflow.RunTag\"$\n\x06RunTag\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"+\n\rExperimentTag\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"\xcb\x01\n\x07RunInfo\x12\x0e\n\x06run_id\x18\x0f \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x12\x15\n\rexperiment_id\x18\x02 \x01(\t\x12\x0f\n\x07user_id\x18\x06 \x01(\t\x12!\n\x06status\x18\x07 \x01(\x0e\x32\x11.mlflow.RunStatus\x12\x12\n\nstart_time\x18\x08 \x01(\x03\x12\x10\n\x08\x65nd_time\x18\t \x01(\x03\x12\x14\n\x0c\x61rtifact_uri\x18\r \x01(\t\x12\x17\n\x0flifecycle_stage\x18\x0e \x01(\t\"\xbb\x01\n\nExperiment\x12\x15\n\rexperiment_id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x19\n\x11\x61rtifact_location\x18\x03 \x01(\t\x12\x17\n\x0flifecycle_stage\x18\x04 \x01(\t\x12\x18\n\x10last_update_time\x18\x05 \x01(\x03\x12\x15\n\rcreation_time\x18\x06 \x01(\x03\x12#\n\x04tags\x18\x07 \x03(\x0b\x32\x15.mlflow.ExperimentTag\"\x91\x01\n\x10\x43reateExperiment\x12\x12\n\x04name\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x12\x19\n\x11\x61rtifact_location\x18\x02 \x01(\t\x1a!\n\x08Response\x12\x15\n\rexperiment_id\x18\x01 \x01(\t:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\x98\x01\n\x0fListExperiments\x12#\n\tview_type\x18\x01 \x01(\x0e\x32\x10.mlflow.ViewType\x1a\x33\n\x08Response\x12\'\n\x0b\x65xperiments\x18\x01 \x03(\x0b\x32\x12.mlflow.Experiment:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\xb0\x01\n\rGetExperiment\x12\x1b\n\rexperiment_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x1aU\n\x08Response\x12&\n\nexperiment\x18\x01 \x01(\x0b\x32\x12.mlflow.Experiment\x12!\n\x04runs\x18\x02 \x03(\x0b\x32\x0f.mlflow.RunInfoB\x02\x18\x01:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"h\n\x10\x44\x65leteExperiment\x12\x1b\n\rexperiment_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"i\n\x11RestoreExperiment\x12\x1b\n\rexperiment_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"z\n\x10UpdateExperiment\x12\x1b\n\rexperiment_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x12\x10\n\x08new_name\x18\x02 \x01(\t\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\xb8\x01\n\tCreateRun\x12\x15\n\rexperiment_id\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\t\x12\x12\n\nstart_time\x18\x07 \x01(\x03\x12\x1c\n\x04tags\x18\t \x03(\x0b\x32\x0e.mlflow.RunTag\x1a$\n\x08Response\x12\x18\n\x03run\x18\x01 \x01(\x0b\x32\x0b.mlflow.Run:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\xbe\x01\n\tUpdateRun\x12\x0e\n\x06run_id\x18\x04 \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x12!\n\x06status\x18\x02 \x01(\x0e\x32\x11.mlflow.RunStatus\x12\x10\n\x08\x65nd_time\x18\x03 \x01(\x03\x1a-\n\x08Response\x12!\n\x08run_info\x18\x01 \x01(\x0b\x32\x0f.mlflow.RunInfo:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"Z\n\tDeleteRun\x12\x14\n\x06run_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"[\n\nRestoreRun\x12\x14\n\x06run_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\xb8\x01\n\tLogMetric\x12\x0e\n\x06run_id\x18\x06 \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x12\x11\n\x03key\x18\x02 \x01(\tB\x04\xf8\x86\x19\x01\x12\x13\n\x05value\x18\x03 \x01(\x01\x42\x04\xf8\x86\x19\x01\x12\x17\n\ttimestamp\x18\x04 \x01(\x03\x42\x04\xf8\x86\x19\x01\x12\x0f\n\x04step\x18\x05 \x01(\x03:\x01\x30\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\x8d\x01\n\x08LogParam\x12\x0e\n\x06run_id\x18\x04 \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x12\x11\n\x03key\x18\x02 \x01(\tB\x04\xf8\x86\x19\x01\x12\x13\n\x05value\x18\x03 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\x90\x01\n\x10SetExperimentTag\x12\x1b\n\rexperiment_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x12\x11\n\x03key\x18\x02 \x01(\tB\x04\xf8\x86\x19\x01\x12\x13\n\x05value\x18\x03 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\x8b\x01\n\x06SetTag\x12\x0e\n\x06run_id\x18\x04 \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x12\x11\n\x03key\x18\x02 \x01(\tB\x04\xf8\x86\x19\x01\x12\x13\n\x05value\x18\x03 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"m\n\tDeleteTag\x12\x14\n\x06run_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x12\x11\n\x03key\x18\x02 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"}\n\x06GetRun\x12\x0e\n\x06run_id\x18\x02 \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x1a$\n\x08Response\x12\x18\n\x03run\x18\x01 \x01(\x0b\x32\x0b.mlflow.Run:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\x98\x02\n\nSearchRuns\x12\x16\n\x0e\x65xperiment_ids\x18\x01 \x03(\t\x12\x0e\n\x06\x66ilter\x18\x04 \x01(\t\x12\x34\n\rrun_view_type\x18\x03 \x01(\x0e\x32\x10.mlflow.ViewType:\x0b\x41\x43TIVE_ONLY\x12\x19\n\x0bmax_results\x18\x05 \x01(\x05:\x04\x31\x30\x30\x30\x12\x10\n\x08order_by\x18\x06 \x03(\t\x12\x12\n\npage_token\x18\x07 \x01(\t\x1a>\n\x08Response\x12\x19\n\x04runs\x18\x01 \x03(\x0b\x32\x0b.mlflow.Run\x12\x17\n\x0fnext_page_token\x18\x02 \x01(\t:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\xab\x01\n\rListArtifacts\x12\x0e\n\x06run_id\x18\x03 \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x12\x0c\n\x04path\x18\x02 \x01(\t\x1a=\n\x08Response\x12\x10\n\x08root_uri\x18\x01 \x01(\t\x12\x1f\n\x05\x66iles\x18\x02 \x03(\x0b\x32\x10.mlflow.FileInfo:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\";\n\x08\x46ileInfo\x12\x0c\n\x04path\x18\x01 \x01(\t\x12\x0e\n\x06is_dir\x18\x02 \x01(\x08\x12\x11\n\tfile_size\x18\x03 \x01(\x03\"\xa8\x01\n\x10GetMetricHistory\x12\x0e\n\x06run_id\x18\x03 \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x12\x18\n\nmetric_key\x18\x02 \x01(\tB\x04\xf8\x86\x19\x01\x1a+\n\x08Response\x12\x1f\n\x07metrics\x18\x01 \x03(\x0b\x32\x0e.mlflow.Metric:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\xb1\x01\n\x08LogBatch\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x1f\n\x07metrics\x18\x02 \x03(\x0b\x32\x0e.mlflow.Metric\x12\x1d\n\x06params\x18\x03 \x03(\x0b\x32\r.mlflow.Param\x12\x1c\n\x04tags\x18\x04 \x03(\x0b\x32\x0e.mlflow.RunTag\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"g\n\x08LogModel\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x12\n\nmodel_json\x18\x02 \x01(\t\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\x95\x01\n\x13GetExperimentByName\x12\x1d\n\x0f\x65xperiment_name\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x1a\x32\n\x08Response\x12&\n\nexperiment\x18\x01 \x01(\x0b\x32\x12.mlflow.Experiment:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]*6\n\x08ViewType\x12\x0f\n\x0b\x41\x43TIVE_ONLY\x10\x01\x12\x10\n\x0c\x44\x45LETED_ONLY\x10\x02\x12\x07\n\x03\x41LL\x10\x03*I\n\nSourceType\x12\x0c\n\x08NOTEBOOK\x10\x01\x12\x07\n\x03JOB\x10\x02\x12\x0b\n\x07PROJECT\x10\x03\x12\t\n\x05LOCAL\x10\x04\x12\x0c\n\x07UNKNOWN\x10\xe8\x07*M\n\tRunStatus\x12\x0b\n\x07RUNNING\x10\x01\x12\r\n\tSCHEDULED\x10\x02\x12\x0c\n\x08\x46INISHED\x10\x03\x12\n\n\x06\x46\x41ILED\x10\x04\x12\n\n\x06KILLED\x10\x05\x32\xe1\x1e\n\rMlflowService\x12\xa6\x01\n\x13getExperimentByName\x12\x1b.mlflow.GetExperimentByName\x1a$.mlflow.GetExperimentByName.Response\"L\xf2\x86\x19H\n,\n\x03GET\x12\x1f/mlflow/experiments/get-by-name\x1a\x04\x08\x02\x10\x00\x10\x01*\x16Get Experiment By Name\x12\xc6\x01\n\x10\x63reateExperiment\x12\x18.mlflow.CreateExperiment\x1a!.mlflow.CreateExperiment.Response\"u\xf2\x86\x19q\n(\n\x04POST\x12\x1a/mlflow/experiments/create\x1a\x04\x08\x02\x10\x00\n0\n\x04POST\x12\"/preview/mlflow/experiments/create\x1a\x04\x08\x02\x10\x00\x10\x01*\x11\x43reate Experiment\x12\xbc\x01\n\x0flistExperiments\x12\x17.mlflow.ListExperiments\x1a .mlflow.ListExperiments.Response\"n\xf2\x86\x19j\n%\n\x03GET\x12\x18/mlflow/experiments/list\x1a\x04\x08\x02\x10\x00\n-\n\x03GET\x12 /preview/mlflow/experiments/list\x1a\x04\x08\x02\x10\x00\x10\x01*\x10List Experiments\x12\xb2\x01\n\rgetExperiment\x12\x15.mlflow.GetExperiment\x1a\x1e.mlflow.GetExperiment.Response\"j\xf2\x86\x19\x66\n$\n\x03GET\x12\x17/mlflow/experiments/get\x1a\x04\x08\x02\x10\x00\n,\n\x03GET\x12\x1f/preview/mlflow/experiments/get\x1a\x04\x08\x02\x10\x00\x10\x01*\x0eGet Experiment\x12\xc6\x01\n\x10\x64\x65leteExperiment\x12\x18.mlflow.DeleteExperiment\x1a!.mlflow.DeleteExperiment.Response\"u\xf2\x86\x19q\n(\n\x04POST\x12\x1a/mlflow/experiments/delete\x1a\x04\x08\x02\x10\x00\n0\n\x04POST\x12\"/preview/mlflow/experiments/delete\x1a\x04\x08\x02\x10\x00\x10\x01*\x11\x44\x65lete Experiment\x12\xcc\x01\n\x11restoreExperiment\x12\x19.mlflow.RestoreExperiment\x1a\".mlflow.RestoreExperiment.Response\"x\xf2\x86\x19t\n)\n\x04POST\x12\x1b/mlflow/experiments/restore\x1a\x04\x08\x02\x10\x00\n1\n\x04POST\x12#/preview/mlflow/experiments/restore\x1a\x04\x08\x02\x10\x00\x10\x01*\x12Restore Experiment\x12\xc6\x01\n\x10updateExperiment\x12\x18.mlflow.UpdateExperiment\x1a!.mlflow.UpdateExperiment.Response\"u\xf2\x86\x19q\n(\n\x04POST\x12\x1a/mlflow/experiments/update\x1a\x04\x08\x02\x10\x00\n0\n\x04POST\x12\"/preview/mlflow/experiments/update\x1a\x04\x08\x02\x10\x00\x10\x01*\x11Update Experiment\x12\x9c\x01\n\tcreateRun\x12\x11.mlflow.CreateRun\x1a\x1a.mlflow.CreateRun.Response\"`\xf2\x86\x19\\\n!\n\x04POST\x12\x13/mlflow/runs/create\x1a\x04\x08\x02\x10\x00\n)\n\x04POST\x12\x1b/preview/mlflow/runs/create\x1a\x04\x08\x02\x10\x00\x10\x01*\nCreate Run\x12\x9c\x01\n\tupdateRun\x12\x11.mlflow.UpdateRun\x1a\x1a.mlflow.UpdateRun.Response\"`\xf2\x86\x19\\\n!\n\x04POST\x12\x13/mlflow/runs/update\x1a\x04\x08\x02\x10\x00\n)\n\x04POST\x12\x1b/preview/mlflow/runs/update\x1a\x04\x08\x02\x10\x00\x10\x01*\nUpdate Run\x12\x9c\x01\n\tdeleteRun\x12\x11.mlflow.DeleteRun\x1a\x1a.mlflow.DeleteRun.Response\"`\xf2\x86\x19\\\n!\n\x04POST\x12\x13/mlflow/runs/delete\x1a\x04\x08\x02\x10\x00\n)\n\x04POST\x12\x1b/preview/mlflow/runs/delete\x1a\x04\x08\x02\x10\x00\x10\x01*\nDelete Run\x12\xa2\x01\n\nrestoreRun\x12\x12.mlflow.RestoreRun\x1a\x1b.mlflow.RestoreRun.Response\"c\xf2\x86\x19_\n\"\n\x04POST\x12\x14/mlflow/runs/restore\x1a\x04\x08\x02\x10\x00\n*\n\x04POST\x12\x1c/preview/mlflow/runs/restore\x1a\x04\x08\x02\x10\x00\x10\x01*\x0bRestore Run\x12\xa4\x01\n\tlogMetric\x12\x11.mlflow.LogMetric\x1a\x1a.mlflow.LogMetric.Response\"h\xf2\x86\x19\x64\n%\n\x04POST\x12\x17/mlflow/runs/log-metric\x1a\x04\x08\x02\x10\x00\n-\n\x04POST\x12\x1f/preview/mlflow/runs/log-metric\x1a\x04\x08\x02\x10\x00\x10\x01*\nLog Metric\x12\xa6\x01\n\x08logParam\x12\x10.mlflow.LogParam\x1a\x19.mlflow.LogParam.Response\"m\xf2\x86\x19i\n(\n\x04POST\x12\x1a/mlflow/runs/log-parameter\x1a\x04\x08\x02\x10\x00\n0\n\x04POST\x12\"/preview/mlflow/runs/log-parameter\x1a\x04\x08\x02\x10\x00\x10\x01*\tLog Param\x12\xe1\x01\n\x10setExperimentTag\x12\x18.mlflow.SetExperimentTag\x1a!.mlflow.SetExperimentTag.Response\"\x8f\x01\xf2\x86\x19\x8a\x01\n4\n\x04POST\x12&/mlflow/experiments/set-experiment-tag\x1a\x04\x08\x02\x10\x00\n<\n\x04POST\x12./preview/mlflow/experiments/set-experiment-tag\x1a\x04\x08\x02\x10\x00\x10\x01*\x12Set Experiment Tag\x12\x92\x01\n\x06setTag\x12\x0e.mlflow.SetTag\x1a\x17.mlflow.SetTag.Response\"_\xf2\x86\x19[\n\"\n\x04POST\x12\x14/mlflow/runs/set-tag\x1a\x04\x08\x02\x10\x00\n*\n\x04POST\x12\x1c/preview/mlflow/runs/set-tag\x1a\x04\x08\x02\x10\x00\x10\x01*\x07Set Tag\x12\xa4\x01\n\tdeleteTag\x12\x11.mlflow.DeleteTag\x1a\x1a.mlflow.DeleteTag.Response\"h\xf2\x86\x19\x64\n%\n\x04POST\x12\x17/mlflow/runs/delete-tag\x1a\x04\x08\x02\x10\x00\n-\n\x04POST\x12\x1f/preview/mlflow/runs/delete-tag\x1a\x04\x08\x02\x10\x00\x10\x01*\nDelete Tag\x12\x88\x01\n\x06getRun\x12\x0e.mlflow.GetRun\x1a\x17.mlflow.GetRun.Response\"U\xf2\x86\x19Q\n\x1d\n\x03GET\x12\x10/mlflow/runs/get\x1a\x04\x08\x02\x10\x00\n%\n\x03GET\x12\x18/preview/mlflow/runs/get\x1a\x04\x08\x02\x10\x00\x10\x01*\x07Get Run\x12\xcc\x01\n\nsearchRuns\x12\x12.mlflow.SearchRuns\x1a\x1b.mlflow.SearchRuns.Response\"\x8c\x01\xf2\x86\x19\x87\x01\n!\n\x04POST\x12\x13/mlflow/runs/search\x1a\x04\x08\x02\x10\x00\n)\n\x04POST\x12\x1b/preview/mlflow/runs/search\x1a\x04\x08\x02\x10\x00\n(\n\x03GET\x12\x1b/preview/mlflow/runs/search\x1a\x04\x08\x02\x10\x00\x10\x01*\x0bSearch Runs\x12\xb0\x01\n\rlistArtifacts\x12\x15.mlflow.ListArtifacts\x1a\x1e.mlflow.ListArtifacts.Response\"h\xf2\x86\x19\x64\n#\n\x03GET\x12\x16/mlflow/artifacts/list\x1a\x04\x08\x02\x10\x00\n+\n\x03GET\x12\x1e/preview/mlflow/artifacts/list\x1a\x04\x08\x02\x10\x00\x10\x01*\x0eList Artifacts\x12\xc7\x01\n\x10getMetricHistory\x12\x18.mlflow.GetMetricHistory\x1a!.mlflow.GetMetricHistory.Response\"v\xf2\x86\x19r\n(\n\x03GET\x12\x1b/mlflow/metrics/get-history\x1a\x04\x08\x02\x10\x00\n0\n\x03GET\x12#/preview/mlflow/metrics/get-history\x1a\x04\x08\x02\x10\x00\x10\x01*\x12Get Metric History\x12\x9e\x01\n\x08logBatch\x12\x10.mlflow.LogBatch\x1a\x19.mlflow.LogBatch.Response\"e\xf2\x86\x19\x61\n$\n\x04POST\x12\x16/mlflow/runs/log-batch\x1a\x04\x08\x02\x10\x00\n,\n\x04POST\x12\x1e/preview/mlflow/runs/log-batch\x1a\x04\x08\x02\x10\x00\x10\x01*\tLog Batch\x12\x9e\x01\n\x08logModel\x12\x10.mlflow.LogModel\x1a\x19.mlflow.LogModel.Response\"e\xf2\x86\x19\x61\n$\n\x04POST\x12\x16/mlflow/runs/log-model\x1a\x04\x08\x02\x10\x00\n,\n\x04POST\x12\x1e/preview/mlflow/runs/log-model\x1a\x04\x08\x02\x10\x00\x10\x01*\tLog ModelB\x1e\n\x14org.mlflow.api.proto\x90\x01\x01\xe2?\x02\x10\x01') + serialized_pb=_b('\n\rservice.proto\x12\x06mlflow\x1a\x15scalapb/scalapb.proto\x1a\x10\x64\x61tabricks.proto\"H\n\x06Metric\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x01\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x04step\x18\x04 \x01(\x03:\x01\x30\"#\n\x05Param\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"C\n\x03Run\x12\x1d\n\x04info\x18\x01 \x01(\x0b\x32\x0f.mlflow.RunInfo\x12\x1d\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x0f.mlflow.RunData\"g\n\x07RunData\x12\x1f\n\x07metrics\x18\x01 \x03(\x0b\x32\x0e.mlflow.Metric\x12\x1d\n\x06params\x18\x02 \x03(\x0b\x32\r.mlflow.Param\x12\x1c\n\x04tags\x18\x03 \x03(\x0b\x32\x0e.mlflow.RunTag\"$\n\x06RunTag\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"+\n\rExperimentTag\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"\xcb\x01\n\x07RunInfo\x12\x0e\n\x06run_id\x18\x0f \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x12\x15\n\rexperiment_id\x18\x02 \x01(\t\x12\x0f\n\x07user_id\x18\x06 \x01(\t\x12!\n\x06status\x18\x07 \x01(\x0e\x32\x11.mlflow.RunStatus\x12\x12\n\nstart_time\x18\x08 \x01(\x03\x12\x10\n\x08\x65nd_time\x18\t \x01(\x03\x12\x14\n\x0c\x61rtifact_uri\x18\r \x01(\t\x12\x17\n\x0flifecycle_stage\x18\x0e \x01(\t\"\xbb\x01\n\nExperiment\x12\x15\n\rexperiment_id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x19\n\x11\x61rtifact_location\x18\x03 \x01(\t\x12\x17\n\x0flifecycle_stage\x18\x04 \x01(\t\x12\x18\n\x10last_update_time\x18\x05 \x01(\x03\x12\x15\n\rcreation_time\x18\x06 \x01(\x03\x12#\n\x04tags\x18\x07 \x03(\x0b\x32\x15.mlflow.ExperimentTag\"\x91\x01\n\x10\x43reateExperiment\x12\x12\n\x04name\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x12\x19\n\x11\x61rtifact_location\x18\x02 \x01(\t\x1a!\n\x08Response\x12\x15\n\rexperiment_id\x18\x01 \x01(\t:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\x98\x01\n\x0fListExperiments\x12#\n\tview_type\x18\x01 \x01(\x0e\x32\x10.mlflow.ViewType\x1a\x33\n\x08Response\x12\'\n\x0b\x65xperiments\x18\x01 \x03(\x0b\x32\x12.mlflow.Experiment:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\xb0\x01\n\rGetExperiment\x12\x1b\n\rexperiment_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x1aU\n\x08Response\x12&\n\nexperiment\x18\x01 \x01(\x0b\x32\x12.mlflow.Experiment\x12!\n\x04runs\x18\x02 \x03(\x0b\x32\x0f.mlflow.RunInfoB\x02\x18\x01:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"h\n\x10\x44\x65leteExperiment\x12\x1b\n\rexperiment_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"i\n\x11RestoreExperiment\x12\x1b\n\rexperiment_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"z\n\x10UpdateExperiment\x12\x1b\n\rexperiment_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x12\x10\n\x08new_name\x18\x02 \x01(\t\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\xb8\x01\n\tCreateRun\x12\x15\n\rexperiment_id\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\t\x12\x12\n\nstart_time\x18\x07 \x01(\x03\x12\x1c\n\x04tags\x18\t \x03(\x0b\x32\x0e.mlflow.RunTag\x1a$\n\x08Response\x12\x18\n\x03run\x18\x01 \x01(\x0b\x32\x0b.mlflow.Run:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\xbe\x01\n\tUpdateRun\x12\x0e\n\x06run_id\x18\x04 \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x12!\n\x06status\x18\x02 \x01(\x0e\x32\x11.mlflow.RunStatus\x12\x10\n\x08\x65nd_time\x18\x03 \x01(\x03\x1a-\n\x08Response\x12!\n\x08run_info\x18\x01 \x01(\x0b\x32\x0f.mlflow.RunInfo:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"Z\n\tDeleteRun\x12\x14\n\x06run_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"[\n\nRestoreRun\x12\x14\n\x06run_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\xb8\x01\n\tLogMetric\x12\x0e\n\x06run_id\x18\x06 \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x12\x11\n\x03key\x18\x02 \x01(\tB\x04\xf8\x86\x19\x01\x12\x13\n\x05value\x18\x03 \x01(\x01\x42\x04\xf8\x86\x19\x01\x12\x17\n\ttimestamp\x18\x04 \x01(\x03\x42\x04\xf8\x86\x19\x01\x12\x0f\n\x04step\x18\x05 \x01(\x03:\x01\x30\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\x8d\x01\n\x08LogParam\x12\x0e\n\x06run_id\x18\x04 \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x12\x11\n\x03key\x18\x02 \x01(\tB\x04\xf8\x86\x19\x01\x12\x13\n\x05value\x18\x03 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\x90\x01\n\x10SetExperimentTag\x12\x1b\n\rexperiment_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x12\x11\n\x03key\x18\x02 \x01(\tB\x04\xf8\x86\x19\x01\x12\x13\n\x05value\x18\x03 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\x8b\x01\n\x06SetTag\x12\x0e\n\x06run_id\x18\x04 \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x12\x11\n\x03key\x18\x02 \x01(\tB\x04\xf8\x86\x19\x01\x12\x13\n\x05value\x18\x03 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"m\n\tDeleteTag\x12\x14\n\x06run_id\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x12\x11\n\x03key\x18\x02 \x01(\tB\x04\xf8\x86\x19\x01\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"}\n\x06GetRun\x12\x0e\n\x06run_id\x18\x02 \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x1a$\n\x08Response\x12\x18\n\x03run\x18\x01 \x01(\x0b\x32\x0b.mlflow.Run:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\x98\x02\n\nSearchRuns\x12\x16\n\x0e\x65xperiment_ids\x18\x01 \x03(\t\x12\x0e\n\x06\x66ilter\x18\x04 \x01(\t\x12\x34\n\rrun_view_type\x18\x03 \x01(\x0e\x32\x10.mlflow.ViewType:\x0b\x41\x43TIVE_ONLY\x12\x19\n\x0bmax_results\x18\x05 \x01(\x05:\x04\x31\x30\x30\x30\x12\x10\n\x08order_by\x18\x06 \x03(\t\x12\x12\n\npage_token\x18\x07 \x01(\t\x1a>\n\x08Response\x12\x19\n\x04runs\x18\x01 \x03(\x0b\x32\x0b.mlflow.Run\x12\x17\n\x0fnext_page_token\x18\x02 \x01(\t:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\xd8\x01\n\rListArtifacts\x12\x0e\n\x06run_id\x18\x03 \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x12\x0c\n\x04path\x18\x02 \x01(\t\x12\x12\n\npage_token\x18\x04 \x01(\t\x1aV\n\x08Response\x12\x10\n\x08root_uri\x18\x01 \x01(\t\x12\x1f\n\x05\x66iles\x18\x02 \x03(\x0b\x32\x10.mlflow.FileInfo\x12\x17\n\x0fnext_page_token\x18\x03 \x01(\t:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\";\n\x08\x46ileInfo\x12\x0c\n\x04path\x18\x01 \x01(\t\x12\x0e\n\x06is_dir\x18\x02 \x01(\x08\x12\x11\n\tfile_size\x18\x03 \x01(\x03\"\xa8\x01\n\x10GetMetricHistory\x12\x0e\n\x06run_id\x18\x03 \x01(\t\x12\x10\n\x08run_uuid\x18\x01 \x01(\t\x12\x18\n\nmetric_key\x18\x02 \x01(\tB\x04\xf8\x86\x19\x01\x1a+\n\x08Response\x12\x1f\n\x07metrics\x18\x01 \x03(\x0b\x32\x0e.mlflow.Metric:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\xb1\x01\n\x08LogBatch\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x1f\n\x07metrics\x18\x02 \x03(\x0b\x32\x0e.mlflow.Metric\x12\x1d\n\x06params\x18\x03 \x03(\x0b\x32\r.mlflow.Param\x12\x1c\n\x04tags\x18\x04 \x03(\x0b\x32\x0e.mlflow.RunTag\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"g\n\x08LogModel\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x12\n\nmodel_json\x18\x02 \x01(\t\x1a\n\n\x08Response:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]\"\x95\x01\n\x13GetExperimentByName\x12\x1d\n\x0f\x65xperiment_name\x18\x01 \x01(\tB\x04\xf8\x86\x19\x01\x1a\x32\n\x08Response\x12&\n\nexperiment\x18\x01 \x01(\x0b\x32\x12.mlflow.Experiment:+\xe2?(\n&com.databricks.rpc.RPC[$this.Response]*6\n\x08ViewType\x12\x0f\n\x0b\x41\x43TIVE_ONLY\x10\x01\x12\x10\n\x0c\x44\x45LETED_ONLY\x10\x02\x12\x07\n\x03\x41LL\x10\x03*I\n\nSourceType\x12\x0c\n\x08NOTEBOOK\x10\x01\x12\x07\n\x03JOB\x10\x02\x12\x0b\n\x07PROJECT\x10\x03\x12\t\n\x05LOCAL\x10\x04\x12\x0c\n\x07UNKNOWN\x10\xe8\x07*M\n\tRunStatus\x12\x0b\n\x07RUNNING\x10\x01\x12\r\n\tSCHEDULED\x10\x02\x12\x0c\n\x08\x46INISHED\x10\x03\x12\n\n\x06\x46\x41ILED\x10\x04\x12\n\n\x06KILLED\x10\x05\x32\xe1\x1e\n\rMlflowService\x12\xa6\x01\n\x13getExperimentByName\x12\x1b.mlflow.GetExperimentByName\x1a$.mlflow.GetExperimentByName.Response\"L\xf2\x86\x19H\n,\n\x03GET\x12\x1f/mlflow/experiments/get-by-name\x1a\x04\x08\x02\x10\x00\x10\x01*\x16Get Experiment By Name\x12\xc6\x01\n\x10\x63reateExperiment\x12\x18.mlflow.CreateExperiment\x1a!.mlflow.CreateExperiment.Response\"u\xf2\x86\x19q\n(\n\x04POST\x12\x1a/mlflow/experiments/create\x1a\x04\x08\x02\x10\x00\n0\n\x04POST\x12\"/preview/mlflow/experiments/create\x1a\x04\x08\x02\x10\x00\x10\x01*\x11\x43reate Experiment\x12\xbc\x01\n\x0flistExperiments\x12\x17.mlflow.ListExperiments\x1a .mlflow.ListExperiments.Response\"n\xf2\x86\x19j\n%\n\x03GET\x12\x18/mlflow/experiments/list\x1a\x04\x08\x02\x10\x00\n-\n\x03GET\x12 /preview/mlflow/experiments/list\x1a\x04\x08\x02\x10\x00\x10\x01*\x10List Experiments\x12\xb2\x01\n\rgetExperiment\x12\x15.mlflow.GetExperiment\x1a\x1e.mlflow.GetExperiment.Response\"j\xf2\x86\x19\x66\n$\n\x03GET\x12\x17/mlflow/experiments/get\x1a\x04\x08\x02\x10\x00\n,\n\x03GET\x12\x1f/preview/mlflow/experiments/get\x1a\x04\x08\x02\x10\x00\x10\x01*\x0eGet Experiment\x12\xc6\x01\n\x10\x64\x65leteExperiment\x12\x18.mlflow.DeleteExperiment\x1a!.mlflow.DeleteExperiment.Response\"u\xf2\x86\x19q\n(\n\x04POST\x12\x1a/mlflow/experiments/delete\x1a\x04\x08\x02\x10\x00\n0\n\x04POST\x12\"/preview/mlflow/experiments/delete\x1a\x04\x08\x02\x10\x00\x10\x01*\x11\x44\x65lete Experiment\x12\xcc\x01\n\x11restoreExperiment\x12\x19.mlflow.RestoreExperiment\x1a\".mlflow.RestoreExperiment.Response\"x\xf2\x86\x19t\n)\n\x04POST\x12\x1b/mlflow/experiments/restore\x1a\x04\x08\x02\x10\x00\n1\n\x04POST\x12#/preview/mlflow/experiments/restore\x1a\x04\x08\x02\x10\x00\x10\x01*\x12Restore Experiment\x12\xc6\x01\n\x10updateExperiment\x12\x18.mlflow.UpdateExperiment\x1a!.mlflow.UpdateExperiment.Response\"u\xf2\x86\x19q\n(\n\x04POST\x12\x1a/mlflow/experiments/update\x1a\x04\x08\x02\x10\x00\n0\n\x04POST\x12\"/preview/mlflow/experiments/update\x1a\x04\x08\x02\x10\x00\x10\x01*\x11Update Experiment\x12\x9c\x01\n\tcreateRun\x12\x11.mlflow.CreateRun\x1a\x1a.mlflow.CreateRun.Response\"`\xf2\x86\x19\\\n!\n\x04POST\x12\x13/mlflow/runs/create\x1a\x04\x08\x02\x10\x00\n)\n\x04POST\x12\x1b/preview/mlflow/runs/create\x1a\x04\x08\x02\x10\x00\x10\x01*\nCreate Run\x12\x9c\x01\n\tupdateRun\x12\x11.mlflow.UpdateRun\x1a\x1a.mlflow.UpdateRun.Response\"`\xf2\x86\x19\\\n!\n\x04POST\x12\x13/mlflow/runs/update\x1a\x04\x08\x02\x10\x00\n)\n\x04POST\x12\x1b/preview/mlflow/runs/update\x1a\x04\x08\x02\x10\x00\x10\x01*\nUpdate Run\x12\x9c\x01\n\tdeleteRun\x12\x11.mlflow.DeleteRun\x1a\x1a.mlflow.DeleteRun.Response\"`\xf2\x86\x19\\\n!\n\x04POST\x12\x13/mlflow/runs/delete\x1a\x04\x08\x02\x10\x00\n)\n\x04POST\x12\x1b/preview/mlflow/runs/delete\x1a\x04\x08\x02\x10\x00\x10\x01*\nDelete Run\x12\xa2\x01\n\nrestoreRun\x12\x12.mlflow.RestoreRun\x1a\x1b.mlflow.RestoreRun.Response\"c\xf2\x86\x19_\n\"\n\x04POST\x12\x14/mlflow/runs/restore\x1a\x04\x08\x02\x10\x00\n*\n\x04POST\x12\x1c/preview/mlflow/runs/restore\x1a\x04\x08\x02\x10\x00\x10\x01*\x0bRestore Run\x12\xa4\x01\n\tlogMetric\x12\x11.mlflow.LogMetric\x1a\x1a.mlflow.LogMetric.Response\"h\xf2\x86\x19\x64\n%\n\x04POST\x12\x17/mlflow/runs/log-metric\x1a\x04\x08\x02\x10\x00\n-\n\x04POST\x12\x1f/preview/mlflow/runs/log-metric\x1a\x04\x08\x02\x10\x00\x10\x01*\nLog Metric\x12\xa6\x01\n\x08logParam\x12\x10.mlflow.LogParam\x1a\x19.mlflow.LogParam.Response\"m\xf2\x86\x19i\n(\n\x04POST\x12\x1a/mlflow/runs/log-parameter\x1a\x04\x08\x02\x10\x00\n0\n\x04POST\x12\"/preview/mlflow/runs/log-parameter\x1a\x04\x08\x02\x10\x00\x10\x01*\tLog Param\x12\xe1\x01\n\x10setExperimentTag\x12\x18.mlflow.SetExperimentTag\x1a!.mlflow.SetExperimentTag.Response\"\x8f\x01\xf2\x86\x19\x8a\x01\n4\n\x04POST\x12&/mlflow/experiments/set-experiment-tag\x1a\x04\x08\x02\x10\x00\n<\n\x04POST\x12./preview/mlflow/experiments/set-experiment-tag\x1a\x04\x08\x02\x10\x00\x10\x01*\x12Set Experiment Tag\x12\x92\x01\n\x06setTag\x12\x0e.mlflow.SetTag\x1a\x17.mlflow.SetTag.Response\"_\xf2\x86\x19[\n\"\n\x04POST\x12\x14/mlflow/runs/set-tag\x1a\x04\x08\x02\x10\x00\n*\n\x04POST\x12\x1c/preview/mlflow/runs/set-tag\x1a\x04\x08\x02\x10\x00\x10\x01*\x07Set Tag\x12\xa4\x01\n\tdeleteTag\x12\x11.mlflow.DeleteTag\x1a\x1a.mlflow.DeleteTag.Response\"h\xf2\x86\x19\x64\n%\n\x04POST\x12\x17/mlflow/runs/delete-tag\x1a\x04\x08\x02\x10\x00\n-\n\x04POST\x12\x1f/preview/mlflow/runs/delete-tag\x1a\x04\x08\x02\x10\x00\x10\x01*\nDelete Tag\x12\x88\x01\n\x06getRun\x12\x0e.mlflow.GetRun\x1a\x17.mlflow.GetRun.Response\"U\xf2\x86\x19Q\n\x1d\n\x03GET\x12\x10/mlflow/runs/get\x1a\x04\x08\x02\x10\x00\n%\n\x03GET\x12\x18/preview/mlflow/runs/get\x1a\x04\x08\x02\x10\x00\x10\x01*\x07Get Run\x12\xcc\x01\n\nsearchRuns\x12\x12.mlflow.SearchRuns\x1a\x1b.mlflow.SearchRuns.Response\"\x8c\x01\xf2\x86\x19\x87\x01\n!\n\x04POST\x12\x13/mlflow/runs/search\x1a\x04\x08\x02\x10\x00\n)\n\x04POST\x12\x1b/preview/mlflow/runs/search\x1a\x04\x08\x02\x10\x00\n(\n\x03GET\x12\x1b/preview/mlflow/runs/search\x1a\x04\x08\x02\x10\x00\x10\x01*\x0bSearch Runs\x12\xb0\x01\n\rlistArtifacts\x12\x15.mlflow.ListArtifacts\x1a\x1e.mlflow.ListArtifacts.Response\"h\xf2\x86\x19\x64\n#\n\x03GET\x12\x16/mlflow/artifacts/list\x1a\x04\x08\x02\x10\x00\n+\n\x03GET\x12\x1e/preview/mlflow/artifacts/list\x1a\x04\x08\x02\x10\x00\x10\x01*\x0eList Artifacts\x12\xc7\x01\n\x10getMetricHistory\x12\x18.mlflow.GetMetricHistory\x1a!.mlflow.GetMetricHistory.Response\"v\xf2\x86\x19r\n(\n\x03GET\x12\x1b/mlflow/metrics/get-history\x1a\x04\x08\x02\x10\x00\n0\n\x03GET\x12#/preview/mlflow/metrics/get-history\x1a\x04\x08\x02\x10\x00\x10\x01*\x12Get Metric History\x12\x9e\x01\n\x08logBatch\x12\x10.mlflow.LogBatch\x1a\x19.mlflow.LogBatch.Response\"e\xf2\x86\x19\x61\n$\n\x04POST\x12\x16/mlflow/runs/log-batch\x1a\x04\x08\x02\x10\x00\n,\n\x04POST\x12\x1e/preview/mlflow/runs/log-batch\x1a\x04\x08\x02\x10\x00\x10\x01*\tLog Batch\x12\x9e\x01\n\x08logModel\x12\x10.mlflow.LogModel\x1a\x19.mlflow.LogModel.Response\"e\xf2\x86\x19\x61\n$\n\x04POST\x12\x16/mlflow/runs/log-model\x1a\x04\x08\x02\x10\x00\n,\n\x04POST\x12\x1e/preview/mlflow/runs/log-model\x1a\x04\x08\x02\x10\x00\x10\x01*\tLog ModelB\x1e\n\x14org.mlflow.api.proto\x90\x01\x01\xe2?\x02\x10\x01') , dependencies=[scalapb_dot_scalapb__pb2.DESCRIPTOR,databricks__pb2.DESCRIPTOR,]) @@ -49,8 +49,8 @@ ], containing_type=None, serialized_options=None, - serialized_start=4198, - serialized_end=4252, + serialized_start=4243, + serialized_end=4297, ) _sym_db.RegisterEnumDescriptor(_VIEWTYPE) @@ -84,8 +84,8 @@ ], containing_type=None, serialized_options=None, - serialized_start=4254, - serialized_end=4327, + serialized_start=4299, + serialized_end=4372, ) _sym_db.RegisterEnumDescriptor(_SOURCETYPE) @@ -119,8 +119,8 @@ ], containing_type=None, serialized_options=None, - serialized_start=4329, - serialized_end=4406, + serialized_start=4374, + serialized_end=4451, ) _sym_db.RegisterEnumDescriptor(_RUNSTATUS) @@ -1748,6 +1748,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='next_page_token', full_name='mlflow.ListArtifacts.Response.next_page_token', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -1760,8 +1767,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=3421, - serialized_end=3482, + serialized_start=3441, + serialized_end=3527, ) _LISTARTIFACTS = _descriptor.Descriptor( @@ -1792,6 +1799,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='page_token', full_name='mlflow.ListArtifacts.page_token', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -1805,7 +1819,7 @@ oneofs=[ ], serialized_start=3356, - serialized_end=3527, + serialized_end=3572, ) @@ -1849,8 +1863,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=3529, - serialized_end=3588, + serialized_start=3574, + serialized_end=3633, ) @@ -1880,8 +1894,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=3671, - serialized_end=3714, + serialized_start=3716, + serialized_end=3759, ) _GETMETRICHISTORY = _descriptor.Descriptor( @@ -1924,8 +1938,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=3591, - serialized_end=3759, + serialized_start=3636, + serialized_end=3804, ) @@ -1999,8 +2013,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=3762, - serialized_end=3939, + serialized_start=3807, + serialized_end=3984, ) @@ -2060,8 +2074,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=3941, - serialized_end=4044, + serialized_start=3986, + serialized_end=4089, ) @@ -2121,8 +2135,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=4047, - serialized_end=4196, + serialized_start=4092, + serialized_end=4241, ) _RUN.fields_by_name['info'].message_type = _RUNINFO @@ -2653,8 +2667,8 @@ file=DESCRIPTOR, index=0, serialized_options=None, - serialized_start=4409, - serialized_end=8346, + serialized_start=4454, + serialized_end=8391, methods=[ _descriptor.MethodDescriptor( name='getExperimentByName', diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index 405c2547ed769..63637801100b4 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -237,20 +237,33 @@ def list_artifacts(self, path=None): self.run_relative_artifact_repo_root_path, path) else: run_relative_path = self.run_relative_artifact_repo_root_path - json_body = message_to_json(ListArtifacts(run_id=self.run_id, path=run_relative_path)) - artifact_list = self._call_endpoint(MlflowService, ListArtifacts, json_body).files - # If `path` is a file, ListArtifacts returns a single list element with the - # same name as `path`. The list_artifacts API expects us to return an empty list in this - # case, so we do so here. - if len(artifact_list) == 1 and artifact_list[0].path == path \ - and not artifact_list[0].is_dir: - return [] - infos = list() - for output_file in artifact_list: - file_rel_path = posixpath.relpath( - path=output_file.path, start=self.run_relative_artifact_repo_root_path) - artifact_size = None if output_file.is_dir else output_file.file_size - infos.append(FileInfo(file_rel_path, output_file.is_dir, artifact_size)) + + infos = [] + page_token = None + while True: + if page_token: + json_body = message_to_json(ListArtifacts(run_id=self.run_id, path=path, page_token=page_token)) + else: + json_body = message_to_json(ListArtifacts(run_id=self.run_id, path=path)) + json_body = message_to_json(ListArtifacts(run_id=self.run_id, path=run_relative_path)) + response = self._call_endpoint(MlflowService, ListArtifacts, json_body) + artifact_list = response.files + # If `path` is a file, ListArtifacts returns a single list element with the + # same name as `path`. The list_artifacts API expects us to return an empty list in this + # case, so we do so here. + if len(artifact_list) == 1 and artifact_list[0].path == path \ + and not artifact_list[0].is_dir: + return [] + for output_file in artifact_list: + file_rel_path = posixpath.relpath( + path=output_file.path, start=self.run_relative_artifact_repo_root_path) + artifact_size = None if output_file.is_dir else output_file.file_size + infos.append(FileInfo(file_rel_path, output_file.is_dir, artifact_size)) + + if len(artifact_list) == 0 or not response.next_page_token: + break + page_token = response.next_page_token + return infos def _download_file(self, remote_file_path, local_path): From 30e3885c686a96aedf949efa6638ac49fd534f69 Mon Sep 17 00:00:00 2001 From: Corey Zumar Date: Tue, 9 Jun 2020 14:40:19 -0700 Subject: [PATCH 20/28] Fix --- mlflow/store/artifact/databricks_artifact_repo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index 63637801100b4..f03ac87166346 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -241,11 +241,12 @@ def list_artifacts(self, path=None): infos = [] page_token = None while True: + print("PAGE TOKEN: {}".format(page_token)) + print("ARTIFACT LIST: {}".format(infos)) if page_token: json_body = message_to_json(ListArtifacts(run_id=self.run_id, path=path, page_token=page_token)) else: json_body = message_to_json(ListArtifacts(run_id=self.run_id, path=path)) - json_body = message_to_json(ListArtifacts(run_id=self.run_id, path=run_relative_path)) response = self._call_endpoint(MlflowService, ListArtifacts, json_body) artifact_list = response.files # If `path` is a file, ListArtifacts returns a single list element with the From a03e099098288188e91169bded6e0a05ef329022 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Tue, 9 Jun 2020 22:46:15 -0700 Subject: [PATCH 21/28] Added relative path test cases --- .../artifact/databricks_artifact_repo.py | 1 + .../artifact/test_databricks_artifact_repo.py | 84 +++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index 405c2547ed769..c4d7056b41f7b 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -237,6 +237,7 @@ def list_artifacts(self, path=None): self.run_relative_artifact_repo_root_path, path) else: run_relative_path = self.run_relative_artifact_repo_root_path + print (ListArtifacts(run_id=self.run_id, path=run_relative_path)) json_body = message_to_json(ListArtifacts(run_id=self.run_id, path=run_relative_path)) artifact_list = self._call_endpoint(MlflowService, ListArtifacts, json_body).files # If `path` is a file, ListArtifacts returns a single list element with the diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index 2e1af3504cbb9..a21ee06863afa 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -26,6 +26,8 @@ ArtifactCredentialInfo.HttpHeader(name='Mock-Name2', value='Mock-Value2')] MOCK_RUN_ROOT_URI = \ "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts" +MOCK_SUBDIR = "subdir/path" +MOCK_SUBDIR_ROOT_URI = os.path.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR) @pytest.fixture() @@ -235,6 +237,30 @@ def test_log_artifact_aws_presigned_url_error(self, databricks_artifact_repo, te databricks_artifact_repo.log_artifact(test_file.strpath) write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY) + @pytest.mark.parametrize("artifact_path,expected_location", [ + (None, os.path.join(MOCK_SUBDIR, "test.txt")), + ('test_path', os.path.join(MOCK_SUBDIR, "test_path/test.txt")), + ]) + def test_log_artifact_with_relative_path(self, test_file, artifact_path, expected_location): + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ + as get_run_artifact_root_mock, \ + mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ + as write_credentials_mock, \ + mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._upload_to_cloud') \ + as upload_mock: + get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI + databricks_artifact_repo = get_artifact_repository(MOCK_SUBDIR_ROOT_URI) + mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, + type=ArtifactCredentialType.AZURE_SAS_URI) + write_credentials_response_proto = GetCredentialsForWrite.Response( + credentials=mock_credentials) + write_credentials_mock.return_value = write_credentials_response_proto + upload_mock.return_value = None + databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) + write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) + upload_mock.assert_called_with(write_credentials_response_proto, test_file.strpath, + expected_location) + @pytest.mark.parametrize("artifact_path", [ None, 'output/', @@ -280,6 +306,36 @@ def test_list_artifacts(self, databricks_artifact_repo): artifacts = databricks_artifact_repo.list_artifacts('a.txt') assert len(artifacts) == 0 + def test_list_artifacts_with_relative_path(self): + list_artifacts_dir_proto_mock = [ + FileInfo(path=os.path.join(MOCK_SUBDIR, 'test/a.txt'), is_dir=False, file_size=100), + FileInfo(path=os.path.join(MOCK_SUBDIR, 'test/dir'), is_dir=True, file_size=0) + ] + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ + as get_run_artifact_root_mock, \ + mock.patch( + DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + '.message_to_json')as message_mock, \ + mock.patch( + DATABRICKS_ARTIFACT_REPOSITORY + '._call_endpoint') as call_endpoint_mock: + get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI + list_artifact_response_proto = \ + ListArtifacts.Response(root_uri='', files=list_artifacts_dir_proto_mock) + call_endpoint_mock.return_value = list_artifact_response_proto + message_mock.return_value = None + databricks_artifact_repo = get_artifact_repository(MOCK_SUBDIR_ROOT_URI) + artifacts = databricks_artifact_repo.list_artifacts('test') + assert isinstance(artifacts, list) + assert isinstance(artifacts[0], FileInfoEntity) + assert len(artifacts) == 2 + assert artifacts[0].path == 'test/a.txt' + assert artifacts[0].is_dir is False + assert artifacts[0].file_size == 100 + assert artifacts[1].path == 'test/dir' + assert artifacts[1].is_dir is True + assert artifacts[1].file_size is None + message_mock.assert_called_with( + ListArtifacts(run_id=MOCK_RUN_ID, path=os.path.join(MOCK_SUBDIR, "test"))) + @pytest.mark.parametrize( "remote_file_path, local_path, cloud_credential_type", [ ('test_file.txt', '', ArtifactCredentialType.AZURE_SAS_URI), @@ -306,6 +362,34 @@ def test_databricks_download_file(self, databricks_artifact_repo, remote_file_pa read_credentials_mock.assert_called_with(MOCK_RUN_ID, remote_file_path) download_mock.assert_called_with(mock_credentials, ANY) + @pytest.mark.parametrize( + "remote_file_path, local_path", [ + ('test_file.txt', ''), + ('test_file.txt', None), + ]) + def test_databricks_download_file_with_relative_path(self, remote_file_path, local_path): + with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ + as get_run_artifact_root_mock, \ + mock.patch( + DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \ + as read_credentials_mock, \ + mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \ + mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._download_from_cloud') \ + as download_mock: + get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI + mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, + type=ArtifactCredentialType.AZURE_SAS_URI) + read_credentials_response_proto = GetCredentialsForRead.Response( + credentials=mock_credentials) + read_credentials_mock.return_value = read_credentials_response_proto + download_mock.return_value = None + get_list_mock.return_value = [] + databricks_artifact_repo = get_artifact_repository(MOCK_SUBDIR_ROOT_URI) + databricks_artifact_repo.download_artifacts(remote_file_path, local_path) + read_credentials_mock.assert_called_with(MOCK_RUN_ID, + os.path.join(MOCK_SUBDIR, remote_file_path)) + download_mock.assert_called_with(mock_credentials, ANY) + def test_databricks_download_file_get_request_fail(self, databricks_artifact_repo, test_file): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \ From 6d3629701459f86fd786d85db46124f64c5d515a Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Wed, 10 Jun 2020 12:16:01 -0700 Subject: [PATCH 22/28] Added test for list_artifacts pagination --- .../artifact/databricks_artifact_repo.py | 10 ++-- .../artifact/test_databricks_artifact_repo.py | 53 +++++++++++++++++-- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index 1ee41dd48ae74..06198d5da617f 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -240,12 +240,13 @@ def list_artifacts(self, path=None): infos = [] page_token = None while True: - print("PAGE TOKEN: {}".format(page_token)) - print("ARTIFACT LIST: {}".format(infos)) if page_token: - json_body = message_to_json(ListArtifacts(run_id=self.run_id, path=path, page_token=page_token)) + json_body = message_to_json( + ListArtifacts(run_id=self.run_id, path=run_relative_path, + page_token=page_token)) else: - json_body = message_to_json(ListArtifacts(run_id=self.run_id, path=path)) + json_body = message_to_json( + ListArtifacts(run_id=self.run_id, path=run_relative_path)) response = self._call_endpoint(MlflowService, ListArtifacts, json_body) artifact_list = response.files # If `path` is a file, ListArtifacts returns a single list element with the @@ -259,7 +260,6 @@ def list_artifacts(self, path=None): path=output_file.path, start=self.run_relative_artifact_repo_root_path) artifact_size = None if output_file.is_dir else output_file.file_size infos.append(FileInfo(file_rel_path, output_file.is_dir, artifact_size)) - if len(artifact_list) == 0 or not response.next_page_token: break page_token = response.next_page_token diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index a21ee06863afa..ecea79bea8627 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -285,7 +285,8 @@ def test_list_artifacts(self, databricks_artifact_repo): FileInfo(path='test/dir', is_dir=True, file_size=0)] with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._call_endpoint') as call_endpoint_mock: list_artifact_response_proto = \ - ListArtifacts.Response(root_uri='', files=list_artifacts_dir_proto_mock) + ListArtifacts.Response(root_uri='', files=list_artifacts_dir_proto_mock, + next_page_token=None) call_endpoint_mock.return_value = list_artifact_response_proto artifacts = databricks_artifact_repo.list_artifacts('test/') assert isinstance(artifacts, list) @@ -319,7 +320,8 @@ def test_list_artifacts_with_relative_path(self): DATABRICKS_ARTIFACT_REPOSITORY + '._call_endpoint') as call_endpoint_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI list_artifact_response_proto = \ - ListArtifacts.Response(root_uri='', files=list_artifacts_dir_proto_mock) + ListArtifacts.Response(root_uri='', files=list_artifacts_dir_proto_mock, + next_page_token=None) call_endpoint_mock.return_value = list_artifact_response_proto message_mock.return_value = None databricks_artifact_repo = get_artifact_repository(MOCK_SUBDIR_ROOT_URI) @@ -336,6 +338,47 @@ def test_list_artifacts_with_relative_path(self): message_mock.assert_called_with( ListArtifacts(run_id=MOCK_RUN_ID, path=os.path.join(MOCK_SUBDIR, "test"))) + def test_paginated_list_artifacts(self, databricks_artifact_repo): + list_artifacts_proto_mock_1 = [ + FileInfo(path='a.txt', is_dir=False, file_size=100), + FileInfo(path='b', is_dir=True, file_size=0) + ] + list_artifacts_proto_mock_2 = [ + FileInfo(path='c.txt', is_dir=False, file_size=100), + FileInfo(path='d', is_dir=True, file_size=0) + ] + list_artifacts_proto_mock_3 = [ + FileInfo(path='e.txt', is_dir=False, file_size=100), + FileInfo(path='f', is_dir=True, file_size=0) + ] + list_artifacts_proto_mock_4 = [] + with mock.patch( + DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + '.message_to_json')as message_mock, \ + mock.patch( + DATABRICKS_ARTIFACT_REPOSITORY + '._call_endpoint') as call_endpoint_mock: + list_artifact_paginated_response_protos = [ + ListArtifacts.Response(root_uri='', files=list_artifacts_proto_mock_1, + next_page_token='2'), + ListArtifacts.Response(root_uri='', files=list_artifacts_proto_mock_2, + next_page_token='4'), + ListArtifacts.Response(root_uri='', files=list_artifacts_proto_mock_3, + next_page_token='6'), + ListArtifacts.Response(root_uri='', files=list_artifacts_proto_mock_4, + next_page_token='8'), + ] + call_endpoint_mock.side_effect = list_artifact_paginated_response_protos + message_mock.return_value = None + artifacts = databricks_artifact_repo.list_artifacts() + assert set(['a.txt', 'b', 'c.txt', 'd', 'e.txt', 'f']) == set( + [file.path for file in artifacts]) + calls = [ + mock.call(ListArtifacts(run_id=MOCK_RUN_ID, path="")), + mock.call(ListArtifacts(run_id=MOCK_RUN_ID, path="", page_token='2')), + mock.call(ListArtifacts(run_id=MOCK_RUN_ID, path="", page_token='4')), + mock.call(ListArtifacts(run_id=MOCK_RUN_ID, path="", page_token='6')) + ] + message_mock.assert_has_calls(calls) + @pytest.mark.parametrize( "remote_file_path, local_path, cloud_credential_type", [ ('test_file.txt', '', ArtifactCredentialType.AZURE_SAS_URI), @@ -350,7 +393,7 @@ def test_databricks_download_file(self, databricks_artifact_repo, remote_file_pa as read_credentials_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._download_from_cloud') \ - as download_mock: + as download_mock: mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, type=cloud_credential_type) read_credentials_response_proto = GetCredentialsForRead.Response( @@ -372,10 +415,10 @@ def test_databricks_download_file_with_relative_path(self, remote_file_path, loc as get_run_artifact_root_mock, \ mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \ - as read_credentials_mock, \ + as read_credentials_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._download_from_cloud') \ - as download_mock: + as download_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI mock_credentials = ArtifactCredentialInfo(signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) From ec85f0601e94902829f5673212579dad1c7f44d7 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Mon, 15 Jun 2020 18:13:10 -0700 Subject: [PATCH 23/28] Fixing travis failures --- .../artifact/test_databricks_artifact_repo.py | 11 ++++++----- .../artifact/test_dbfs_artifact_repo_delegation.py | 14 ++++++++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index ecea79bea8627..22352a9e97794 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -4,6 +4,7 @@ from azure.storage.blob import BlobClient import mock import pytest +import posixpath from requests.models import Response from unittest.mock import ANY @@ -27,7 +28,7 @@ MOCK_RUN_ROOT_URI = \ "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts" MOCK_SUBDIR = "subdir/path" -MOCK_SUBDIR_ROOT_URI = os.path.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR) +MOCK_SUBDIR_ROOT_URI = posixpath.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR) @pytest.fixture() @@ -239,7 +240,7 @@ def test_log_artifact_aws_presigned_url_error(self, databricks_artifact_repo, te @pytest.mark.parametrize("artifact_path,expected_location", [ (None, os.path.join(MOCK_SUBDIR, "test.txt")), - ('test_path', os.path.join(MOCK_SUBDIR, "test_path/test.txt")), + ('test_path', posixpath.join(MOCK_SUBDIR, "test_path/test.txt")), ]) def test_log_artifact_with_relative_path(self, test_file, artifact_path, expected_location): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ @@ -271,13 +272,13 @@ def test_log_artifacts(self, databricks_artifact_repo, test_dir, artifact_path): log_artifact_mock.return_value = None databricks_artifact_repo.log_artifacts(test_dir.strpath, artifact_path) artifact_path = artifact_path or '' - calls = [mock.call(os.path.join(test_dir.strpath, 'empty-file'), + expected_calls = [mock.call(os.path.join(test_dir.strpath, 'empty-file'), os.path.join(artifact_path, '')), mock.call(os.path.join(test_dir.strpath, 'test.txt'), os.path.join(artifact_path, '')), mock.call(os.path.join(test_dir.strpath, 'subdir/test.txt'), os.path.join(artifact_path, 'subdir'))] - log_artifact_mock.assert_has_calls(calls) + assert log_artifact_mock.mock_calls == expected_calls def test_list_artifacts(self, databricks_artifact_repo): list_artifact_file_proto_mock = [FileInfo(path='a.txt', is_dir=False, file_size=0)] @@ -430,7 +431,7 @@ def test_databricks_download_file_with_relative_path(self, remote_file_path, loc databricks_artifact_repo = get_artifact_repository(MOCK_SUBDIR_ROOT_URI) databricks_artifact_repo.download_artifacts(remote_file_path, local_path) read_credentials_mock.assert_called_with(MOCK_RUN_ID, - os.path.join(MOCK_SUBDIR, remote_file_path)) + posixpath.join(MOCK_SUBDIR, remote_file_path)) download_mock.assert_called_with(mock_credentials, ANY) def test_databricks_download_file_get_request_fail(self, databricks_artifact_repo, test_file): diff --git a/tests/store/artifact/test_dbfs_artifact_repo_delegation.py b/tests/store/artifact/test_dbfs_artifact_repo_delegation.py index 419fd9469b396..8c999a5a0621c 100644 --- a/tests/store/artifact/test_dbfs_artifact_repo_delegation.py +++ b/tests/store/artifact/test_dbfs_artifact_repo_delegation.py @@ -41,7 +41,13 @@ def test_dbfs_artifact_repo_delegates_to_correct_repo( assert isinstance(rest_repo, DbfsRestArtifactRepository) assert rest_repo.artifact_uri == artifact_uri - artifact_uri = "dbfs:/databricks/mlflow-tracking/my/absolute/dbfs/path" - databricks_repo = get_artifact_repository(artifact_uri) - assert isinstance(databricks_repo, DatabricksArtifactRepository) - assert databricks_repo.artifact_uri == artifact_uri + with mock.patch( + "mlflow.store.artifact.databricks_artifact_repo" + + ".DatabricksArtifactRepository._get_run_artifact_root") \ + as get_run_artifact_root_mock: + mock_uri = \ + "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts" + get_run_artifact_root_mock.return_value = mock_uri + databricks_repo = get_artifact_repository(mock_uri) + assert isinstance(databricks_repo, DatabricksArtifactRepository) + assert databricks_repo.artifact_uri == mock_uri From c6585d68dcfd12ac2896a69e8460db82d8d7168b Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Mon, 15 Jun 2020 18:48:19 -0700 Subject: [PATCH 24/28] Fixes --- .../artifact/test_databricks_artifact_repo.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index 22352a9e97794..4d26a24e9a6e8 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -28,7 +28,7 @@ MOCK_RUN_ROOT_URI = \ "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts" MOCK_SUBDIR = "subdir/path" -MOCK_SUBDIR_ROOT_URI = posixpath.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR) +MOCK_SUBDIR_ROOT_URI = os.path.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR) @pytest.fixture() @@ -238,9 +238,9 @@ def test_log_artifact_aws_presigned_url_error(self, databricks_artifact_repo, te databricks_artifact_repo.log_artifact(test_file.strpath) write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY) - @pytest.mark.parametrize("artifact_path,expected_location", [ + @pytest.mark.parametrize("artifact_path, expected_location", [ (None, os.path.join(MOCK_SUBDIR, "test.txt")), - ('test_path', posixpath.join(MOCK_SUBDIR, "test_path/test.txt")), + ('test_path', os.path.join(MOCK_SUBDIR, "test_path/test.txt")), ]) def test_log_artifact_with_relative_path(self, test_file, artifact_path, expected_location): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ @@ -273,11 +273,11 @@ def test_log_artifacts(self, databricks_artifact_repo, test_dir, artifact_path): databricks_artifact_repo.log_artifacts(test_dir.strpath, artifact_path) artifact_path = artifact_path or '' expected_calls = [mock.call(os.path.join(test_dir.strpath, 'empty-file'), - os.path.join(artifact_path, '')), - mock.call(os.path.join(test_dir.strpath, 'test.txt'), - os.path.join(artifact_path, '')), - mock.call(os.path.join(test_dir.strpath, 'subdir/test.txt'), - os.path.join(artifact_path, 'subdir'))] + os.path.join(artifact_path, '')), + mock.call(os.path.join(test_dir.strpath, 'test.txt'), + os.path.join(artifact_path, '')), + mock.call(os.path.join(test_dir.strpath, 'subdir/test.txt'), + os.path.join(artifact_path, 'subdir'))] assert log_artifact_mock.mock_calls == expected_calls def test_list_artifacts(self, databricks_artifact_repo): From 51ca61f95a7c1cbd688c5ce825f533a3a8de7af4 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Mon, 15 Jun 2020 18:56:06 -0700 Subject: [PATCH 25/28] More fixes --- tests/store/artifact/test_databricks_artifact_repo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index 4d26a24e9a6e8..8ad0b01a884d1 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -28,7 +28,7 @@ MOCK_RUN_ROOT_URI = \ "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts" MOCK_SUBDIR = "subdir/path" -MOCK_SUBDIR_ROOT_URI = os.path.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR) +MOCK_SUBDIR_ROOT_URI = posixpath.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR) @pytest.fixture() From d1b9dfd10f60f59b56e2c348ed346ba93b32425a Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Mon, 15 Jun 2020 19:34:04 -0700 Subject: [PATCH 26/28] More fixes --- tests/store/artifact/test_databricks_artifact_repo.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index 8ad0b01a884d1..d9aa7a8653b9e 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -239,8 +239,8 @@ def test_log_artifact_aws_presigned_url_error(self, databricks_artifact_repo, te write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY) @pytest.mark.parametrize("artifact_path, expected_location", [ - (None, os.path.join(MOCK_SUBDIR, "test.txt")), - ('test_path', os.path.join(MOCK_SUBDIR, "test_path/test.txt")), + (None, posixpath.join(MOCK_SUBDIR, "test.txt")), + ('test_path', posixpath.join(MOCK_SUBDIR, "test_path/test.txt")), ]) def test_log_artifact_with_relative_path(self, test_file, artifact_path, expected_location): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ @@ -276,7 +276,8 @@ def test_log_artifacts(self, databricks_artifact_repo, test_dir, artifact_path): os.path.join(artifact_path, '')), mock.call(os.path.join(test_dir.strpath, 'test.txt'), os.path.join(artifact_path, '')), - mock.call(os.path.join(test_dir.strpath, 'subdir/test.txt'), + mock.call(os.path.join(test_dir.strpath, + os.path.join('subdir', 'test.txt')), os.path.join(artifact_path, 'subdir'))] assert log_artifact_mock.mock_calls == expected_calls @@ -293,10 +294,10 @@ def test_list_artifacts(self, databricks_artifact_repo): assert isinstance(artifacts, list) assert isinstance(artifacts[0], FileInfoEntity) assert len(artifacts) == 2 - assert artifacts[0].path == 'test/a.txt' + assert artifacts[0].path == os.path.join('test', 'a.txt') assert artifacts[0].is_dir is False assert artifacts[0].file_size == 100 - assert artifacts[1].path == 'test/dir' + assert artifacts[1].path == os.path.join('test', 'dir') assert artifacts[1].is_dir is True assert artifacts[1].file_size is None From 72ecab57d6bbe2f51b9e997bb34ee621b4653222 Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Mon, 15 Jun 2020 20:03:31 -0700 Subject: [PATCH 27/28] Clean-up --- tests/store/artifact/test_databricks_artifact_repo.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index d9aa7a8653b9e..1a7486c7a44ff 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -294,10 +294,10 @@ def test_list_artifacts(self, databricks_artifact_repo): assert isinstance(artifacts, list) assert isinstance(artifacts[0], FileInfoEntity) assert len(artifacts) == 2 - assert artifacts[0].path == os.path.join('test', 'a.txt') + assert artifacts[0].path == 'test/a.txt' assert artifacts[0].is_dir is False assert artifacts[0].file_size == 100 - assert artifacts[1].path == os.path.join('test', 'dir') + assert artifacts[1].path == 'test/dir' assert artifacts[1].is_dir is True assert artifacts[1].file_size is None @@ -311,8 +311,8 @@ def test_list_artifacts(self, databricks_artifact_repo): def test_list_artifacts_with_relative_path(self): list_artifacts_dir_proto_mock = [ - FileInfo(path=os.path.join(MOCK_SUBDIR, 'test/a.txt'), is_dir=False, file_size=100), - FileInfo(path=os.path.join(MOCK_SUBDIR, 'test/dir'), is_dir=True, file_size=0) + FileInfo(path=posixpath.join(MOCK_SUBDIR, 'test/a.txt'), is_dir=False, file_size=100), + FileInfo(path=posixpath.join(MOCK_SUBDIR, 'test/dir'), is_dir=True, file_size=0) ] with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ as get_run_artifact_root_mock, \ From dfc7f6039b20fcfa0282c13f7d30a6f07a1bc1ad Mon Sep 17 00:00:00 2001 From: arjundc-db Date: Mon, 15 Jun 2020 20:14:37 -0700 Subject: [PATCH 28/28] Clean-up --- tests/store/artifact/test_databricks_artifact_repo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index 1a7486c7a44ff..2fd877245a891 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -338,7 +338,7 @@ def test_list_artifacts_with_relative_path(self): assert artifacts[1].is_dir is True assert artifacts[1].file_size is None message_mock.assert_called_with( - ListArtifacts(run_id=MOCK_RUN_ID, path=os.path.join(MOCK_SUBDIR, "test"))) + ListArtifacts(run_id=MOCK_RUN_ID, path=posixpath.join(MOCK_SUBDIR, "test"))) def test_paginated_list_artifacts(self, databricks_artifact_repo): list_artifacts_proto_mock_1 = [