Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Databricks artifact repo #5

Open
wants to merge 34 commits into
base: master
Choose a base branch
from

Conversation

dbczumar
Copy link
Owner

What changes are proposed in this pull request?

(Please fill in changes proposed in this fix)

How is this patch tested?

(Details)

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

(Details in 1-2 sentences. You can just refer to another PR with a description if this PR is part of a larger change.)

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for
    Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/projects: MLproject format, project running backends
  • area/scoring: Local serving, model deployment tools, spark UDFs
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, JavaScript, plotting
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations

How should the PR be classified in the release notes? Choose one:

  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

Copy link
Owner Author

@dbczumar dbczumar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arjundc-db Awesome work! I left a few comments; I still need to get my hands dirty with this against a Databricks environment; I'll go ahead and do that shortly. Let me know if you have any questions about the comments!

super(DatabricksArtifactRepository, self).__init__(artifact_uri)

def _extract_run_id(self, artifact_uri):
return artifact_uri.lstrip('/').split('/')[4]
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arjundc-db Awesome that we're extracting the run ID from the URI! A quick suggestion here - can we: 1. use the urlparse library to fetch the path component of the URI and 2. normalize the path before performing the splitting? E.g.,

import posixpath
from urllib.parse import urlparse
parsed_uri = urlparse(artifact_uri)
parsed_path = posixpath.normpath(parsed_uri.path)
return parsed_path.lstrip("/").split("/")[4]

The reasoning for #1 here is that DBFS URIs support empty hostnames; e.g., dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID is valid (note the 3 leading slashes). I confirmed this by running %fs ls dbfs:///databricks in a Databricks notebook. The current solution yields the incorrect result when applied directly to a URI of this form:

In [1]: url = "dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts"
In [2]: url.lstrip('/').split('/')[4]
Out[2]: 'mlflow-tracking' # Should be 'RUN_ID'

Similarly, I've confirmed that dbfs://databricks/ (two leading slashes) and dbfs:databricks/... (no leading slashes) are invalid use cases; Databricks rejects URIs of the first form because they contain non-empty hostnames, which are unsupported. Databricks rejects URIs of the second form because they contain relative paths. Databricks requires URIs to specify absolute DBFS paths. This means that we don't have to worry about relative paths, which may not contain a leading slash.

The reasoning for #2 here is that URI paths may contain redundant slashes. For example, dbfs:/databricks/mlflow-tracking////EXP_ID/RUN_ID is a valid URI that is semantically equivalent to dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID. Normalizing the path will remove redundant slashes. Without normalizing the path, we currently get an incorrect result:

In [1]: url = "dbfs:/databricks////mlflow-tracking/EXP_ID/RUN_ID/"
In [2]: url.lstrip('/').split('/')[4]
Out [2]: ''

For more information about the structure of URIs, https://en.wikipedia.org/wiki/Uniform_Resource_Identifier is an awesome reference. Let me know if you have any questions here.

These kinds of URI edge cases would be great to handle via a unit test. Let's make sure we plan to add a unit test for this case before merging this functionality!

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we correct the behavior here, can we also add an inline comment that clearly documents the assumptions we're making about the URI structure? E.g., URIs are assumed to be semantically equivalent to dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts/... ? This will help future readers of the code interpret the behavior.

Copy link

@arjundc-db arjundc-db May 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, this was pretty cool to read about, thanks!

"path": path
}

def _get_azure_write_credentials(self, run_id, path=None):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method calls an MLflow Artifacts Service endpoint that is not directly bound to Azure; it may also return AWS credentials. The type of credential is specified by the type field of the ArtifactCredentialInfo response. Accordingly, I think we should call this _get_write_credentials.

Callers of this method should check the type field of the response and then perform the corresponding cloud operation (either AWS or Azure). The current task is to implement Azure uploads / downloads, so it's fine to leave the AWS case unimplemented (e.g., simply raise a "Not Implemented" exception in the code). It would still be great to structure the upload / download logic so that it accounts for the fact that credentials may be provided for different cloud services. Let me know if this makes sense!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForWrite,
self._create_json_body(run_id, path))

