diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 5a6e5cd3bf..4648f5ecdf 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -177,13 +177,6 @@ def execute(self, **kwargs) -> typing.Any: self._agent = AgentRegistry.get_agent(task_template.type) res = asyncio.run(self._create(task_template, output_prefix, kwargs)) - - # If the task is synchronous, the agent will return the output from the resource literals. - if res.HasField("resource"): - if res.resource.phase != TaskExecution.SUCCEEDED: - raise FlyteUserException(f"Failed to run the task {self._entity.name}") - return LiteralMap.from_flyte_idl(res.resource.outputs) - res = asyncio.run(self._get(resource_meta=res.resource_meta)) if res.resource.phase != TaskExecution.SUCCEEDED: @@ -249,7 +242,7 @@ async def _get(self, resource_meta: bytes) -> GetTaskResponse: visible=True, ) log_links = "" - for link in res.log_links: + for link in res.resource.log_links: log_links += f"{link.name}: {link.uri}\n" if log_links: progress.update(task_log_links, description=f"[cyan]{log_links}", visible=True) diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py index 4c34285793..0418d4f809 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py @@ -114,7 +114,7 @@ def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: } ).to_flyte_idl() - return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res), log_links=log_links) + return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res, log_links=log_links)) def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: client = bigquery.Client() diff --git a/plugins/flytekit-bigquery/tests/test_agent.py b/plugins/flytekit-bigquery/tests/test_agent.py index dc2af4ab80..0293ebd3cc 100644 --- a/plugins/flytekit-bigquery/tests/test_agent.py +++ b/plugins/flytekit-bigquery/tests/test_agent.py @@ -96,9 +96,9 @@ def __init__(self): res.resource.outputs.literals["results"].scalar.structured_dataset.uri == "bq://dummy_project:dummy_dataset.dummy_table" ) - assert res.log_links[0].name == "BigQuery Console" + assert res.resource.log_links[0].name == "BigQuery Console" assert ( - res.log_links[0].uri + res.resource.log_links[0].uri == "https://console.cloud.google.com/bigquery?project=dummy_project&j=bq:us-central1:dummy_id&page=queryresults" ) agent.delete(metadata_bytes) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 2fe442182a..c75617b0e0 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -102,7 +102,7 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{metadata.run_id}" log_links = [TaskLog(uri=databricks_console_url, name="Databricks Console").to_flyte_idl()] - return GetTaskResponse(resource=Resource(phase=cur_phase, message=message), log_links=log_links) + return GetTaskResponse(resource=Resource(phase=cur_phase, message=message, log_links=log_links)) async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: metadata = pickle.loads(resource_meta) diff --git a/plugins/flytekit-spark/tests/test_agent.py b/plugins/flytekit-spark/tests/test_agent.py index 5d19b3402f..f875800268 100644 --- a/plugins/flytekit-spark/tests/test_agent.py +++ b/plugins/flytekit-spark/tests/test_agent.py @@ -126,8 +126,6 @@ async def test_databricks_agent(): assert res.resource.phase == TaskExecution.SUCCEEDED assert res.resource.outputs == literals.LiteralMap({}).to_flyte_idl() assert res.resource.message == "OK" - assert res.log_links[0].name == "Databricks Console" - assert res.log_links[0].uri == "https://test-account.cloud.databricks.com/#job/1/run/123" mocked.post(delete_url, status=http.HTTPStatus.OK, payload=mock_delete_response) await agent.delete(metadata_bytes) diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 68dee74b3f..ce9b9e5b9b 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -14,15 +14,13 @@ DeleteTaskResponse, GetTaskRequest, GetTaskResponse, - ListAgentsRequest, - ListAgentsResponse, Resource, ) from flyteidl.core.execution_pb2 import TaskExecution from flytekit import PythonFunctionTask, task from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings -from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService +from flytekit.extend.backend.agent_service import AsyncAgentService from flytekit.extend.backend.base_agent import ( AgentBase, AgentRegistry, @@ -60,8 +58,9 @@ def create( def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: return GetTaskResponse( - resource=Resource(phase=TaskExecution.SUCCEEDED), - log_links=[TaskLog(name="console", uri="localhost:3000").to_flyte_idl()], + resource=Resource( + phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000").to_flyte_idl()] + ), ) def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: @@ -90,24 +89,6 @@ async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: return DeleteTaskResponse() -class SyncDummyAgent(AgentBase): - name = "Sync Dummy Agent" - - def __init__(self): - super().__init__(task_type="sync_dummy") - - def create( - self, - output_prefix: str, - task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None, - **kwargs, - ) -> CreateTaskResponse: - return CreateTaskResponse( - resource=Resource(phase=TaskExecution.SUCCEEDED, outputs=LiteralMap({}).to_flyte_idl()) - ) - - def get_task_template(task_type: str) -> TaskTemplate: @task def simple_task(i: int): @@ -146,8 +127,6 @@ def test_dummy_agent(): assert agent.create("/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes res = agent.get(metadata_bytes) assert res.resource.phase == TaskExecution.SUCCEEDED - assert res.log_links[0].name == "console" - assert res.log_links[0].uri == "localhost:3000" assert agent.delete(metadata_bytes) == DeleteTaskResponse() class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): @@ -186,19 +165,6 @@ async def test_async_dummy_agent(): assert agent_metadata.supported_task_types == ["async_dummy"] -@pytest.mark.asyncio -async def test_sync_dummy_agent(): - AgentRegistry.register(SyncDummyAgent()) - agent = AgentRegistry.get_agent("sync_dummy") - res = agent.create("/tmp", sync_dummy_template, task_inputs) - assert res.resource.phase == TaskExecution.SUCCEEDED - assert res.resource.outputs == LiteralMap({}).to_flyte_idl() - - agent_metadata = AgentRegistry.get_agent_metadata("Sync Dummy Agent") - assert agent_metadata.name == "Sync Dummy Agent" - assert agent_metadata.supported_task_types == ["sync_dummy"] - - @pytest.mark.asyncio async def run_agent_server(): service = AsyncAgentService() @@ -209,10 +175,6 @@ async def run_agent_server(): async_request = CreateTaskRequest( inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() ) - sync_request = CreateTaskRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=sync_dummy_template.to_flyte_idl() - ) - fake_agent = "fake" metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") res = await service.CreateTask(request, ctx) @@ -229,17 +191,6 @@ async def run_agent_server(): res = await service.DeleteTask(DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) assert isinstance(res, DeleteTaskResponse) - res = await service.CreateTask(sync_request, ctx) - assert res.resource.phase == TaskExecution.SUCCEEDED - assert res.resource.outputs == LiteralMap({}).to_flyte_idl() - - res = await service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx) - assert res is None - - metadata_service = AgentMetadataService() - res = await metadata_service.ListAgent(ListAgentsRequest(), ctx) - assert isinstance(res, ListAgentsResponse) - def test_agent_server(): loop.run_in_executor(None, run_agent_server)