Skip to content

Commit

Permalink
Implement GraphQL endpoint for SearchRuns API (mlflow#12601)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Lok <[email protected]>
  • Loading branch information
daniellok-db authored Jul 15, 2024
1 parent 7f90f9c commit f4b7ed8
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 51 deletions.
83 changes: 42 additions & 41 deletions mlflow/java/client/src/main/java/org/mlflow/api/proto/Service.java

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions mlflow/protos/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ service MlflowService {
visibility: PUBLIC,
rpc_doc_title: "Search Runs",
};
option (graphql) = {};
}

// List artifacts for a run. Takes an optional ``artifact_path`` prefix which if specified,
Expand Down
12 changes: 6 additions & 6 deletions mlflow/protos/service_pb2.py

Large diffs are not rendered by default.

28 changes: 27 additions & 1 deletion mlflow/server/graphql/autogenerated_graphql_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ class MlflowRunStatus(graphene.Enum):
KILLED = 5


class MlflowViewType(graphene.Enum):
ACTIVE_ONLY = 1
DELETED_ONLY = 2
ALL = 3


class MlflowModelVersionTag(graphene.ObjectType):
key = graphene.String()
value = graphene.String()
Expand Down Expand Up @@ -124,6 +130,11 @@ class MlflowRun(graphene.ObjectType):
inputs = graphene.Field(MlflowRunInputs)


class MlflowSearchRunsResponse(graphene.ObjectType):
runs = graphene.List(graphene.NonNull('mlflow.server.graphql.graphql_schema_extensions.MlflowRunExtension'))
next_page_token = graphene.String()


class MlflowGetRunResponse(graphene.ObjectType):
run = graphene.Field('mlflow.server.graphql.graphql_schema_extensions.MlflowRunExtension')

Expand Down Expand Up @@ -162,6 +173,15 @@ class MlflowGetMetricHistoryBulkIntervalInput(graphene.InputObjectType):
max_results = graphene.Int()


class MlflowSearchRunsInput(graphene.InputObjectType):
experiment_ids = graphene.List(graphene.String)
filter = graphene.String()
run_view_type = graphene.Field(MlflowViewType)
max_results = graphene.Int()
order_by = graphene.List(graphene.String)
page_token = graphene.String()


class MlflowGetRunInput(graphene.InputObjectType):
run_id = graphene.String()
run_uuid = graphene.String()
Expand Down Expand Up @@ -203,4 +223,10 @@ def resolve_mlflow_search_model_versions(self, info, input):


class MutationType(graphene.ObjectType):
pass
mlflow_search_runs = graphene.Field(MlflowSearchRunsResponse, input=MlflowSearchRunsInput())

def resolve_mlflow_search_runs(self, info, input):
input_dict = vars(input)
request_message = mlflow.protos.service_pb2.SearchRuns()
parse_dict(input_dict, request_message)
return mlflow.server.handlers.search_runs_impl(request_message)
11 changes: 8 additions & 3 deletions mlflow/server/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,13 @@ def _search_runs():
"order_by": [_assert_array, _assert_item_type_string],
},
)
response_message = search_runs_impl(request_message)
response = Response(mimetype="application/json")
response.set_data(message_to_json(response_message))
return response


def search_runs_impl(request_message):
response_message = SearchRuns.Response()
run_view_type = ViewType.ACTIVE_ONLY
if request_message.HasField("run_view_type"):
Expand All @@ -1004,9 +1011,7 @@ def _search_runs():
response_message.runs.extend([r.to_proto() for r in run_entities])
if run_entities.token:
response_message.next_page_token = run_entities.token
response = Response(mimetype="application/json")
response.set_data(message_to_json(response_message))
return response
return response_message


@catch_mlflow_exception
Expand Down
22 changes: 22 additions & 0 deletions mlflow/server/js/src/graphql/autogenerated_schema.gql
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,32 @@ type Test {
}

type Mutation {
mlflowSearchRuns(input: MlflowSearchRunsInput): MlflowSearchRunsResponse

"""Simple echoing field"""
testMutation(inputString: String): TestMutation
}

type MlflowSearchRunsResponse {
runs: [MlflowRunExtension!]
nextPageToken: String
}

input MlflowSearchRunsInput {
experimentIds: [String]
filter: String
runViewType: MlflowViewType = null
maxResults: Int
orderBy: [String]
pageToken: String
}

enum MlflowViewType {
ACTIVE_ONLY
DELETED_ONLY
ALL
}

type TestMutation {
"""Echoes the input string"""
output: String
Expand Down
35 changes: 35 additions & 0 deletions tests/tracking/test_rest_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2259,3 +2259,38 @@ def test_get_metric_history_bulk_interval_graphql(mlflow_client):
json = response.json()
expected = [{"key": metric_name, "timestamp": mock.ANY, "value": i} for i in range(10)]
assert json["data"]["mlflowGetMetricHistoryBulkInterval"]["metrics"] == expected


def test_search_runs_graphql(mlflow_client):
name = "GraphqlTest"
mlflow_client.create_registered_model(name)
experiment_id = mlflow_client.create_experiment(name)
created_run_1 = mlflow_client.create_run(experiment_id)
created_run_2 = mlflow_client.create_run(experiment_id)

response = requests.post(
f"{mlflow_client.tracking_uri}/graphql",
json={
"query": f"""
mutation testMutation {{
mlflowSearchRuns(input: {{ experimentIds: ["{experiment_id}"] }}) {{
runs {{
info {{
runId
}}
}}
}}
}}
""",
"operationName": "testMutation",
},
headers={"content-type": "application/json; charset=utf-8"},
)

assert response.status_code == 200
json = response.json()
expected = [
{"info": {"runId": created_run_2.info.run_id}},
{"info": {"runId": created_run_1.info.run_id}},
]
assert json["data"]["mlflowSearchRuns"]["runs"] == expected

0 comments on commit f4b7ed8

Please sign in to comment.