Skip to content

Commit

Permalink
Implement GraphQL endpoint for ListArtifacts API (mlflow#12602)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniellok-db authored Jul 15, 2024
1 parent 99933f8 commit 2008fc6
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 47 deletions.
76 changes: 38 additions & 38 deletions mlflow/java/client/src/main/java/org/mlflow/api/proto/Service.java

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions mlflow/protos/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ service MlflowService {
visibility: PUBLIC,
rpc_doc_title: "List Artifacts",
};
option (graphql) = {};
}

// Get a list of all values for the specified metric for a given run.
Expand Down
12 changes: 6 additions & 6 deletions mlflow/protos/service_pb2.py

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions mlflow/server/graphql/autogenerated_graphql_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ class MlflowGetMetricHistoryBulkIntervalResponse(graphene.ObjectType):
metrics = graphene.List(graphene.NonNull(MlflowMetricWithRunId))


class MlflowFileInfo(graphene.ObjectType):
path = graphene.String()
is_dir = graphene.Boolean()
file_size = LongString()


class MlflowListArtifactsResponse(graphene.ObjectType):
root_uri = graphene.String()
files = graphene.List(graphene.NonNull(MlflowFileInfo))
next_page_token = graphene.String()


class MlflowDataset(graphene.ObjectType):
name = graphene.String()
digest = graphene.String()
Expand Down Expand Up @@ -173,6 +185,13 @@ class MlflowGetMetricHistoryBulkIntervalInput(graphene.InputObjectType):
max_results = graphene.Int()


class MlflowListArtifactsInput(graphene.InputObjectType):
run_id = graphene.String()
run_uuid = graphene.String()
path = graphene.String()
page_token = graphene.String()


class MlflowSearchRunsInput(graphene.InputObjectType):
experiment_ids = graphene.List(graphene.String)
filter = graphene.String()
Expand All @@ -195,6 +214,7 @@ class QueryType(graphene.ObjectType):
mlflow_get_experiment = graphene.Field(MlflowGetExperimentResponse, input=MlflowGetExperimentInput())
mlflow_get_metric_history_bulk_interval = graphene.Field(MlflowGetMetricHistoryBulkIntervalResponse, input=MlflowGetMetricHistoryBulkIntervalInput())
mlflow_get_run = graphene.Field(MlflowGetRunResponse, input=MlflowGetRunInput())
mlflow_list_artifacts = graphene.Field(MlflowListArtifactsResponse, input=MlflowListArtifactsInput())
mlflow_search_model_versions = graphene.Field(MlflowSearchModelVersionsResponse, input=MlflowSearchModelVersionsInput())

def resolve_mlflow_get_experiment(self, info, input):
Expand All @@ -215,6 +235,12 @@ def resolve_mlflow_get_run(self, info, input):
parse_dict(input_dict, request_message)
return mlflow.server.handlers.get_run_impl(request_message)

def resolve_mlflow_list_artifacts(self, info, input):
input_dict = vars(input)
request_message = mlflow.protos.service_pb2.ListArtifacts()
parse_dict(input_dict, request_message)
return mlflow.server.handlers.list_artifacts_impl(request_message)

def resolve_mlflow_search_model_versions(self, info, input):
input_dict = vars(input)
request_message = mlflow.protos.model_registry_pb2.SearchModelVersions()
Expand Down
11 changes: 8 additions & 3 deletions mlflow/server/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,13 @@ def _list_artifacts():
"page_token": [_assert_string],
},
)
response_message = list_artifacts_impl(request_message)
response = Response(mimetype="application/json")
response.set_data(message_to_json(response_message))
return response


def list_artifacts_impl(request_message):
response_message = ListArtifacts.Response()
if request_message.HasField("path"):
path = request_message.path
Expand All @@ -1044,9 +1051,7 @@ def _list_artifacts():

response_message.files.extend([a.to_proto() for a in artifact_entities])
response_message.root_uri = run.info.artifact_uri
response = Response(mimetype="application/json")
response.set_data(message_to_json(response_message))
return response
return response_message


@catch_mlflow_exception
Expand Down
20 changes: 20 additions & 0 deletions mlflow/server/js/src/graphql/autogenerated_schema.gql
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ type Query {
mlflowGetExperiment(input: MlflowGetExperimentInput): MlflowGetExperimentResponse
mlflowGetMetricHistoryBulkInterval(input: MlflowGetMetricHistoryBulkIntervalInput): MlflowGetMetricHistoryBulkIntervalResponse
mlflowGetRun(input: MlflowGetRunInput): MlflowGetRunResponse
mlflowListArtifacts(input: MlflowListArtifactsInput): MlflowListArtifactsResponse
mlflowSearchModelVersions(input: MlflowSearchModelVersionsInput): MlflowSearchModelVersionsResponse

"""Simple echoing field"""
Expand Down Expand Up @@ -173,6 +174,25 @@ input MlflowGetRunInput {
runUuid: String
}

type MlflowListArtifactsResponse {
rootUri: String
files: [MlflowFileInfo!]
nextPageToken: String
}

type MlflowFileInfo {
path: String
isDir: Boolean
fileSize: LongString
}

input MlflowListArtifactsInput {
runId: String
runUuid: String
path: String
pageToken: String
}

type MlflowSearchModelVersionsResponse {
modelVersions: [MlflowModelVersion!]
nextPageToken: String
Expand Down
51 changes: 51 additions & 0 deletions tests/tracking/test_rest_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,3 +2294,54 @@ def test_search_runs_graphql(mlflow_client):
{"info": {"runId": created_run_1.info.run_id}},
]
assert json["data"]["mlflowSearchRuns"]["runs"] == expected


def test_list_artifacts_graphql(mlflow_client, tmp_path):
name = "GraphqlTest"
experiment_id = mlflow_client.create_experiment(name)
created_run_id = mlflow_client.create_run(experiment_id).info.run_id
file_path = tmp_path / "test.txt"
file_path.write_text("hello world")
mlflow_client.log_artifact(created_run_id, file_path.absolute().as_posix())
mlflow_client.log_artifact(created_run_id, file_path.absolute().as_posix(), "testDir")

response = requests.post(
f"{mlflow_client.tracking_uri}/graphql",
json={
"query": f"""
fragment FilesFragment on MlflowListArtifactsResponse {{
files {{
path
isDir
fileSize
}}
}}
query testQuery {{
file: mlflowListArtifacts(input: {{ runId: "{created_run_id}" }}) {{
...FilesFragment
}}
subdir: mlflowListArtifacts(input: {{
runId: "{created_run_id}",
path: "testDir",
}}) {{
...FilesFragment
}}
}}
""",
"operationName": "testQuery",
},
headers={"content-type": "application/json; charset=utf-8"},
)

assert response.status_code == 200
json = response.json()
file_expected = [
{"path": "test.txt", "isDir": False, "fileSize": "11"},
{"path": "testDir", "isDir": True, "fileSize": "0"},
]
assert json["data"]["file"]["files"] == file_expected
subdir_expected = [
{"path": "testDir/test.txt", "isDir": False, "fileSize": "11"},
]
assert json["data"]["subdir"]["files"] == subdir_expected

0 comments on commit 2008fc6

Please sign in to comment.