def _get_azure_read_credentials(self, run_id, path=None):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method calls an MLflow Artifacts Service endpoint that is not directly bound to Azure; it may also return AWS credentials. The type of credential is specified by the type field of the ArtifactCredentialInfo response. Accordingly, I think we should call this _get_read_credentials.

Callers of this method should check the type field of the response and then perform the corresponding cloud operation (either AWS or Azure). The current task is to implement Azure uploads / downloads, so it's fine to leave the AWS case unimplemented (e.g., simply raise a "Not Implemented" exception in the code). It would still be great to structure the upload / download logic so that it accounts for the fact that credentials may be provided for different cloud services. Let me know if this makes sense!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -129,3 +129,7 @@ def _join_posixpaths_and_append_absolute_suffixes(prefix_path, suffix_path):
# joined path
suffix_path = suffix_path.lstrip(posixpath.sep)
return posixpath.join(prefix_path, suffix_path)


def is_artifact_acled_uri(artifact_uri):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the same vein as the comment above about semantically equivalent URIs (e.g., dbfs:///databricks/mlflow-tracking////EXP_ID/RUN_ID/...), this check will fail. Can we do the following instead?:

  1. Parse the URI
  2. Extract the scheme and verify that it's dbfs
  3. Extract the path, normalize it via posixpath.normpath()
  4. Verify that the normalized path has the expected prefix

It would also be great to unit test this method against various semantically equivalent URIs.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, however, I do not think step 2 is necessary since https://livegrep.dev.databricks.com/view/mlflow/mlflow/mlflow/store/artifact/artifact_repository_registry.py#L63 will only call dbfs_artifact_repo_factory when the scheme is dbfs.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arjundc-db I think we're relying on an assumption about exactly where this function is going to be called. Future callers of this function are not guaranteed to perform dbfs scheme validation beforehand. Accordingly, we have a few options:

  1. Perform the scheme validation in this method
  2. Do not perform the scheme validation in this method and clearly document (via the method docstring) that this method assumes it is being passed a dbfs:/ URI and will not validate the scheme
  3. If we don't foresee this method being called elsewhere, we can define it within the dbfs_artifact_repo_factory() method as a subroutine. This way, we can safely assume that we're already working with a dbfs URI (as documented by the dbfs_artifact_repo_factory docstring).

I'd vote for either option 1 or option 3!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that makes sense. Using option 3.

@@ -129,3 +129,7 @@ def _join_posixpaths_and_append_absolute_suffixes(prefix_path, suffix_path):
# joined path
suffix_path = suffix_path.lstrip(posixpath.sep)
return posixpath.join(prefix_path, suffix_path)


def is_artifact_acled_uri(artifact_uri):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: can we rename this to is_databricks_acled_artifacts_uri? This helps to clarify that we're checking for a piece of Databricks-specific functionality.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -8,7 +8,7 @@

_INVALID_DB_URI_MSG = "Please refer to https://mlflow.org/docs/latest/tracking.html#storage for " \
"format specifications."

_ACLED_ARTIFACT_URI = "dbfs:/databricks/mlflow-tracking/"
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we move this into is_artifact_acled_uri()? It doesn't seem to be required by more than one function in this module.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return artifact_uri.lstrip('/').split('/')[4]

def _call_endpoint(self, service, api, json_body):
_METHOD_TO_INFO = extract_api_info_for_service(service, _PATH_PREFIX)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This extract_api_info_for_service method is a bit expensive (https://github.com/mlflow/mlflow/blob/c4e0cbdf5c92b3b81f7080fff4a974114605d2db/mlflow/utils/rest_utils.py#L115). Can we instead define a module-level _SERVICE_AND_METHOD_TO_INFO dictionary as:

_SERVICE_AND_METHOD_TO_INFO = {
    service: extract_api_info_for_service(service, _PATH_PREFIX)
    for service in [MlflowService, DatabricksMlflowArtifactsService]
}

We can then fetch the info as _SERVICE_AND_METHOD_TO_INFO[service][api], without performing proto extraction operations each time this method is called.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return call_endpoint(get_databricks_host_creds(), endpoint, method, json_body, response_proto)

def _create_json_body(self, run_id, path=None):
path = path or '.'
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
path = path or '.'
path = path or ""

"" is more idiomatic for representing empty than "."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@codecov-commenter
Copy link

codecov-commenter commented May 20, 2020

Codecov Report

Merging #5 into master will not change coverage.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master       #5   +/-   ##
=======================================
  Coverage   85.04%   85.04%           
=======================================
  Files          20       20           
  Lines        1050     1050           
=======================================
  Hits          893      893           
  Misses        157      157           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7969f91...dfc7f60. Read the comment docs.

service: extract_api_info_for_service(service, _PATH_PREFIX)
for service in [MlflowService, DatabricksMlflowArtifactsService]
}
self.credential_type_to_cloud_service = {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be deleted. I was considering a dictionary (Cloud_Provider -> upload/download_function) but it looked too messy.

Copy link
Owner Author

@dbczumar dbczumar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arjundc-db Awesome progress! I left a few more comments. Still need to take a look at the tests, but this is awesome!

mlflow/store/artifact/artifact_repo.py Show resolved Hide resolved
Comment on lines 71 to 75
if os.path.getsize(local_file) < _AZURE_SINGLE_BLOCK_BLOB_MAX_SIZE:
signed_write_uri = credentials.signed_uri
service = BlobClient.from_blob_url(blob_url=signed_write_uri, credential=None)
with open(local_file, "rb") as data:
service.upload_blob(data, overwrite=True)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arjundc-db Examining the documentation of BlobClient.upload_blob() (https://docs.microsoft.com/en-us/python/api/azure-storage-blob/azure.storage.blob.blobclient?view=azure-python#upload-blob-data--blob-type--blobtype-blockblob---blockblob----length-none--metadata-none----kwargs-), it appears that this method performs a block blob upload with automatic chunking (e.g., multiple chunks may be used for a single file of size < 256MB).

This Python client API is not equivalent to the single-part blob upload API (PutBlob): https://docs.microsoft.com/en-us/rest/api/storageservices/put-blob.

Accordingly, I think it's probably best to unconditionally use the 100MB self-chunking approach regardless of aggregate file size. Let me know if you have any questions about this.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep makes sense, done.

return call_endpoint(get_databricks_host_creds(),
endpoint, method, json_body, response_proto)

def _create_json_body(self, run_id, path=None):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arjundc-db I think this is unused. Can we remove it?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks for the catch.

'databricks/mlflow-tracking/path/to/artifact/..')

self.run_id = self._extract_run_id()
self._SERVICE_AND_METHOD_TO_INFO = {
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define this dictionary at the module level rather than the class level? The contents of this dictionary should not changed based on the particular instance that's using it. Often times, artifact repositories are instantiated many times over the course of a single MLflow training session, so it would be great to avoid having this overhead every time an artifact repository is created.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Cool point!

from mlflow.utils.databricks_utils import get_databricks_host_creds

_PATH_PREFIX = "/api/2.0"
_AZURE_SINGLE_BLOCK_BLOB_MAX_SIZE = 256000000 - 1 # Can upload blob in single request if it is no more thn 256 MB
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Typo - thn >> than (though I think we'll want to remove this constant anyway along with the < 256MB case - see comment below).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, removing it.

pass

def _upload_to_cloud(self, cloud_credentials, local_file, artifact_path):
if cloud_credentials.credentials.type == 1:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we do this instead for readability?:

Suggested change
if cloud_credentials.credentials.type == 1:
if cloud_credentials.credentials.type == ArtifactCredentials.AZURE_SAS_URI:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

raise MlflowException('Not implemented yet')

def _download_from_cloud(self, cloud_credentials, local_path):
if cloud_credentials.credentials.type == 1:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we do this instead for readability?:

Suggested change
if cloud_credentials.credentials.type == 1:
if cloud_credentials.credentials.type == ArtifactCredentials.AZURE_SAS_URI:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

else:
uploading_block_list = list()
for chunk in yield_file_in_chunks(local_file, _AZURE_MAX_BLOCK_CHUNK_SIZE):
signed_write_uri = self._get_write_credentials(self.run_id, artifact_path).credentials.signed_uri
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arjundc-db Can we use the credentials that were passed as a parameter to this _azure_upload_file method for the first chunk? Otherwise, I think we're wasting that set of credentials (it's only being used to determine whether the credentials are for AWS or Azure).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current approach I wanted to make _azure_upload_file independent of the initial get credentials. However, after the refactoring (and generating new creds only when the old ones expire), it makes sense to use the older credentials itself. Done.

with open(local_file, "rb") as data:
service.upload_blob(data, overwrite=True)
else:
uploading_block_list = list()
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arjundc-db This logic is awesome! I think it could really benefit from some inline comments explaining why we're chunking the file ourselves, why we fetch credentials when we do (once before each chunk and once before the block list commit operation). Can you add these here?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 140 to 144
# If `path` is a file, ListArtifacts returns a single list element with the
# same name as `path`. The list_artifacts API expects us to return an empty list in this
# case, so we do so here.
if len(artifact_list) == 1 and artifact_list[0].path == path:
return []
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch!

Edit: Saw this logic is used in dbfs_artifact_repo as well - unfortunately, it looks like it's a bit brittle. To test it, I created a file called foo (it was created inside a directory also called arty/foo, but that's not important). Then, I ran the following code:

In[0]: client.list_artifacts(run_id, "arty/foo/foo")
Out[0]: []

In[1]: client.list_artifacts(run_id, "arty/foo/foo////")
Out[1]: [<FileInfo: file_size=5, is_dir=False, path='arty/foo/foo'>]

Looks like those trailing slashes break the path comparison. However, it's a bit strange that ListArtifacts returns the file path arty/foo/foo for an input of arty/foo/foo/////, since this output path is a file (for files, test/ and test// are not equivalent...) and arty/foo/foo///// is not a prefix of arty/foo/foo. I think we should file a bug report ticket for the ListArtifacts API for this case.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense. To fix this, does it make sense to perform a regex check, since if a users wants to download a single file, the format will :probably: end with a .txt, .py, .pkl to name a few. This could be a restriction if a single file needs to be downloaded.

Copy link
Owner Author

@dbczumar dbczumar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments on tests / inline docs - most are very minor! Great work!

@@ -159,11 +161,23 @@ def dbfs_artifact_repo_factory(artifact_uri):

This factory method is used with URIs of the form ``dbfs:/<path>``. DBFS-backed artifact
storage can only be used together with the RestStore.

In the special case where the URI is of the form
``dbfs:/databricks/mlflow-tracking/<Exp-ID>/<Run-ID>/<path>,
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hyper nit:

Suggested change
``dbfs:/databricks/mlflow-tracking/<Exp-ID>/<Run-ID>/<path>,
`dbfs:/databricks/mlflow-tracking/<Exp-ID>/<Run-ID>/<path>`,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

mlflow/store/artifact/databricks_artifact_repo.py Outdated Show resolved Hide resolved
Comment on lines 48 to 51
LIST_ARTIFACTS_PROTO_RESPONSE = [FileInfo(path='test/a.txt', is_dir=False, file_size=100),
FileInfo(path='test/dir', is_dir=True, file_size=0)]

LIST_ARTIFACTS_SINGLE_FILE_PROTO_RESPONSE = [FileInfo(path='a.txt', is_dir=False, file_size=0)]
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: If we don't foresee these constants being used in multiple test cases, we may as well move their contents into the relevant test cases; this makes it easier to understand what's going on in each individual test case.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got you. Done.

Comment on lines 24 to 26
TEST_FILE_1_CONTENT = u"Hello 🍆🍔".encode("utf-8")
TEST_FILE_2_CONTENT = u"World 🍆🍔🍆".encode("utf-8")
TEST_FILE_3_CONTENT = u"¡🍆🍆🍔🍆🍆!".encode("utf-8")
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we don't ever read this content within any of the test cases. Accordingly, I think we can move these constants directly into the test_file and test_dir fixtures to clean things up.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

tests/store/artifact/test_databricks_artifact_repo.py Outdated Show resolved Hide resolved
tests/store/artifact/test_databricks_artifact_repo.py Outdated Show resolved Hide resolved
tests/utils/test_uri.py Show resolved Hide resolved
Copy link
Owner Author

@dbczumar dbczumar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arjundc-db This is looking great! I left a few nit comments about documentation and testing, as well as some path handling stuff (posixpath vs os.path). Once these are addressed, we should be all set to file a PR against the main MLflow repo and work towards getting it merged (there's a bit of work I have to do to support a paginated list_artifacts() API before merging.

"""
The artifact_uri is expected to be
dbfs:/databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path>
Once the path from the inputted uri is extracted and normalized, is is
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
Once the path from the inputted uri is extracted and normalized, is is
Once the path from the input uri is extracted and normalized, it is

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

'databricks/mlflow-tracking/path/to/artifact/..')
self.run_id = self._extract_run_id()

def _extract_run_id(self):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a static method that takes the URI as an input parameter? Defining a no-argument method that operates on a class property to return a new value feels a bit confusing.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I rather make it a helper function in the file, something like https://livegrep.dev.databricks.com/view/mlflow/mlflow/mlflow/store/artifact/dbfs_artifact_repo.py#L147 ?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a static function within the class makes the most sense, since this is only used within the class itself. The @staticmethod decorator may be useful here.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do!

if mlflow.utils.databricks_utils.is_dbfs_fuse_available() \
uri_scheme = get_uri_scheme(artifact_uri)
if uri_scheme != 'dbfs':
raise Exception("DBFS URI must be of the form "
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be an MlflowException

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

with requests.get(signed_read_uri, stream=True) as response:
response.raise_for_status()
with open(local_file, "wb") as output_file:
for chunk in response.iter_content(chunk_size=_AZURE_MAX_BLOCK_CHUNK_SIZE):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arjundc-db Can we add a comment right about this line explaining why we're leveraging iter_content() rather than reading all of the content at once?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 43 to 45
raise MlflowException('DatabricksArtifactRepository URI must start with dbfs:/')
if not is_databricks_acled_artifacts_uri(artifact_uri):
raise MlflowException('Artifact URI incorrect. Expected path prefix to be '
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we specify error_code=INVALID_PARAMETER_VALUE for these exceptions? Here's an example: https://github.com/mlflow/mlflow/blob/88858b529bc1b49f37a4a3e7087649f85e62bf1a/mlflow/sklearn.py#L142

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def download_file(fullpath):
fullpath = fullpath.rstrip('/')
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! Can we add a brief inline comment above this line explaining why we're stripping trailing slashes from the suffix, just so it's clear for folks reading this code later on?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def log_artifact(self, local_file, artifact_path=None):
basename = os.path.basename(local_file)
artifact_path = artifact_path or ""
artifact_path = os.path.join(artifact_path, basename)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

artifact_path is used to define the destination for the uploaded artifacts on DBFS; DBFS is a posix filesystem and uses forward slashes as separators. os.path.join() will use the caller's native filesystem for path concatenation, which will result in the artifact path being joined with a backwards slash on Windows systems. Accordingly, we should call posixpath.join() here instead.

Note that the preceding call to os.path.basename() on line 154 is perfectly acceptable/correct since that line is resolving the base name for a path on the caller's local filesystem.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty cool to learn about, thanks for the explanation.

if dirpath != local_dir:
rel_path = os.path.relpath(dirpath, local_dir)
rel_path = relative_path_to_artifact_path(rel_path)
artifact_subdir = os.path.join(artifact_path, rel_path)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in the comment above, artifact_subdir is used to define the destination for the uploaded artifacts on DBFS; DBFS is a posix filesystem and uses forward slashes as separators. os.path.join() will use the caller's native filesystem for path concatenation, which will result in the artifact path being joined with a backwards slash on Windows systems. Accordingly, we should call posixpath.join() here instead.

Note that the preceding call to os.path.relpath is correct since it is resolving the relative location of a subdirectory to a parent directory on the caller's local filesystem. This relative path (which may contain backslashes on Windows systems) is then converted to a POSIX-style path component of a URI by the relative_path_to_artifact_path function, which invokes urllib.pathname2url to perform the conversion.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

mlflow/store/artifact/databricks_artifact_repo.py Outdated Show resolved Hide resolved

class DatabricksArtifactRepository(ArtifactRepository):
"""
Stores artifacts on Azure/AWS with access control.
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think we can probably have a better description here, such as:

Performs storage operations on artifacts in the access-controlled `dbfs:/databricks/mlflow-tracking` location. Signed access URIs for S3 / Azure Blob Storage are fetched from the MLflow service and used to read and write files from/to this location.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

service.stage_block(block_id, chunk, headers=headers)
uploading_block_list.append(block_id)
try:
service.commit_block_list(uploading_block_list)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After doing some digging through the docs, I think we should include the headers on the commit_block_list operation, since these headers are likely to be used for encryption information; I'd imagine that this information is just as relevant for a block commit as it is for a block upload. Sorry for the runaround here.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

credentials = self._get_write_credentials(self.run_id,
artifact_path).credentials.signed_uri
service = BlobClient.from_blob_url(blob_url=credentials, credential=None)
service.commit_block_list(uploading_block_list)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After doing some digging through the docs, I think we should include the headers on the commit_block_list operation, since these headers are likely to be used for encryption information; I'd imagine that this information is just as relevant for a block commit as it is for a block upload. Sorry for the runaround here.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 90 to 93
headers = dict()
for header in credential.headers:
headers[header.name] = header.value
return headers
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
headers = dict()
for header in credential.headers:
headers[header.name] = header.value
return headers
return {
header.name: header.value
for header in headers
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

elif cloud_credentials.credentials.type == ArtifactCredentialType.AWS_PRESIGNED_URL:
self._aws_upload_file(cloud_credentials.credentials, local_file)
else:
raise MlflowException('Not implemented yet')
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we do in _download_from_cloud, we should use Cloud provider not supported in the exception text.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if cloud_credential.type not in [ArtifactCredentialType.AZURE_SAS_URI,
ArtifactCredentialType.AWS_PRESIGNED_URL]:
raise MlflowException(message='Cloud provider not supported.',
error_code=INVALID_PARAMETER_VALUE)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice exception! I think the error code here should be INTERNAL_ERROR in this case (the default error code), since the user has no control over this failure mode. If the MLflow service suddenly starts passing strange credential types to the client, there's nothing the end user can do about it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

aws_upload_mock.return_value = None
databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path)
write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location)
aws_upload_mock.assert_called_with(mock_credentials, test_file.strpath)
Copy link
Owner Author

@dbczumar dbczumar Jun 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arjundc-db Is it possible to dig into the _aws_upload_file method a little further and assert that the expected headers are passed to requests.put? This would be a great sanity check that we're correctly handling headers in the client. I think it would be great to do this with a separate test case such as test_log_artifact_aws_with_headers.

A similar test_log_artifact_azure_with_headers case would also be awesome. Hopefully this doesn't require too much mocking - presumably, just requests.put for AWS and a couple of BlobClient operations for Azure, since we're already creating actual files in the tmpdir test fixture; let me know if you can give this a shot!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! wew

write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location)
aws_upload_mock.assert_called_with(mock_credentials, test_file.strpath)

def test_log_artifact_fail_case(self, databricks_artifact_repo, test_file, ):
Copy link
Owner Author

@dbczumar dbczumar Jun 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we be more descriptive with this test case name? What is the failure mode we're looking at? (e.g., unexpected exceptions encountered during artifact uploading)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location)
aws_upload_mock.assert_called_with(mock_credentials, test_file.strpath)

def test_log_artifact_fail_case(self, databricks_artifact_repo, test_file, ):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
def test_log_artifact_fail_case(self, databricks_artifact_repo, test_file, ):
def test_log_artifact_fail_case(self, databricks_artifact_repo, test_file):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

mock_blob_service.from_blob_url().return_value = MlflowException("MOCK ERROR")
with pytest.raises(MlflowException):
databricks_artifact_repo.log_artifact(test_file.strpath)
write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this assert function will be called since the pytest.raises scope terminates as soon as an exception is encountered. You should be able to fix this by moving the assert call out of the pytest.raises context (i.e., unindent 1 level).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

dbczumar pushed a commit that referenced this pull request Oct 23, 2023
Signed-off-by: Saransh Sharma <[email protected]>
Co-authored-by: Saransh Sharma <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants