Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CI error #2220

Merged
merged 10 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,6 @@ def execute(self, **kwargs) -> typing.Any:
self._agent = AgentRegistry.get_agent(task_template.type)

res = asyncio.run(self._create(task_template, output_prefix, kwargs))

# If the task is synchronous, the agent will return the output from the resource literals.
if res.HasField("resource"):
if res.resource.phase != TaskExecution.SUCCEEDED:
raise FlyteUserException(f"Failed to run the task {self._entity.name}")
return LiteralMap.from_flyte_idl(res.resource.outputs)

res = asyncio.run(self._get(resource_meta=res.resource_meta))

if res.resource.phase != TaskExecution.SUCCEEDED:
Expand Down Expand Up @@ -249,7 +242,7 @@ async def _get(self, resource_meta: bytes) -> GetTaskResponse:
visible=True,
)
log_links = ""
for link in res.log_links:
for link in res.resource.log_links:
log_links += f"{link.name}: {link.uri}\n"
if log_links:
progress.update(task_log_links, description=f"[cyan]{log_links}", visible=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse:
}
).to_flyte_idl()

return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res), log_links=log_links)
return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res, log_links=log_links))

def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse:
client = bigquery.Client()
Expand Down
4 changes: 2 additions & 2 deletions plugins/flytekit-bigquery/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def __init__(self):
res.resource.outputs.literals["results"].scalar.structured_dataset.uri
== "bq://dummy_project:dummy_dataset.dummy_table"
)
assert res.log_links[0].name == "BigQuery Console"
assert res.resource.log_links[0].name == "BigQuery Console"
assert (
res.log_links[0].uri
res.resource.log_links[0].uri
== "https://console.cloud.google.com/bigquery?project=dummy_project&j=bq:us-central1:dummy_id&page=queryresults"
)
agent.delete(metadata_bytes)
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-spark/flytekitplugins/spark/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse:
databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{metadata.run_id}"
log_links = [TaskLog(uri=databricks_console_url, name="Databricks Console").to_flyte_idl()]

return GetTaskResponse(resource=Resource(phase=cur_phase, message=message), log_links=log_links)
return GetTaskResponse(resource=Resource(phase=cur_phase, message=message, log_links=log_links))

async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse:
metadata = pickle.loads(resource_meta)
Expand Down
2 changes: 0 additions & 2 deletions plugins/flytekit-spark/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ async def test_databricks_agent():
assert res.resource.phase == TaskExecution.SUCCEEDED
assert res.resource.outputs == literals.LiteralMap({}).to_flyte_idl()
assert res.resource.message == "OK"
assert res.log_links[0].name == "Databricks Console"
assert res.log_links[0].uri == "https://test-account.cloud.databricks.com/#job/1/run/123"

mocked.post(delete_url, status=http.HTTPStatus.OK, payload=mock_delete_response)
await agent.delete(metadata_bytes)
Expand Down
57 changes: 4 additions & 53 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
DeleteTaskResponse,
GetTaskRequest,
GetTaskResponse,
ListAgentsRequest,
ListAgentsResponse,
Resource,
)
from flyteidl.core.execution_pb2 import TaskExecution

from flytekit import PythonFunctionTask, task
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService
from flytekit.extend.backend.agent_service import AsyncAgentService
from flytekit.extend.backend.base_agent import (
AgentBase,
AgentRegistry,
Expand Down Expand Up @@ -60,8 +58,9 @@ def create(

def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse:
return GetTaskResponse(
resource=Resource(phase=TaskExecution.SUCCEEDED),
log_links=[TaskLog(name="console", uri="localhost:3000").to_flyte_idl()],
resource=Resource(
phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000").to_flyte_idl()]
),
)

def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse:
Expand Down Expand Up @@ -90,24 +89,6 @@ async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse:
return DeleteTaskResponse()


class SyncDummyAgent(AgentBase):
name = "Sync Dummy Agent"

def __init__(self):
super().__init__(task_type="sync_dummy")

def create(
self,
output_prefix: str,
task_template: TaskTemplate,
inputs: typing.Optional[LiteralMap] = None,
**kwargs,
) -> CreateTaskResponse:
return CreateTaskResponse(
resource=Resource(phase=TaskExecution.SUCCEEDED, outputs=LiteralMap({}).to_flyte_idl())
)


def get_task_template(task_type: str) -> TaskTemplate:
@task
def simple_task(i: int):
Expand Down Expand Up @@ -146,8 +127,6 @@ def test_dummy_agent():
assert agent.create("/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes
res = agent.get(metadata_bytes)
assert res.resource.phase == TaskExecution.SUCCEEDED
assert res.log_links[0].name == "console"
assert res.log_links[0].uri == "localhost:3000"
assert agent.delete(metadata_bytes) == DeleteTaskResponse()

class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask):
Expand Down Expand Up @@ -186,19 +165,6 @@ async def test_async_dummy_agent():
assert agent_metadata.supported_task_types == ["async_dummy"]


@pytest.mark.asyncio
async def test_sync_dummy_agent():
AgentRegistry.register(SyncDummyAgent())
agent = AgentRegistry.get_agent("sync_dummy")
res = agent.create("/tmp", sync_dummy_template, task_inputs)
assert res.resource.phase == TaskExecution.SUCCEEDED
assert res.resource.outputs == LiteralMap({}).to_flyte_idl()

agent_metadata = AgentRegistry.get_agent_metadata("Sync Dummy Agent")
assert agent_metadata.name == "Sync Dummy Agent"
assert agent_metadata.supported_task_types == ["sync_dummy"]


@pytest.mark.asyncio
async def run_agent_server():
service = AsyncAgentService()
Expand All @@ -209,10 +175,6 @@ async def run_agent_server():
async_request = CreateTaskRequest(
inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl()
)
sync_request = CreateTaskRequest(
inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=sync_dummy_template.to_flyte_idl()
)
fake_agent = "fake"
metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")

res = await service.CreateTask(request, ctx)
Expand All @@ -229,17 +191,6 @@ async def run_agent_server():
res = await service.DeleteTask(DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx)
assert isinstance(res, DeleteTaskResponse)

res = await service.CreateTask(sync_request, ctx)
assert res.resource.phase == TaskExecution.SUCCEEDED
assert res.resource.outputs == LiteralMap({}).to_flyte_idl()

res = await service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx)
assert res is None

metadata_service = AgentMetadataService()
res = await metadata_service.ListAgent(ListAgentsRequest(), ctx)
assert isinstance(res, ListAgentsResponse)


def test_agent_server():
loop.run_in_executor(None, run_agent_server)
Expand Down
Loading