-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Databricks artifact repo #5
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arjundc-db Awesome work! I left a few comments; I still need to get my hands dirty with this against a Databricks environment; I'll go ahead and do that shortly. Let me know if you have any questions about the comments!
super(DatabricksArtifactRepository, self).__init__(artifact_uri) | ||
|
||
def _extract_run_id(self, artifact_uri): | ||
return artifact_uri.lstrip('/').split('/')[4] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arjundc-db Awesome that we're extracting the run ID from the URI! A quick suggestion here - can we: 1. use the urlparse library to fetch the path component of the URI and 2. normalize the path before performing the splitting? E.g.,
import posixpath
from urllib.parse import urlparse
parsed_uri = urlparse(artifact_uri)
parsed_path = posixpath.normpath(parsed_uri.path)
return parsed_path.lstrip("/").split("/")[4]
The reasoning for #1 here is that DBFS URIs support empty hostnames; e.g., dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID
is valid (note the 3 leading slashes). I confirmed this by running %fs ls dbfs:///databricks
in a Databricks notebook. The current solution yields the incorrect result when applied directly to a URI of this form:
In [1]: url = "dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts"
In [2]: url.lstrip('/').split('/')[4]
Out[2]: 'mlflow-tracking' # Should be 'RUN_ID'
Similarly, I've confirmed that dbfs://databricks/
(two leading slashes) and dbfs:databricks/...
(no leading slashes) are invalid use cases; Databricks rejects URIs of the first form because they contain non-empty hostnames, which are unsupported. Databricks rejects URIs of the second form because they contain relative paths. Databricks requires URIs to specify absolute DBFS paths. This means that we don't have to worry about relative paths, which may not contain a leading slash.
The reasoning for #2 here is that URI paths may contain redundant slashes. For example, dbfs:/databricks/mlflow-tracking////EXP_ID/RUN_ID
is a valid URI that is semantically equivalent to dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID
. Normalizing the path will remove redundant slashes. Without normalizing the path, we currently get an incorrect result:
In [1]: url = "dbfs:/databricks////mlflow-tracking/EXP_ID/RUN_ID/"
In [2]: url.lstrip('/').split('/')[4]
Out [2]: ''
For more information about the structure of URIs, https://en.wikipedia.org/wiki/Uniform_Resource_Identifier is an awesome reference. Let me know if you have any questions here.
These kinds of URI edge cases would be great to handle via a unit test. Let's make sure we plan to add a unit test for this case before merging this functionality!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once we correct the behavior here, can we also add an inline comment that clearly documents the assumptions we're making about the URI structure? E.g., URIs are assumed to be semantically equivalent to dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts/...
? This will help future readers of the code interpret the behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, this was pretty cool to read about, thanks!
"path": path | ||
} | ||
|
||
def _get_azure_write_credentials(self, run_id, path=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method calls an MLflow Artifacts Service endpoint that is not directly bound to Azure; it may also return AWS credentials. The type of credential is specified by the type
field of the ArtifactCredentialInfo
response. Accordingly, I think we should call this _get_write_credentials
.
Callers of this method should check the type
field of the response and then perform the corresponding cloud operation (either AWS or Azure). The current task is to implement Azure uploads / downloads, so it's fine to leave the AWS case unimplemented (e.g., simply raise a "Not Implemented" exception in the code). It would still be great to structure the upload / download logic so that it accounts for the fact that credentials may be provided for different cloud services. Let me know if this makes sense!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForWrite, | ||
self._create_json_body(run_id, path)) | ||
|
||
def _get_azure_read_credentials(self, run_id, path=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method calls an MLflow Artifacts Service endpoint that is not directly bound to Azure; it may also return AWS credentials. The type of credential is specified by the type
field of the ArtifactCredentialInfo
response. Accordingly, I think we should call this _get_read_credentials
.
Callers of this method should check the type
field of the response and then perform the corresponding cloud operation (either AWS or Azure). The current task is to implement Azure uploads / downloads, so it's fine to leave the AWS case unimplemented (e.g., simply raise a "Not Implemented" exception in the code). It would still be great to structure the upload / download logic so that it accounts for the fact that credentials may be provided for different cloud services. Let me know if this makes sense!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
mlflow/utils/uri.py
Outdated
@@ -129,3 +129,7 @@ def _join_posixpaths_and_append_absolute_suffixes(prefix_path, suffix_path): | |||
# joined path | |||
suffix_path = suffix_path.lstrip(posixpath.sep) | |||
return posixpath.join(prefix_path, suffix_path) | |||
|
|||
|
|||
def is_artifact_acled_uri(artifact_uri): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the same vein as the comment above about semantically equivalent URIs (e.g., dbfs:///databricks/mlflow-tracking////EXP_ID/RUN_ID/...
), this check will fail. Can we do the following instead?:
- Parse the URI
- Extract the scheme and verify that it's
dbfs
- Extract the path, normalize it via
posixpath.normpath()
- Verify that the normalized path has the expected prefix
It would also be great to unit test this method against various semantically equivalent URIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, however, I do not think step 2 is necessary since https://livegrep.dev.databricks.com/view/mlflow/mlflow/mlflow/store/artifact/artifact_repository_registry.py#L63 will only call dbfs_artifact_repo_factory
when the scheme is dbfs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arjundc-db I think we're relying on an assumption about exactly where this function is going to be called. Future callers of this function are not guaranteed to perform dbfs
scheme validation beforehand. Accordingly, we have a few options:
- Perform the scheme validation in this method
- Do not perform the scheme validation in this method and clearly document (via the method docstring) that this method assumes it is being passed a
dbfs:/
URI and will not validate the scheme - If we don't foresee this method being called elsewhere, we can define it within the
dbfs_artifact_repo_factory()
method as a subroutine. This way, we can safely assume that we're already working with adbfs
URI (as documented by thedbfs_artifact_repo_factory
docstring).
I'd vote for either option 1 or option 3!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, that makes sense. Using option 3.
mlflow/utils/uri.py
Outdated
@@ -129,3 +129,7 @@ def _join_posixpaths_and_append_absolute_suffixes(prefix_path, suffix_path): | |||
# joined path | |||
suffix_path = suffix_path.lstrip(posixpath.sep) | |||
return posixpath.join(prefix_path, suffix_path) | |||
|
|||
|
|||
def is_artifact_acled_uri(artifact_uri): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: can we rename this to is_databricks_acled_artifacts_uri
? This helps to clarify that we're checking for a piece of Databricks-specific functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
mlflow/utils/uri.py
Outdated
@@ -8,7 +8,7 @@ | |||
|
|||
_INVALID_DB_URI_MSG = "Please refer to https://mlflow.org/docs/latest/tracking.html#storage for " \ | |||
"format specifications." | |||
|
|||
_ACLED_ARTIFACT_URI = "dbfs:/databricks/mlflow-tracking/" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can we move this into is_artifact_acled_uri()
? It doesn't seem to be required by more than one function in this module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
return artifact_uri.lstrip('/').split('/')[4] | ||
|
||
def _call_endpoint(self, service, api, json_body): | ||
_METHOD_TO_INFO = extract_api_info_for_service(service, _PATH_PREFIX) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This extract_api_info_for_service
method is a bit expensive (https://github.com/mlflow/mlflow/blob/c4e0cbdf5c92b3b81f7080fff4a974114605d2db/mlflow/utils/rest_utils.py#L115). Can we instead define a module-level _SERVICE_AND_METHOD_TO_INFO
dictionary as:
_SERVICE_AND_METHOD_TO_INFO = {
service: extract_api_info_for_service(service, _PATH_PREFIX)
for service in [MlflowService, DatabricksMlflowArtifactsService]
}
We can then fetch the info as _SERVICE_AND_METHOD_TO_INFO[service][api]
, without performing proto extraction operations each time this method is called.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
return call_endpoint(get_databricks_host_creds(), endpoint, method, json_body, response_proto) | ||
|
||
def _create_json_body(self, run_id, path=None): | ||
path = path or '.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
path = path or '.' | |
path = path or "" |
"" is more idiomatic for representing empty than "."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Codecov Report
@@ Coverage Diff @@
## master #5 +/- ##
=======================================
Coverage 85.04% 85.04%
=======================================
Files 20 20
Lines 1050 1050
=======================================
Hits 893 893
Misses 157 157 Continue to review full report at Codecov.
|
service: extract_api_info_for_service(service, _PATH_PREFIX) | ||
for service in [MlflowService, DatabricksMlflowArtifactsService] | ||
} | ||
self.credential_type_to_cloud_service = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be deleted. I was considering a dictionary (Cloud_Provider -> upload/download_function) but it looked too messy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arjundc-db Awesome progress! I left a few more comments. Still need to take a look at the tests, but this is awesome!
if os.path.getsize(local_file) < _AZURE_SINGLE_BLOCK_BLOB_MAX_SIZE: | ||
signed_write_uri = credentials.signed_uri | ||
service = BlobClient.from_blob_url(blob_url=signed_write_uri, credential=None) | ||
with open(local_file, "rb") as data: | ||
service.upload_blob(data, overwrite=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arjundc-db Examining the documentation of BlobClient.upload_blob()
(https://docs.microsoft.com/en-us/python/api/azure-storage-blob/azure.storage.blob.blobclient?view=azure-python#upload-blob-data--blob-type--blobtype-blockblob---blockblob----length-none--metadata-none----kwargs-), it appears that this method performs a block blob upload with automatic chunking (e.g., multiple chunks may be used for a single file of size < 256MB).
This Python client API is not equivalent to the single-part blob upload API (PutBlob
): https://docs.microsoft.com/en-us/rest/api/storageservices/put-blob.
Accordingly, I think it's probably best to unconditionally use the 100MB self-chunking approach regardless of aggregate file size. Let me know if you have any questions about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep makes sense, done.
return call_endpoint(get_databricks_host_creds(), | ||
endpoint, method, json_body, response_proto) | ||
|
||
def _create_json_body(self, run_id, path=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arjundc-db I think this is unused. Can we remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks for the catch.
'databricks/mlflow-tracking/path/to/artifact/..') | ||
|
||
self.run_id = self._extract_run_id() | ||
self._SERVICE_AND_METHOD_TO_INFO = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we define this dictionary at the module level rather than the class level? The contents of this dictionary should not changed based on the particular instance that's using it. Often times, artifact repositories are instantiated many times over the course of a single MLflow training session, so it would be great to avoid having this overhead every time an artifact repository is created.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Cool point!
from mlflow.utils.databricks_utils import get_databricks_host_creds | ||
|
||
_PATH_PREFIX = "/api/2.0" | ||
_AZURE_SINGLE_BLOCK_BLOB_MAX_SIZE = 256000000 - 1 # Can upload blob in single request if it is no more thn 256 MB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Typo - thn
>> than
(though I think we'll want to remove this constant anyway along with the < 256MB
case - see comment below).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, removing it.
pass | ||
|
||
def _upload_to_cloud(self, cloud_credentials, local_file, artifact_path): | ||
if cloud_credentials.credentials.type == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can we do this instead for readability?:
if cloud_credentials.credentials.type == 1: | |
if cloud_credentials.credentials.type == ArtifactCredentials.AZURE_SAS_URI: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
raise MlflowException('Not implemented yet') | ||
|
||
def _download_from_cloud(self, cloud_credentials, local_path): | ||
if cloud_credentials.credentials.type == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can we do this instead for readability?:
if cloud_credentials.credentials.type == 1: | |
if cloud_credentials.credentials.type == ArtifactCredentials.AZURE_SAS_URI: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
else: | ||
uploading_block_list = list() | ||
for chunk in yield_file_in_chunks(local_file, _AZURE_MAX_BLOCK_CHUNK_SIZE): | ||
signed_write_uri = self._get_write_credentials(self.run_id, artifact_path).credentials.signed_uri |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arjundc-db Can we use the credentials that were passed as a parameter to this _azure_upload_file
method for the first chunk? Otherwise, I think we're wasting that set of credentials (it's only being used to determine whether the credentials are for AWS or Azure).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the current approach I wanted to make _azure_upload_file independent of the initial get credentials. However, after the refactoring (and generating new creds only when the old ones expire), it makes sense to use the older credentials itself. Done.
with open(local_file, "rb") as data: | ||
service.upload_blob(data, overwrite=True) | ||
else: | ||
uploading_block_list = list() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arjundc-db This logic is awesome! I think it could really benefit from some inline comments explaining why we're chunking the file ourselves, why we fetch credentials when we do (once before each chunk and once before the block list commit operation). Can you add these here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
# If `path` is a file, ListArtifacts returns a single list element with the | ||
# same name as `path`. The list_artifacts API expects us to return an empty list in this | ||
# case, so we do so here. | ||
if len(artifact_list) == 1 and artifact_list[0].path == path: | ||
return [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch!
Edit: Saw this logic is used in dbfs_artifact_repo
as well - unfortunately, it looks like it's a bit brittle. To test it, I created a file called foo
(it was created inside a directory also called arty/foo
, but that's not important). Then, I ran the following code:
In[0]: client.list_artifacts(run_id, "arty/foo/foo")
Out[0]: []
In[1]: client.list_artifacts(run_id, "arty/foo/foo////")
Out[1]: [<FileInfo: file_size=5, is_dir=False, path='arty/foo/foo'>]
Looks like those trailing slashes break the path comparison. However, it's a bit strange that ListArtifacts
returns the file path arty/foo/foo
for an input of arty/foo/foo/////
, since this output path is a file (for files, test/
and test//
are not equivalent...) and arty/foo/foo/////
is not a prefix of arty/foo/foo
. I think we should file a bug report ticket for the ListArtifacts API for this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense. To fix this, does it make sense to perform a regex check, since if a users wants to download a single file, the format will :probably: end with a .txt, .py, .pkl to name a few. This could be a restriction if a single file needs to be downloaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments on tests / inline docs - most are very minor! Great work!
@@ -159,11 +161,23 @@ def dbfs_artifact_repo_factory(artifact_uri): | |||
|
|||
This factory method is used with URIs of the form ``dbfs:/<path>``. DBFS-backed artifact | |||
storage can only be used together with the RestStore. | |||
|
|||
In the special case where the URI is of the form | |||
``dbfs:/databricks/mlflow-tracking/<Exp-ID>/<Run-ID>/<path>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hyper nit:
``dbfs:/databricks/mlflow-tracking/<Exp-ID>/<Run-ID>/<path>, | |
`dbfs:/databricks/mlflow-tracking/<Exp-ID>/<Run-ID>/<path>`, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
LIST_ARTIFACTS_PROTO_RESPONSE = [FileInfo(path='test/a.txt', is_dir=False, file_size=100), | ||
FileInfo(path='test/dir', is_dir=True, file_size=0)] | ||
|
||
LIST_ARTIFACTS_SINGLE_FILE_PROTO_RESPONSE = [FileInfo(path='a.txt', is_dir=False, file_size=0)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: If we don't foresee these constants being used in multiple test cases, we may as well move their contents into the relevant test cases; this makes it easier to understand what's going on in each individual test case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got you. Done.
TEST_FILE_1_CONTENT = u"Hello 🍆🍔".encode("utf-8") | ||
TEST_FILE_2_CONTENT = u"World 🍆🍔🍆".encode("utf-8") | ||
TEST_FILE_3_CONTENT = u"¡🍆🍆🍔🍆🍆!".encode("utf-8") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like we don't ever read this content within any of the test cases. Accordingly, I think we can move these constants directly into the test_file
and test_dir
fixtures to clean things up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arjundc-db This is looking great! I left a few nit comments about documentation and testing, as well as some path handling stuff (posixpath
vs os.path
). Once these are addressed, we should be all set to file a PR against the main MLflow repo and work towards getting it merged (there's a bit of work I have to do to support a paginated list_artifacts()
API before merging.
""" | ||
The artifact_uri is expected to be | ||
dbfs:/databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path> | ||
Once the path from the inputted uri is extracted and normalized, is is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
Once the path from the inputted uri is extracted and normalized, is is | |
Once the path from the input uri is extracted and normalized, it is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
'databricks/mlflow-tracking/path/to/artifact/..') | ||
self.run_id = self._extract_run_id() | ||
|
||
def _extract_run_id(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this a static method that takes the URI as an input parameter? Defining a no-argument method that operates on a class property to return a new value feels a bit confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I rather make it a helper function in the file, something like https://livegrep.dev.databricks.com/view/mlflow/mlflow/mlflow/store/artifact/dbfs_artifact_repo.py#L147 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a static function within the class makes the most sense, since this is only used within the class itself. The @staticmethod
decorator may be useful here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do!
if mlflow.utils.databricks_utils.is_dbfs_fuse_available() \ | ||
uri_scheme = get_uri_scheme(artifact_uri) | ||
if uri_scheme != 'dbfs': | ||
raise Exception("DBFS URI must be of the form " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be an MlflowException
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
with requests.get(signed_read_uri, stream=True) as response: | ||
response.raise_for_status() | ||
with open(local_file, "wb") as output_file: | ||
for chunk in response.iter_content(chunk_size=_AZURE_MAX_BLOCK_CHUNK_SIZE): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arjundc-db Can we add a comment right about this line explaining why we're leveraging iter_content()
rather than reading all of the content at once?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
raise MlflowException('DatabricksArtifactRepository URI must start with dbfs:/') | ||
if not is_databricks_acled_artifacts_uri(artifact_uri): | ||
raise MlflowException('Artifact URI incorrect. Expected path prefix to be ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can we specify error_code=INVALID_PARAMETER_VALUE
for these exceptions? Here's an example: https://github.com/mlflow/mlflow/blob/88858b529bc1b49f37a4a3e7087649f85e62bf1a/mlflow/sklearn.py#L142
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
def download_file(fullpath): | ||
fullpath = fullpath.rstrip('/') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome! Can we add a brief inline comment above this line explaining why we're stripping trailing slashes from the suffix, just so it's clear for folks reading this code later on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
def log_artifact(self, local_file, artifact_path=None): | ||
basename = os.path.basename(local_file) | ||
artifact_path = artifact_path or "" | ||
artifact_path = os.path.join(artifact_path, basename) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
artifact_path
is used to define the destination for the uploaded artifacts on DBFS; DBFS is a posix filesystem and uses forward slashes as separators. os.path.join()
will use the caller's native filesystem for path concatenation, which will result in the artifact path being joined with a backwards slash on Windows systems. Accordingly, we should call posixpath.join()
here instead.
Note that the preceding call to os.path.basename()
on line 154 is perfectly acceptable/correct since that line is resolving the base name for a path on the caller's local filesystem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty cool to learn about, thanks for the explanation.
if dirpath != local_dir: | ||
rel_path = os.path.relpath(dirpath, local_dir) | ||
rel_path = relative_path_to_artifact_path(rel_path) | ||
artifact_subdir = os.path.join(artifact_path, rel_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As in the comment above, artifact_subdir
is used to define the destination for the uploaded artifacts on DBFS; DBFS is a posix filesystem and uses forward slashes as separators. os.path.join()
will use the caller's native filesystem for path concatenation, which will result in the artifact path being joined with a backwards slash on Windows systems. Accordingly, we should call posixpath.join()
here instead.
Note that the preceding call to os.path.relpath
is correct since it is resolving the relative location of a subdirectory to a parent directory on the caller's local filesystem. This relative path (which may contain backslashes on Windows systems) is then converted to a POSIX-style path component of a URI by the relative_path_to_artifact_path
function, which invokes urllib.pathname2url to perform the conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks!
|
||
class DatabricksArtifactRepository(ArtifactRepository): | ||
""" | ||
Stores artifacts on Azure/AWS with access control. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I think we can probably have a better description here, such as:
Performs storage operations on artifacts in the access-controlled `dbfs:/databricks/mlflow-tracking` location. Signed access URIs for S3 / Azure Blob Storage are fetched from the MLflow service and used to read and write files from/to this location.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
service.stage_block(block_id, chunk, headers=headers) | ||
uploading_block_list.append(block_id) | ||
try: | ||
service.commit_block_list(uploading_block_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After doing some digging through the docs, I think we should include the headers on the commit_block_list
operation, since these headers are likely to be used for encryption information; I'd imagine that this information is just as relevant for a block commit as it is for a block upload. Sorry for the runaround here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
credentials = self._get_write_credentials(self.run_id, | ||
artifact_path).credentials.signed_uri | ||
service = BlobClient.from_blob_url(blob_url=credentials, credential=None) | ||
service.commit_block_list(uploading_block_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After doing some digging through the docs, I think we should include the headers on the commit_block_list
operation, since these headers are likely to be used for encryption information; I'd imagine that this information is just as relevant for a block commit as it is for a block upload. Sorry for the runaround here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
headers = dict() | ||
for header in credential.headers: | ||
headers[header.name] = header.value | ||
return headers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
headers = dict() | |
for header in credential.headers: | |
headers[header.name] = header.value | |
return headers | |
return { | |
header.name: header.value | |
for header in headers | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
elif cloud_credentials.credentials.type == ArtifactCredentialType.AWS_PRESIGNED_URL: | ||
self._aws_upload_file(cloud_credentials.credentials, local_file) | ||
else: | ||
raise MlflowException('Not implemented yet') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we do in _download_from_cloud
, we should use Cloud provider not supported
in the exception text.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
if cloud_credential.type not in [ArtifactCredentialType.AZURE_SAS_URI, | ||
ArtifactCredentialType.AWS_PRESIGNED_URL]: | ||
raise MlflowException(message='Cloud provider not supported.', | ||
error_code=INVALID_PARAMETER_VALUE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice exception! I think the error code here should be INTERNAL_ERROR
in this case (the default error code), since the user has no control over this failure mode. If the MLflow service suddenly starts passing strange credential types to the client, there's nothing the end user can do about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
aws_upload_mock.return_value = None | ||
databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) | ||
write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) | ||
aws_upload_mock.assert_called_with(mock_credentials, test_file.strpath) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arjundc-db Is it possible to dig into the _aws_upload_file
method a little further and assert that the expected headers are passed to requests.put
? This would be a great sanity check that we're correctly handling headers in the client. I think it would be great to do this with a separate test case such as test_log_artifact_aws_with_headers
.
A similar test_log_artifact_azure_with_headers
case would also be awesome. Hopefully this doesn't require too much mocking - presumably, just requests.put
for AWS and a couple of BlobClient
operations for Azure, since we're already creating actual files in the tmpdir
test fixture; let me know if you can give this a shot!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! wew
write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) | ||
aws_upload_mock.assert_called_with(mock_credentials, test_file.strpath) | ||
|
||
def test_log_artifact_fail_case(self, databricks_artifact_repo, test_file, ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we be more descriptive with this test case name? What is the failure mode we're looking at? (e.g., unexpected exceptions encountered during artifact uploading)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) | ||
aws_upload_mock.assert_called_with(mock_credentials, test_file.strpath) | ||
|
||
def test_log_artifact_fail_case(self, databricks_artifact_repo, test_file, ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
def test_log_artifact_fail_case(self, databricks_artifact_repo, test_file, ): | |
def test_log_artifact_fail_case(self, databricks_artifact_repo, test_file): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
mock_blob_service.from_blob_url().return_value = MlflowException("MOCK ERROR") | ||
with pytest.raises(MlflowException): | ||
databricks_artifact_repo.log_artifact(test_file.strpath) | ||
write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this assert
function will be called since the pytest.raises
scope terminates as soon as an exception is encountered. You should be able to fix this by moving the assert
call out of the pytest.raises
context (i.e., unindent 1 level).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Path resolution fixes for DatabricksArtifactRepository
…-db/mlflow into databricks-artifact-repo
Paginated client
…-db/mlflow into databricks-artifact-repo
…-db/mlflow into databricks-artifact-repo
Signed-off-by: Saransh Sharma <[email protected]> Co-authored-by: Saransh Sharma <[email protected]>
What changes are proposed in this pull request?
(Please fill in changes proposed in this fix)
How is this patch tested?
(Details)
Release Notes
Is this a user-facing change?
(Details in 1-2 sentences. You can just refer to another PR with a description if this PR is part of a larger change.)
What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls forModel Registry
area/models
: MLmodel format, model serialization/deserialization, flavorsarea/projects
: MLproject format, project running backendsarea/scoring
: Local serving, model deployment tools, spark UDFsarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, JavaScript, plottingarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientsIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsHow should the PR be classified in the release notes? Choose one:
rn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes