Skip to content

Commit

Permalink
Add SyncAgentBase (#2146)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Jan Fiedler <[email protected]>
  • Loading branch information
pingsutw authored and fiedlerNr9 committed Jul 25, 2024
1 parent d03886a commit 3a28aff
Show file tree
Hide file tree
Showing 38 changed files with 850 additions and 601 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ jobs:
- python-version: 3.11
plugin-names: "flytekit-whylogs"
steps:
- uses: insightsengineering/disk-space-reclaimer@v1
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand All @@ -296,12 +295,12 @@ jobs:
key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('dev-requirements.txt', format('plugins/{0}/requirements.txt', matrix.plugin-names ))) }}
- name: Install dependencies
run: |
export SETUPTOOLS_SCM_PRETEND_VERSION="2.0.0"
make setup
cd plugins/${{ matrix.plugin-names }}
pip install .
if [ -f dev-requirements.txt ]; then pip install -r dev-requirements.txt; fi
if [ -f dev-requirements.in ]; then pip install -r dev-requirements.in; fi
pip install -U $GITHUB_WORKSPACE
pip install --no-deps -U --force-reinstall "git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl"
pip freeze
- name: Test with coverage
run: |
Expand Down
3 changes: 2 additions & 1 deletion Dockerfile.dev
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ COPY . /flytekit
# 3. Clean up the apt cache to reduce image size. Reference: https://gist.github.com/marvell/7c812736565928e602c4
# 4. Create a non-root user 'flytekit' and set appropriate permissions for directories.
RUN apt-get update && apt-get install build-essential vim libmagic1 git -y \
&& pip install "git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl" \
&& pip install --no-cache-dir -U --pre \
flyteidl \
-e /flytekit \
-e /flytekit/plugins/flytekit-k8s-pod \
-e /flytekit/plugins/flytekit-deck-standard \
Expand All @@ -43,6 +43,7 @@ RUN apt-get update && apt-get install build-essential vim libmagic1 git -y \
&& chown flytekit: /home \
&& :


ENV PYTHONPATH "/flytekit:/flytekit/plugins/flytekit-k8s-pod:/flytekit/plugins/flytekit-deck-standard:"

# Switch to the 'flytekit' user for better security.
Expand Down
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ update_boilerplate:
.PHONY: setup
setup: install-piptools ## Install requirements
pip install -r dev-requirements.in
pip install --no-deps -U --force-reinstall "git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl"

.PHONY: fmt
fmt:
Expand Down
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
-e file:.#egg=flytekit
git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl

coverage[toml]
hypothesis
Expand Down
4 changes: 3 additions & 1 deletion flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from flyteidl.service.agent_pb2_grpc import (
add_AgentMetadataServiceServicer_to_server,
add_AsyncAgentServiceServicer_to_server,
add_SyncAgentServiceServicer_to_server,
)
from grpc import aio

Expand Down Expand Up @@ -52,7 +53,7 @@ def agent(_: click.Context, port, worker, timeout):

async def _start_grpc_server(port: int, worker: int, timeout: int):
click.secho("Starting up the server to expose the prometheus metrics...", fg="blue")
from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService
from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService, SyncAgentService

try:
from prometheus_client import start_http_server
Expand All @@ -64,6 +65,7 @@ async def _start_grpc_server(port: int, worker: int, timeout: int):
server = aio.server(futures.ThreadPoolExecutor(max_workers=worker))

add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server)
add_SyncAgentServiceServicer_to_server(SyncAgentService(), server)
add_AgentMetadataServiceServicer_to_server(AgentMetadataService(), server)

server.add_insecure_port(f"[::]:{port}")
Expand Down
34 changes: 30 additions & 4 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Dict, List, NamedTuple, Optional, Type, cast

from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
from google.protobuf.json_format import MessageToDict as _MessageToDict
Expand Down Expand Up @@ -1164,19 +1165,34 @@ def named_tuple_to_variable_map(cls, t: typing.NamedTuple) -> _interface_models.
@classmethod
@timeit("Translate literal to python value")
def literal_map_to_kwargs(
cls, ctx: FlyteContext, lm: LiteralMap, python_types: typing.Dict[str, type]
cls,
ctx: FlyteContext,
lm: LiteralMap,
python_types: typing.Optional[typing.Dict[str, type]] = None,
literal_types: typing.Optional[typing.Dict[str, _interface_models.Variable]] = None,
) -> typing.Dict[str, typing.Any]:
"""
Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task
"""
if len(lm.literals) > len(python_types):
if python_types is None and literal_types is None:
raise ValueError("At least one of python_types or literal_types must be provided")

if literal_types:
python_interface_inputs = {
name: TypeEngine.guess_python_type(lt.type) for name, lt in literal_types.items()
}
else:
python_interface_inputs = python_types # type: ignore

if len(lm.literals) > len(python_interface_inputs):
raise ValueError(
f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}"
f"Received more input values {len(lm.literals)}"
f" than allowed by the input spec {len(python_interface_inputs)}"
)
kwargs = {}
for i, k in enumerate(lm.literals):
try:
kwargs[k] = TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k])
kwargs[k] = TypeEngine.to_python_value(ctx, lm.literals[k], python_interface_inputs[k])
except TypeTransformerFailedError as exc:
raise TypeTransformerFailedError(f"Error converting input '{k}' at position {i}:\n {exc}") from exc
return kwargs
Expand Down Expand Up @@ -1210,6 +1226,16 @@ def dict_to_literal_map(
raise user_exceptions.FlyteTypeException(type(v), python_type, received_value=v)
return LiteralMap(literal_map)

@classmethod
def dict_to_literal_map_pb(
cls,
ctx: FlyteContext,
d: typing.Dict[str, typing.Any],
type_hints: Optional[typing.Dict[str, type]] = None,
) -> Optional[literals_pb2.LiteralMap]:
literal_map = cls.dict_to_literal_map(ctx, d, type_hints)
return literal_map.to_flyte_idl()

@classmethod
def get_available_transformers(cls) -> typing.KeysView[Type]:
"""
Expand Down
137 changes: 102 additions & 35 deletions flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,42 @@
import typing
from http import HTTPStatus

import grpc
from flyteidl.admin.agent_pb2 import (
CreateTaskRequest,
CreateTaskResponse,
DeleteTaskRequest,
DeleteTaskResponse,
ExecuteTaskSyncRequest,
ExecuteTaskSyncResponse,
ExecuteTaskSyncResponseHeader,
GetAgentRequest,
GetAgentResponse,
GetTaskRequest,
GetTaskResponse,
ListAgentsRequest,
ListAgentsResponse,
Resource,
)
from flyteidl.service.agent_pb2_grpc import (
AgentMetadataServiceServicer,
AsyncAgentServiceServicer,
SyncAgentServiceServicer,
)
from flyteidl.service.agent_pb2_grpc import AgentMetadataServiceServicer, AsyncAgentServiceServicer
from prometheus_client import Counter, Summary

from flytekit import logger
from flytekit import FlyteContext, logger
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions.system import FlyteAgentNotFound
from flytekit.extend.backend.base_agent import AgentRegistry, mirror_async_methods
from flytekit.extend.backend.base_agent import AgentRegistry, SyncAgentBase, mirror_async_methods
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate

metric_prefix = "flyte_agent_"
create_operation = "create"
get_operation = "get"
delete_operation = "delete"
do_operation = "do"

# Follow the naming convention. https://prometheus.io/docs/practices/naming/
request_success_count = Counter(
Expand All @@ -46,7 +57,24 @@
input_literal_size = Summary(f"{metric_prefix}input_literal_bytes", "Size of input literal", ["task_type"])


def agent_exception_handler(func: typing.Callable):
def _handle_exception(e: Exception, context: grpc.ServicerContext, task_type: str, operation: str):
if isinstance(e, FlyteAgentNotFound):
error_message = f"Cannot find agent for task type: {task_type}."
logger.error(error_message)
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(error_message)
request_failure_count.labels(task_type=task_type, operation=operation, error_code=HTTPStatus.NOT_FOUND).inc()
else:
error_message = f"failed to {operation} {task_type} task with error: {e}."
logger.error(error_message)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(error_message)
request_failure_count.labels(
task_type=task_type, operation=operation, error_code=HTTPStatus.INTERNAL_SERVER_ERROR
).inc()


def record_agent_metrics(func: typing.Callable):
async def wrapper(
self,
request: typing.Union[CreateTaskRequest, GetTaskRequest, DeleteTaskRequest],
Expand All @@ -60,10 +88,10 @@ async def wrapper(
if request.inputs:
input_literal_size.labels(task_type=task_type).observe(request.inputs.ByteSize())
elif isinstance(request, GetTaskRequest):
task_type = request.task_type
task_type = request.task_type or request.task_category.name
operation = get_operation
elif isinstance(request, DeleteTaskRequest):
task_type = request.task_type
task_type = request.task_type or request.task_category.name
operation = delete_operation
else:
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
Expand All @@ -75,51 +103,90 @@ async def wrapper(
res = await func(self, request, context, *args, **kwargs)
request_success_count.labels(task_type=task_type, operation=operation).inc()
return res
except FlyteAgentNotFound:
error_message = f"Cannot find agent for task type: {task_type}."
logger.error(error_message)
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(error_message)
request_failure_count.labels(task_type=task_type, operation=operation, error_code="404").inc()
except Exception as e:
error_message = f"failed to {operation} {task_type} task with error {e}."
logger.error(error_message)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(error_message)
request_failure_count.labels(task_type=task_type, operation=operation, error_code="500").inc()
_handle_exception(e, context, task_type, operation)

return wrapper


class AsyncAgentService(AsyncAgentServiceServicer):
@agent_exception_handler
@record_agent_metrics
async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse:
tmp = TaskTemplate.from_flyte_idl(request.template)
template = TaskTemplate.from_flyte_idl(request.template)
inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
agent = AgentRegistry.get_agent(tmp.type)
agent = AgentRegistry.get_agent(template.type, template.task_type_version)

logger.info(f"{tmp.type} agent start creating the job")
return await mirror_async_methods(
agent.create, output_prefix=request.output_prefix, task_template=tmp, inputs=inputs
)
logger.info(f"{agent.name} start creating the job")
resource_mata = await mirror_async_methods(agent.create, task_template=template, inputs=inputs)
return CreateTaskResponse(resource_meta=resource_mata.encode())

@agent_exception_handler
@record_agent_metrics
async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse:
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.task_type} agent start checking the status of the job")
return await mirror_async_methods(agent.get, resource_meta=request.resource_meta)
if request.task_category and request.task_category.name:
agent = AgentRegistry.get_agent(request.task_category.name, request.task_category.version)
else:
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.name} start checking the status of the job")
res = await mirror_async_methods(agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta))

if res.outputs is None:
outputs = None
elif isinstance(res.outputs, LiteralMap):
outputs = res.outputs.to_flyte_idl()
else:
ctx = FlyteContext.current_context()
outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs)
return GetTaskResponse(
resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs)
)

@agent_exception_handler
@record_agent_metrics
async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.task_type} agent start deleting the job")
return await mirror_async_methods(agent.delete, resource_meta=request.resource_meta)
if request.task_category and request.task_category.name:
agent = AgentRegistry.get_agent(request.task_category.name, request.task_category.version)
else:
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.name} start deleting the job")
return await mirror_async_methods(agent.delete, resource_meta=agent.metadata_type.decode(request.resource_meta))


class SyncAgentService(SyncAgentServiceServicer):
async def ExecuteTaskSync(
self, request_iterator: typing.AsyncIterator[ExecuteTaskSyncRequest], context: grpc.ServicerContext
) -> typing.AsyncIterator[ExecuteTaskSyncResponse]:
request = await request_iterator.__anext__()
template = TaskTemplate.from_flyte_idl(request.header.template)
task_type = template.type
try:
with request_latency.labels(task_type=task_type, operation=do_operation).time():
agent = AgentRegistry.get_agent(task_type, template.task_type_version)
if not isinstance(agent, SyncAgentBase):
raise ValueError(f"[{agent.name}] does not support sync execution")

request = await request_iterator.__anext__()
literal_map = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
res = await mirror_async_methods(agent.do, task_template=template, inputs=literal_map)

if res.outputs is None:
outputs = None
elif isinstance(res.outputs, LiteralMap):
outputs = res.outputs.to_flyte_idl()
else:
ctx = FlyteContext.current_context()
outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs)

header = ExecuteTaskSyncResponseHeader(
resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs)
)
yield ExecuteTaskSyncResponse(header=header)
request_success_count.labels(task_type=task_type, operation=do_operation).inc()
except Exception as e:
_handle_exception(e, context, template.type, do_operation)


class AgentMetadataService(AgentMetadataServiceServicer):
async def GetAgent(self, request: GetAgentRequest, context: grpc.ServicerContext) -> GetAgentResponse:
return GetAgentResponse(agent=AgentRegistry._METADATA[request.name])
return GetAgentResponse(agent=AgentRegistry.get_agent_metadata(request.name))

async def ListAgents(self, request: ListAgentsRequest, context: grpc.ServicerContext) -> ListAgentsResponse:
agents = [agent for agent in AgentRegistry._METADATA.values()]
return ListAgentsResponse(agents=agents)
return ListAgentsResponse(agents=AgentRegistry.list_agents())
Loading

0 comments on commit 3a28aff

Please sign in to comment.