Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate custom_info Dict through agent Resource #2426

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 3 additions & 24 deletions flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
GetTaskResponse,
ListAgentsRequest,
ListAgentsResponse,
Resource,
)
from flyteidl.service.agent_pb2_grpc import (
AgentMetadataServiceServicer,
Expand All @@ -25,8 +24,7 @@
)
from prometheus_client import Counter, Summary

from flytekit import FlyteContext, logger
from flytekit.core.type_engine import TypeEngine
from flytekit import logger
from flytekit.exceptions.system import FlyteAgentNotFound
from flytekit.extend.backend.base_agent import AgentRegistry, SyncAgentBase, mirror_async_methods
from flytekit.models.literals import LiteralMap
Expand Down Expand Up @@ -136,16 +134,7 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext)
logger.info(f"{agent.name} start checking the status of the job")
res = await mirror_async_methods(agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta))

if res.outputs is None:
outputs = None
elif isinstance(res.outputs, LiteralMap):
outputs = res.outputs.to_flyte_idl()
else:
ctx = FlyteContext.current_context()
outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs)
return GetTaskResponse(
resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs)
)
return GetTaskResponse(resource=res.to_flyte_idl())

@record_agent_metrics
async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
Expand Down Expand Up @@ -178,17 +167,7 @@ async def ExecuteTaskSync(
agent.do, task_template=template, inputs=literal_map, output_prefix=output_prefix
)

if res.outputs is None:
outputs = None
elif isinstance(res.outputs, LiteralMap):
outputs = res.outputs.to_flyte_idl()
else:
ctx = FlyteContext.current_context()
outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs)

header = ExecuteTaskSyncResponseHeader(
resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs)
)
header = ExecuteTaskSyncResponseHeader(resource=res.to_flyte_idl())
yield ExecuteTaskSyncResponse(header=header)
request_success_count.labels(task_type=task_type, operation=do_operation).inc()
except Exception as e:
Expand Down
36 changes: 35 additions & 1 deletion flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
from typing import Any, Dict, List, Optional, Union

from flyteidl.admin.agent_pb2 import Agent
from flyteidl.admin.agent_pb2 import Resource as _Resource
from flyteidl.admin.agent_pb2 import TaskCategory as _TaskCategory
from flyteidl.core import literals_pb2
from flyteidl.core.execution_pb2 import TaskExecution, TaskLog
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct
from rich.logging import RichHandler
from rich.progress import Progress

Expand All @@ -28,6 +31,7 @@
from flytekit.exceptions.user import FlyteUserException
from flytekit.extend.backend.utils import is_terminal_phase, mirror_async_methods, render_task_template
from flytekit.loggers import set_flytekit_log_properties
from flytekit.models import common
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskExecutionMetadata, TaskTemplate

Expand Down Expand Up @@ -76,7 +80,7 @@ def decode(cls, data: bytes) -> "ResourceMeta":


@dataclass
class Resource:
class Resource(common.FlyteIdlEntity):
"""
This is the output resource of the job.
Expand All @@ -91,6 +95,36 @@ class Resource:
message: Optional[str] = None
log_links: Optional[List[TaskLog]] = None
outputs: Optional[Union[LiteralMap, typing.Dict[str, Any]]] = None
custom_info: Optional[typing.Dict[str, Any]] = None

def to_flyte_idl(self) -> _Resource:
if self.outputs is None:
outputs = None
elif isinstance(self.outputs, LiteralMap):
outputs = self.outputs.to_flyte_idl()
else:
ctx = FlyteContext.current_context()
outputs = TypeEngine.dict_to_literal_map_pb(ctx, self.outputs)

return _Resource(
phase=self.phase,
message=self.message,
log_links=self.log_links,
outputs=outputs,
custom_info=(json_format.Parse(json.dumps(self.custom_info), Struct()) if self.custom_info else None),
)

@classmethod
def from_flyte_idl(cls, pb2_object: _Resource):
return cls(
phase=pb2_object.phase,
message=pb2_object.message,
log_links=pb2_object.log_links,
outputs=(LiteralMap.from_flyte_idl(pb2_object.outputs) if pb2_object.outputs else None),
custom_info=(
json_format.MessageToDict(pb2_object.custom_info) if pb2_object.HasField("custom_info") else None
),
)


class AgentBase(ABC):
Expand Down
96 changes: 86 additions & 10 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,20 @@

from flytekit import PythonFunctionTask, task
from flytekit.clis.sdk_in_container.serve import print_agents_metadata
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.configuration import (
FastSerializationSettings,
Image,
ImageConfig,
SerializationSettings,
)
from flytekit.core.base_task import PythonTask, kwtypes
from flytekit.core.interface import Interface
from flytekit.exceptions.system import FlyteAgentNotFound
from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService, SyncAgentService
from flytekit.extend.backend.agent_service import (
AgentMetadataService,
AsyncAgentService,
SyncAgentService,
)
from flytekit.extend.backend.base_agent import (
AgentRegistry,
AsyncAgentBase,
Expand Down Expand Up @@ -71,7 +80,11 @@ def create(self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap
return DummyMetadata(job_id=dummy_id)

def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource:
return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")])
return Resource(
phase=TaskExecution.SUCCEEDED,
log_links=[TaskLog(name="console", uri="localhost:3000")],
custom_info={"custom": "info", "num": 1},
)

def delete(self, resource_meta: DummyMetadata, **kwargs):
...
Expand All @@ -96,7 +109,11 @@ async def create(
return DummyMetadata(job_id=dummy_id, output_path=output_path, task_name=task_name)

async def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource:
return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")])
return Resource(
phase=TaskExecution.SUCCEEDED,
log_links=[TaskLog(name="console", uri="localhost:3000")],
custom_info={"custom": "info", "num": 1},
)

async def delete(self, resource_meta: DummyMetadata, **kwargs):
...
Expand All @@ -108,7 +125,12 @@ class MockOpenAIAgent(SyncAgentBase):
def __init__(self):
super().__init__(task_type_name="openai")

def do(self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, **kwargs) -> Resource:
def do(
self,
task_template: TaskTemplate,
inputs: typing.Optional[LiteralMap] = None,
**kwargs,
) -> Resource:
assert inputs.literals["a"].scalar.primitive.integer == 1
return Resource(phase=TaskExecution.SUCCEEDED, outputs={"o0": 1})

Expand Down Expand Up @@ -174,6 +196,8 @@ def test_dummy_agent():
assert resource.phase == TaskExecution.SUCCEEDED
assert resource.log_links[0].name == "console"
assert resource.log_links[0].uri == "localhost:3000"
assert resource.custom_info["custom"] == "info"
assert resource.custom_info["num"] == 1
assert agent.delete(metadata) is None

class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask):
Expand All @@ -189,7 +213,9 @@ def __init__(self, **kwargs):


@pytest.mark.parametrize(
"agent,consume_metadata", [(DummyAgent(), False), (AsyncDummyAgent(), True)], ids=["sync", "async"]
"agent,consume_metadata",
[(DummyAgent(), False), (AsyncDummyAgent(), True)],
ids=["sync", "async"],
)
@pytest.mark.asyncio
async def test_async_agent_service(agent, consume_metadata):
Expand Down Expand Up @@ -222,7 +248,10 @@ async def test_async_agent_service(agent, consume_metadata):
assert res.resource_meta == metadata_bytes
res = await service.GetTask(GetTaskRequest(task_category=task_category, resource_meta=metadata_bytes), ctx)
assert res.resource.phase == TaskExecution.SUCCEEDED
res = await service.DeleteTask(DeleteTaskRequest(task_category=task_category, resource_meta=metadata_bytes), ctx)
res = await service.DeleteTask(
DeleteTaskRequest(task_category=task_category, resource_meta=metadata_bytes),
ctx,
)
assert res == DeleteTaskResponse()

agent_metadata = AgentRegistry.get_agent_metadata(agent.name)
Expand Down Expand Up @@ -269,7 +298,9 @@ def test_openai_agent():
class OpenAITask(SyncAgentExecutorMixin, PythonTask):
def __init__(self, **kwargs):
super().__init__(
task_type="openai", interface=Interface(inputs=kwtypes(a=int), outputs=kwtypes(o0=int)), **kwargs
task_type="openai",
interface=Interface(inputs=kwtypes(a=int), outputs=kwtypes(o0=int)),
**kwargs,
)

t = OpenAITask(task_config={}, name="openai task")
Expand Down Expand Up @@ -393,9 +424,54 @@ def test_render_task_template():
@pytest.fixture
def sample_agents():
async_agent = Agent(
name="Sensor", is_sync=False, supported_task_categories=[TaskCategory(name="sensor", version=0)]
name="Sensor",
is_sync=False,
supported_task_categories=[TaskCategory(name="sensor", version=0)],
)
sync_agent = Agent(
name="ChatGPT Agent", is_sync=True, supported_task_categories=[TaskCategory(name="chatgpt", version=0)]
name="ChatGPT Agent",
is_sync=True,
supported_task_categories=[TaskCategory(name="chatgpt", version=0)],
)
return [async_agent, sync_agent]


def test_resource_type():
o = Resource(
phase=TaskExecution.SUCCEEDED,
)
v = o.to_flyte_idl()
assert v
assert v.phase == TaskExecution.SUCCEEDED
assert len(v.log_links) == 0
assert v.message == ""
assert len(v.outputs.literals) == 0
assert len(v.custom_info) == 0

o2 = Resource.from_flyte_idl(v)
assert o2

o = Resource(
phase=TaskExecution.SUCCEEDED,
log_links=[TaskLog(name="console", uri="localhost:3000")],
message="foo",
outputs={"o0": 1},
custom_info={"custom": "info", "num": 1},
)
v = o.to_flyte_idl()
assert v
assert v.phase == TaskExecution.SUCCEEDED
assert v.log_links[0].name == "console"
assert v.log_links[0].uri == "localhost:3000"
assert v.message == "foo"
assert v.outputs.literals["o0"].scalar.primitive.integer == 1
assert v.custom_info["custom"] == "info"
assert v.custom_info["num"] == 1

o2 = Resource.from_flyte_idl(v)
assert o2.phase == o.phase
assert list(o2.log_links) == list(o.log_links)
assert o2.message == o.message
# round-tripping creates a literal map out of outputs
assert o2.outputs.literals["o0"].scalar.primitive.integer == 1
assert o2.custom_info == o.custom_info
Loading