Skip to content

Commit

Permalink
Fix CI error (#2220)
Browse files Browse the repository at this point in the history
* fix ci

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* fix

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Jan Fiedler <[email protected]>
  • Loading branch information
pingsutw authored and fiedlerNr9 committed Jul 25, 2024
1 parent a57c41f commit 3c7f358
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 67 deletions.
9 changes: 1 addition & 8 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,6 @@ def execute(self, **kwargs) -> typing.Any:
self._agent = AgentRegistry.get_agent(task_template.type)

res = asyncio.run(self._create(task_template, output_prefix, kwargs))

# If the task is synchronous, the agent will return the output from the resource literals.
if res.HasField("resource"):
if res.resource.phase != TaskExecution.SUCCEEDED:
raise FlyteUserException(f"Failed to run the task {self._entity.name}")
return LiteralMap.from_flyte_idl(res.resource.outputs)

res = asyncio.run(self._get(resource_meta=res.resource_meta))

if res.resource.phase != TaskExecution.SUCCEEDED:
Expand Down Expand Up @@ -249,7 +242,7 @@ async def _get(self, resource_meta: bytes) -> GetTaskResponse:
visible=True,
)
log_links = ""
for link in res.log_links:
for link in res.resource.log_links:
log_links += f"{link.name}: {link.uri}\n"
if log_links:
progress.update(task_log_links, description=f"[cyan]{log_links}", visible=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse:
}
).to_flyte_idl()

return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res), log_links=log_links)
return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res, log_links=log_links))

def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse:
client = bigquery.Client()
Expand Down
4 changes: 2 additions & 2 deletions plugins/flytekit-bigquery/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def __init__(self):
res.resource.outputs.literals["results"].scalar.structured_dataset.uri
== "bq://dummy_project:dummy_dataset.dummy_table"
)
assert res.log_links[0].name == "BigQuery Console"
assert res.resource.log_links[0].name == "BigQuery Console"
assert (
res.log_links[0].uri
res.resource.log_links[0].uri
== "https://console.cloud.google.com/bigquery?project=dummy_project&j=bq:us-central1:dummy_id&page=queryresults"
)
agent.delete(metadata_bytes)
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-spark/flytekitplugins/spark/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse:
databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{metadata.run_id}"
log_links = [TaskLog(uri=databricks_console_url, name="Databricks Console").to_flyte_idl()]

return GetTaskResponse(resource=Resource(phase=cur_phase, message=message), log_links=log_links)
return GetTaskResponse(resource=Resource(phase=cur_phase, message=message, log_links=log_links))

async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse:
metadata = pickle.loads(resource_meta)
Expand Down
2 changes: 0 additions & 2 deletions plugins/flytekit-spark/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ async def test_databricks_agent():
assert res.resource.phase == TaskExecution.SUCCEEDED
assert res.resource.outputs == literals.LiteralMap({}).to_flyte_idl()
assert res.resource.message == "OK"
assert res.log_links[0].name == "Databricks Console"
assert res.log_links[0].uri == "https://test-account.cloud.databricks.com/#job/1/run/123"

mocked.post(delete_url, status=http.HTTPStatus.OK, payload=mock_delete_response)
await agent.delete(metadata_bytes)
Expand Down
57 changes: 4 additions & 53 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
DeleteTaskResponse,
GetTaskRequest,
GetTaskResponse,
ListAgentsRequest,
ListAgentsResponse,
Resource,
)
from flyteidl.core.execution_pb2 import TaskExecution

from flytekit import PythonFunctionTask, task
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService
from flytekit.extend.backend.agent_service import AsyncAgentService
from flytekit.extend.backend.base_agent import (
AgentBase,
AgentRegistry,
Expand Down Expand Up @@ -60,8 +58,9 @@ def create(

def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse:
return GetTaskResponse(
resource=Resource(phase=TaskExecution.SUCCEEDED),
log_links=[TaskLog(name="console", uri="localhost:3000").to_flyte_idl()],
resource=Resource(
phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000").to_flyte_idl()]
),
)

def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse:
Expand Down Expand Up @@ -90,24 +89,6 @@ async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse:
return DeleteTaskResponse()


class SyncDummyAgent(AgentBase):
name = "Sync Dummy Agent"

def __init__(self):
super().__init__(task_type="sync_dummy")

def create(
self,
output_prefix: str,
task_template: TaskTemplate,
inputs: typing.Optional[LiteralMap] = None,
**kwargs,
) -> CreateTaskResponse:
return CreateTaskResponse(
resource=Resource(phase=TaskExecution.SUCCEEDED, outputs=LiteralMap({}).to_flyte_idl())
)


def get_task_template(task_type: str) -> TaskTemplate:
@task
def simple_task(i: int):
Expand Down Expand Up @@ -146,8 +127,6 @@ def test_dummy_agent():
assert agent.create("/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes
res = agent.get(metadata_bytes)
assert res.resource.phase == TaskExecution.SUCCEEDED
assert res.log_links[0].name == "console"
assert res.log_links[0].uri == "localhost:3000"
assert agent.delete(metadata_bytes) == DeleteTaskResponse()

class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask):
Expand Down Expand Up @@ -186,19 +165,6 @@ async def test_async_dummy_agent():
assert agent_metadata.supported_task_types == ["async_dummy"]


@pytest.mark.asyncio
async def test_sync_dummy_agent():
AgentRegistry.register(SyncDummyAgent())
agent = AgentRegistry.get_agent("sync_dummy")
res = agent.create("/tmp", sync_dummy_template, task_inputs)
assert res.resource.phase == TaskExecution.SUCCEEDED
assert res.resource.outputs == LiteralMap({}).to_flyte_idl()

agent_metadata = AgentRegistry.get_agent_metadata("Sync Dummy Agent")
assert agent_metadata.name == "Sync Dummy Agent"
assert agent_metadata.supported_task_types == ["sync_dummy"]


@pytest.mark.asyncio
async def run_agent_server():
service = AsyncAgentService()
Expand All @@ -209,10 +175,6 @@ async def run_agent_server():
async_request = CreateTaskRequest(
inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl()
)
sync_request = CreateTaskRequest(
inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=sync_dummy_template.to_flyte_idl()
)
fake_agent = "fake"
metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")

res = await service.CreateTask(request, ctx)
Expand All @@ -229,17 +191,6 @@ async def run_agent_server():
res = await service.DeleteTask(DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx)
assert isinstance(res, DeleteTaskResponse)

res = await service.CreateTask(sync_request, ctx)
assert res.resource.phase == TaskExecution.SUCCEEDED
assert res.resource.outputs == LiteralMap({}).to_flyte_idl()

res = await service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx)
assert res is None

metadata_service = AgentMetadataService()
res = await metadata_service.ListAgent(ListAgentsRequest(), ctx)
assert isinstance(res, ListAgentsResponse)


def test_agent_server():
loop.run_in_executor(None, run_agent_server)
Expand Down

0 comments on commit 3c7f358

Please sign in to comment.