diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index c1700c3f30..1cb3f0500d 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -281,7 +281,6 @@ jobs: - python-version: 3.11 plugin-names: "flytekit-whylogs" steps: - - uses: insightsengineering/disk-space-reclaimer@v1 - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 @@ -296,12 +295,12 @@ jobs: key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('dev-requirements.txt', format('plugins/{0}/requirements.txt', matrix.plugin-names ))) }} - name: Install dependencies run: | + export SETUPTOOLS_SCM_PRETEND_VERSION="2.0.0" make setup cd plugins/${{ matrix.plugin-names }} pip install . - if [ -f dev-requirements.txt ]; then pip install -r dev-requirements.txt; fi + if [ -f dev-requirements.in ]; then pip install -r dev-requirements.in; fi pip install -U $GITHUB_WORKSPACE - pip install --no-deps -U --force-reinstall "git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl" pip freeze - name: Test with coverage run: | diff --git a/Dockerfile.dev b/Dockerfile.dev index 2b85a5f7d2..63019e5e38 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -28,8 +28,8 @@ COPY . /flytekit # 3. Clean up the apt cache to reduce image size. Reference: https://gist.github.com/marvell/7c812736565928e602c4 # 4. Create a non-root user 'flytekit' and set appropriate permissions for directories. RUN apt-get update && apt-get install build-essential vim libmagic1 git -y \ + && pip install "git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl" \ && pip install --no-cache-dir -U --pre \ - flyteidl \ -e /flytekit \ -e /flytekit/plugins/flytekit-k8s-pod \ -e /flytekit/plugins/flytekit-deck-standard \ @@ -43,6 +43,7 @@ RUN apt-get update && apt-get install build-essential vim libmagic1 git -y \ && chown flytekit: /home \ && : + ENV PYTHONPATH "/flytekit:/flytekit/plugins/flytekit-k8s-pod:/flytekit/plugins/flytekit-deck-standard:" # Switch to the 'flytekit' user for better security. diff --git a/Makefile b/Makefile index 585b76a5c0..fa245479dd 100644 --- a/Makefile +++ b/Makefile @@ -26,7 +26,6 @@ update_boilerplate: .PHONY: setup setup: install-piptools ## Install requirements pip install -r dev-requirements.in - pip install --no-deps -U --force-reinstall "git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl" .PHONY: fmt fmt: diff --git a/dev-requirements.in b/dev-requirements.in index d9784f75d0..d866cfc1c8 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -1,4 +1,5 @@ -e file:.#egg=flytekit +git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl coverage[toml] hypothesis diff --git a/flytekit/clis/sdk_in_container/serve.py b/flytekit/clis/sdk_in_container/serve.py index 87f008b084..2783afc727 100644 --- a/flytekit/clis/sdk_in_container/serve.py +++ b/flytekit/clis/sdk_in_container/serve.py @@ -4,6 +4,7 @@ from flyteidl.service.agent_pb2_grpc import ( add_AgentMetadataServiceServicer_to_server, add_AsyncAgentServiceServicer_to_server, + add_SyncAgentServiceServicer_to_server, ) from grpc import aio @@ -52,7 +53,7 @@ def agent(_: click.Context, port, worker, timeout): async def _start_grpc_server(port: int, worker: int, timeout: int): click.secho("Starting up the server to expose the prometheus metrics...", fg="blue") - from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService + from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService, SyncAgentService try: from prometheus_client import start_http_server @@ -64,6 +65,7 @@ async def _start_grpc_server(port: int, worker: int, timeout: int): server = aio.server(futures.ThreadPoolExecutor(max_workers=worker)) add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server) + add_SyncAgentServiceServicer_to_server(SyncAgentService(), server) add_AgentMetadataServiceServicer_to_server(AgentMetadataService(), server) server.add_insecure_port(f"[::]:{port}") diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 76f750233b..9153fca032 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -16,6 +16,7 @@ from typing import Dict, List, NamedTuple, Optional, Type, cast from dataclasses_json import DataClassJsonMixin, dataclass_json +from flyteidl.core import literals_pb2 from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct from google.protobuf.json_format import MessageToDict as _MessageToDict @@ -1164,19 +1165,34 @@ def named_tuple_to_variable_map(cls, t: typing.NamedTuple) -> _interface_models. @classmethod @timeit("Translate literal to python value") def literal_map_to_kwargs( - cls, ctx: FlyteContext, lm: LiteralMap, python_types: typing.Dict[str, type] + cls, + ctx: FlyteContext, + lm: LiteralMap, + python_types: typing.Optional[typing.Dict[str, type]] = None, + literal_types: typing.Optional[typing.Dict[str, _interface_models.Variable]] = None, ) -> typing.Dict[str, typing.Any]: """ Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task """ - if len(lm.literals) > len(python_types): + if python_types is None and literal_types is None: + raise ValueError("At least one of python_types or literal_types must be provided") + + if literal_types: + python_interface_inputs = { + name: TypeEngine.guess_python_type(lt.type) for name, lt in literal_types.items() + } + else: + python_interface_inputs = python_types # type: ignore + + if len(lm.literals) > len(python_interface_inputs): raise ValueError( - f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}" + f"Received more input values {len(lm.literals)}" + f" than allowed by the input spec {len(python_interface_inputs)}" ) kwargs = {} for i, k in enumerate(lm.literals): try: - kwargs[k] = TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k]) + kwargs[k] = TypeEngine.to_python_value(ctx, lm.literals[k], python_interface_inputs[k]) except TypeTransformerFailedError as exc: raise TypeTransformerFailedError(f"Error converting input '{k}' at position {i}:\n {exc}") from exc return kwargs @@ -1210,6 +1226,16 @@ def dict_to_literal_map( raise user_exceptions.FlyteTypeException(type(v), python_type, received_value=v) return LiteralMap(literal_map) + @classmethod + def dict_to_literal_map_pb( + cls, + ctx: FlyteContext, + d: typing.Dict[str, typing.Any], + type_hints: Optional[typing.Dict[str, type]] = None, + ) -> Optional[literals_pb2.LiteralMap]: + literal_map = cls.dict_to_literal_map(ctx, d, type_hints) + return literal_map.to_flyte_idl() + @classmethod def get_available_transformers(cls) -> typing.KeysView[Type]: """ diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 2d4246c6c1..c000b92150 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -1,4 +1,5 @@ import typing +from http import HTTPStatus import grpc from flyteidl.admin.agent_pb2 import ( @@ -6,19 +7,28 @@ CreateTaskResponse, DeleteTaskRequest, DeleteTaskResponse, + ExecuteTaskSyncRequest, + ExecuteTaskSyncResponse, + ExecuteTaskSyncResponseHeader, GetAgentRequest, GetAgentResponse, GetTaskRequest, GetTaskResponse, ListAgentsRequest, ListAgentsResponse, + Resource, +) +from flyteidl.service.agent_pb2_grpc import ( + AgentMetadataServiceServicer, + AsyncAgentServiceServicer, + SyncAgentServiceServicer, ) -from flyteidl.service.agent_pb2_grpc import AgentMetadataServiceServicer, AsyncAgentServiceServicer from prometheus_client import Counter, Summary -from flytekit import logger +from flytekit import FlyteContext, logger +from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.system import FlyteAgentNotFound -from flytekit.extend.backend.base_agent import AgentRegistry, mirror_async_methods +from flytekit.extend.backend.base_agent import AgentRegistry, SyncAgentBase, mirror_async_methods from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -26,6 +36,7 @@ create_operation = "create" get_operation = "get" delete_operation = "delete" +do_operation = "do" # Follow the naming convention. https://prometheus.io/docs/practices/naming/ request_success_count = Counter( @@ -46,7 +57,24 @@ input_literal_size = Summary(f"{metric_prefix}input_literal_bytes", "Size of input literal", ["task_type"]) -def agent_exception_handler(func: typing.Callable): +def _handle_exception(e: Exception, context: grpc.ServicerContext, task_type: str, operation: str): + if isinstance(e, FlyteAgentNotFound): + error_message = f"Cannot find agent for task type: {task_type}." + logger.error(error_message) + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(error_message) + request_failure_count.labels(task_type=task_type, operation=operation, error_code=HTTPStatus.NOT_FOUND).inc() + else: + error_message = f"failed to {operation} {task_type} task with error: {e}." + logger.error(error_message) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(error_message) + request_failure_count.labels( + task_type=task_type, operation=operation, error_code=HTTPStatus.INTERNAL_SERVER_ERROR + ).inc() + + +def record_agent_metrics(func: typing.Callable): async def wrapper( self, request: typing.Union[CreateTaskRequest, GetTaskRequest, DeleteTaskRequest], @@ -60,10 +88,10 @@ async def wrapper( if request.inputs: input_literal_size.labels(task_type=task_type).observe(request.inputs.ByteSize()) elif isinstance(request, GetTaskRequest): - task_type = request.task_type + task_type = request.task_type or request.task_category.name operation = get_operation elif isinstance(request, DeleteTaskRequest): - task_type = request.task_type + task_type = request.task_type or request.task_category.name operation = delete_operation else: context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -75,51 +103,90 @@ async def wrapper( res = await func(self, request, context, *args, **kwargs) request_success_count.labels(task_type=task_type, operation=operation).inc() return res - except FlyteAgentNotFound: - error_message = f"Cannot find agent for task type: {task_type}." - logger.error(error_message) - context.set_code(grpc.StatusCode.NOT_FOUND) - context.set_details(error_message) - request_failure_count.labels(task_type=task_type, operation=operation, error_code="404").inc() except Exception as e: - error_message = f"failed to {operation} {task_type} task with error {e}." - logger.error(error_message) - context.set_code(grpc.StatusCode.INTERNAL) - context.set_details(error_message) - request_failure_count.labels(task_type=task_type, operation=operation, error_code="500").inc() + _handle_exception(e, context, task_type, operation) return wrapper class AsyncAgentService(AsyncAgentServiceServicer): - @agent_exception_handler + @record_agent_metrics async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse: - tmp = TaskTemplate.from_flyte_idl(request.template) + template = TaskTemplate.from_flyte_idl(request.template) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None - agent = AgentRegistry.get_agent(tmp.type) + agent = AgentRegistry.get_agent(template.type, template.task_type_version) - logger.info(f"{tmp.type} agent start creating the job") - return await mirror_async_methods( - agent.create, output_prefix=request.output_prefix, task_template=tmp, inputs=inputs - ) + logger.info(f"{agent.name} start creating the job") + resource_mata = await mirror_async_methods(agent.create, task_template=template, inputs=inputs) + return CreateTaskResponse(resource_meta=resource_mata.encode()) - @agent_exception_handler + @record_agent_metrics async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse: - agent = AgentRegistry.get_agent(request.task_type) - logger.info(f"{agent.task_type} agent start checking the status of the job") - return await mirror_async_methods(agent.get, resource_meta=request.resource_meta) + if request.task_category and request.task_category.name: + agent = AgentRegistry.get_agent(request.task_category.name, request.task_category.version) + else: + agent = AgentRegistry.get_agent(request.task_type) + logger.info(f"{agent.name} start checking the status of the job") + res = await mirror_async_methods(agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta)) + + if res.outputs is None: + outputs = None + elif isinstance(res.outputs, LiteralMap): + outputs = res.outputs.to_flyte_idl() + else: + ctx = FlyteContext.current_context() + outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs) + return GetTaskResponse( + resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs) + ) - @agent_exception_handler + @record_agent_metrics async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse: - agent = AgentRegistry.get_agent(request.task_type) - logger.info(f"{agent.task_type} agent start deleting the job") - return await mirror_async_methods(agent.delete, resource_meta=request.resource_meta) + if request.task_category and request.task_category.name: + agent = AgentRegistry.get_agent(request.task_category.name, request.task_category.version) + else: + agent = AgentRegistry.get_agent(request.task_type) + logger.info(f"{agent.name} start deleting the job") + return await mirror_async_methods(agent.delete, resource_meta=agent.metadata_type.decode(request.resource_meta)) + + +class SyncAgentService(SyncAgentServiceServicer): + async def ExecuteTaskSync( + self, request_iterator: typing.AsyncIterator[ExecuteTaskSyncRequest], context: grpc.ServicerContext + ) -> typing.AsyncIterator[ExecuteTaskSyncResponse]: + request = await request_iterator.__anext__() + template = TaskTemplate.from_flyte_idl(request.header.template) + task_type = template.type + try: + with request_latency.labels(task_type=task_type, operation=do_operation).time(): + agent = AgentRegistry.get_agent(task_type, template.task_type_version) + if not isinstance(agent, SyncAgentBase): + raise ValueError(f"[{agent.name}] does not support sync execution") + + request = await request_iterator.__anext__() + literal_map = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None + res = await mirror_async_methods(agent.do, task_template=template, inputs=literal_map) + + if res.outputs is None: + outputs = None + elif isinstance(res.outputs, LiteralMap): + outputs = res.outputs.to_flyte_idl() + else: + ctx = FlyteContext.current_context() + outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs) + + header = ExecuteTaskSyncResponseHeader( + resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs) + ) + yield ExecuteTaskSyncResponse(header=header) + request_success_count.labels(task_type=task_type, operation=do_operation).inc() + except Exception as e: + _handle_exception(e, context, template.type, do_operation) class AgentMetadataService(AgentMetadataServiceServicer): async def GetAgent(self, request: GetAgentRequest, context: grpc.ServicerContext) -> GetAgentResponse: - return GetAgentResponse(agent=AgentRegistry._METADATA[request.name]) + return GetAgentResponse(agent=AgentRegistry.get_agent_metadata(request.name)) async def ListAgents(self, request: ListAgentsRequest, context: grpc.ServicerContext) -> ListAgentsResponse: - agents = [agent for agent in AgentRegistry._METADATA.values()] - return ListAgentsResponse(agents=agents) + return ListAgentsResponse(agents=AgentRegistry.list_agents()) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 4648f5ecdf..8f4bd96e6e 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -1,72 +1,161 @@ import asyncio -import inspect +import json import signal import sys import time import typing -from abc import ABC +from abc import ABC, abstractmethod from collections import OrderedDict +from dataclasses import asdict, dataclass from functools import partial from types import FrameType, coroutine +from typing import Any, Dict, List, Optional, Union -from flyteidl.admin.agent_pb2 import ( - Agent, - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, -) +from flyteidl.admin.agent_pb2 import Agent +from flyteidl.admin.agent_pb2 import TaskCategory as _TaskCategory from flyteidl.core import literals_pb2 -from flyteidl.core.execution_pb2 import TaskExecution -from flyteidl.core.tasks_pb2 import TaskTemplate +from flyteidl.core.execution_pb2 import TaskExecution, TaskLog from rich.progress import Progress -import flytekit from flytekit import FlyteContext, PythonFunctionTask, logger from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core import utils from flytekit.core.base_task import PythonTask -from flytekit.core.type_engine import TypeEngine +from flytekit.core.type_engine import TypeEngine, dataclass_from_dict from flytekit.exceptions.system import FlyteAgentNotFound from flytekit.exceptions.user import FlyteUserException +from flytekit.extend.backend.utils import is_terminal_phase, mirror_async_methods, render_task_template from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + + +class TaskCategory: + def __init__(self, name: str, version: int = 0): + self._name = name + self._version = version + + def __hash__(self): + return hash((self._name, self._version)) + + def __eq__(self, other: "TaskCategory"): + return self._name == other._name and self._version == other._version + + @property + def name(self) -> str: + return self._name + + @property + def version(self) -> int: + return self._version + + def __str__(self): + return f"{self._name}_v{self._version}" + + +@dataclass +class ResourceMeta: + """ + This is the metadata for the job. For example, the id of the job. + """ + + def encode(self) -> bytes: + """ + Encode the resource meta to bytes. + """ + return json.dumps(asdict(self)).encode("utf-8") + + @classmethod + def decode(cls, data: bytes) -> "ResourceMeta": + """ + Decode the resource meta from bytes. + """ + return dataclass_from_dict(cls, json.loads(data.decode("utf-8"))) + + +@dataclass +class Resource: + """ + This is the output resource of the job. + + Args: + phase: The phase of the job. + message: The return message from the job. + log_links: The log links of the job. For example, the link to the BigQuery Console. + outputs: The outputs of the job. If return python native types, the agent will convert them to flyte literals. + """ + + phase: TaskExecution.Phase + message: Optional[str] = None + log_links: Optional[List[TaskLog]] = None + outputs: Optional[Union[LiteralMap, typing.Dict[str, Any]]] = None + + +T = typing.TypeVar("T", bound=ResourceMeta) class AgentBase(ABC): + name = "Base Agent" + + def __init__(self, task_type_name: str, task_type_version: int = 0, **kwargs): + self._task_category = TaskCategory(name=task_type_name, version=task_type_version) + + @property + def task_category(self) -> TaskCategory: + """ + task category that the agent supports + """ + return self._task_category + + +class SyncAgentBase(AgentBase): """ - This is the base class for all agents. It defines the interface that all agents must implement. - The agent service will be run either locally or in a pod, and will be responsible for - invoking agents. The propeller will communicate with the agent service + This is the base class for all sync agents. It defines the interface that all agents must implement. + The agent service is responsible for invoking agents. + Propeller sends a request to agent service, and gets a response in the same call. + + All the agents should be registered in the AgentRegistry. Agent Service + will look up the agent based on the task type. Every task type can only have one agent. + """ + + name = "Base Sync Agent" + + @abstractmethod + def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs) -> Resource: + """ + This is the method that the agent will run. + """ + raise NotImplementedError + + +class AsyncAgentBase(AgentBase, typing.Generic[T]): + """ + This is the base class for all async agents. It defines the interface that all agents must implement. + The agent service is responsible for invoking agents. The propeller will communicate with the agent service to create tasks, get the status of tasks, and delete tasks. All the agents should be registered in the AgentRegistry. Agent Service will look up the agent based on the task type. Every task type can only have one agent. """ - name = "Base Agent" + name = "Base Async Agent" - def __init__(self, task_type: str, **kwargs): - self._task_type = task_type + def __init__(self, metadata_type: typing.Type[T], **kwargs): + super().__init__(**kwargs) + self._metadata_type = metadata_type @property - def task_type(self) -> str: - """ - task_type is the name of the task type that this agent supports. - """ - return self._task_type - - def create( - self, - output_prefix: str, - task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None, - **kwargs, - ) -> CreateTaskResponse: + def metadata_type(self) -> ResourceMeta: + return self._metadata_type + + @abstractmethod + def create(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs) -> T: """ - Return a Unique ID for the task that was created. It should return error code if the task creation failed. + Return a resource meta that can be used to get the status of the task. """ raise NotImplementedError - def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: + @abstractmethod + def get(self, resource_meta: T, **kwargs) -> Resource: """ Return the status of the task, and return the outputs in some cases. For example, bigquery job can't write the structured dataset to the output location, so it returns the output literals to the propeller, @@ -74,9 +163,10 @@ def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: """ raise NotImplementedError - def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: + @abstractmethod + def delete(self, resource_meta: T, **kwargs): """ - Delete the task. This call should be idempotent. + Delete the task. This call should be idempotent. It should raise an error if fails to delete the task. """ raise NotImplementedError @@ -88,29 +178,42 @@ class AgentRegistry(object): The agent metadata service will look up the agent metadata based on the agent name. """ - _REGISTRY: typing.Dict[str, AgentBase] = {} - _METADATA: typing.Dict[str, Agent] = {} + _REGISTRY: Dict[TaskCategory, Union[AsyncAgentBase, SyncAgentBase]] = {} + _METADATA: Dict[str, Agent] = {} @staticmethod - def register(agent: AgentBase): - if agent.task_type in AgentRegistry._REGISTRY: - raise ValueError(f"Duplicate agent for task type {agent.task_type}") - AgentRegistry._REGISTRY[agent.task_type] = agent + def register(agent: Union[AsyncAgentBase, SyncAgentBase], override: bool = False): + if agent.task_category in AgentRegistry._REGISTRY and override is False: + raise ValueError(f"Duplicate agent for task type: {agent.task_category}") + AgentRegistry._REGISTRY[agent.task_category] = agent + + task_category = _TaskCategory(name=agent.task_category.name, version=agent.task_category.version) if agent.name in AgentRegistry._METADATA: - agent_metadata = AgentRegistry._METADATA[agent.name] - agent_metadata.supported_task_types.append(agent.task_type) + agent_metadata = AgentRegistry.get_agent_metadata(agent.name) + agent_metadata.supported_task_categories.append(task_category) + agent_metadata.supported_task_types.append(task_category.name) else: - agent_metadata = Agent(name=agent.name, supported_task_types=[agent.task_type]) + agent_metadata = Agent( + name=agent.name, + supported_task_types=[task_category.name], + supported_task_categories=[task_category], + is_sync=isinstance(agent, SyncAgentBase), + ) AgentRegistry._METADATA[agent.name] = agent_metadata - logger.info(f"Registering an agent for task type: {agent.task_type}, name: {agent.name}") + logger.info(f"Registering {agent.name} for task type: {agent.task_category}") @staticmethod - def get_agent(task_type: str) -> AgentBase: - if task_type not in AgentRegistry._REGISTRY: - raise FlyteAgentNotFound(f"Cannot find agent for task type: {task_type}.") - return AgentRegistry._REGISTRY[task_type] + def get_agent(task_type_name: str, task_type_version: int = 0) -> Union[SyncAgentBase, AsyncAgentBase]: + task_category = TaskCategory(name=task_type_name, version=task_type_version) + if task_category not in AgentRegistry._REGISTRY: + raise FlyteAgentNotFound(f"Cannot find agent for task category: {task_category}.") + return AgentRegistry._REGISTRY[task_category] + + @staticmethod + def list_agents() -> List[Agent]: + return list(AgentRegistry._METADATA.values()) @staticmethod def get_agent_metadata(name: str) -> Agent: @@ -119,89 +222,87 @@ def get_agent_metadata(name: str) -> Agent: return AgentRegistry._METADATA[name] -def mirror_async_methods(func: typing.Callable, **kwargs) -> typing.Coroutine: - if inspect.iscoroutinefunction(func): - return func(**kwargs) - args = [v for _, v in kwargs.items()] - return asyncio.get_running_loop().run_in_executor(None, func, *args) - - -def convert_to_flyte_phase(state: str) -> TaskExecution.Phase: - """ - Convert the state from the agent to the phase in flyte. +class SyncAgentExecutorMixin: """ - state = state.lower() - # timedout is the state of Databricks job. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate - if state in ["failed", "timeout", "timedout", "canceled"]: - return TaskExecution.FAILED - elif state in ["done", "succeeded", "success"]: - return TaskExecution.SUCCEEDED - elif state in ["running"]: - return TaskExecution.RUNNING - raise ValueError(f"Unrecognized state: {state}") - - -def is_terminal_phase(phase: TaskExecution.Phase) -> bool: - """ - Return true if the phase is terminal. + This mixin class is used to run the sync task locally, and it's only used for local execution. + Task should inherit from this class if the task can be run in the agent. + + Synchronous tasks run quickly and can return their results instantly. + Sending a prompt to ChatGPT and getting a response, or retrieving some metadata from a backend system. """ - return phase in [TaskExecution.SUCCEEDED, TaskExecution.ABORTED, TaskExecution.FAILED] + T = typing.TypeVar("T", "SyncAgentExecutorMixin", PythonTask) + + def execute(self: T, **kwargs) -> LiteralMap: + from flytekit.tools.translator import get_serializable + + ctx = FlyteContext.current_context() + ss = ctx.serialization_settings or SerializationSettings(ImageConfig()) + task_template = get_serializable(OrderedDict(), ss, self).template + + agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version) -def get_agent_secret(secret_key: str) -> str: - return flytekit.current_context().secrets.get(secret_key) + resource = asyncio.run(self._do(agent, task_template, kwargs)) + if resource.phase != TaskExecution.SUCCEEDED: + raise FlyteUserException(f"Failed to run the task {self.name} with error: {resource.message}") + + if resource.outputs and not isinstance(resource.outputs, LiteralMap): + return TypeEngine.dict_to_literal_map(ctx, resource.outputs) + return resource.outputs + + async def _do(self: T, agent: SyncAgentBase, template: TaskTemplate, inputs: Dict[str, Any] = None) -> Resource: + ctx = FlyteContext.current_context() + literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) + return await mirror_async_methods(agent.do, task_template=template, inputs=literal_map) class AsyncAgentExecutorMixin: """ - This mixin class is used to run the agent task locally, and it's only used for local execution. + This mixin class is used to run the async task locally, and it's only used for local execution. Task should inherit from this class if the task can be run in the agent. - It can handle asynchronous tasks and synchronous tasks. + Asynchronous tasks are tasks that take a long time to complete, such as running a query. - Synchronous tasks run quickly and can return their results instantly. Sending a prompt to ChatGPT and getting a response, or retrieving some metadata from a backend system. """ + T = typing.TypeVar("T", "AsyncAgentExecutorMixin", PythonTask) + _clean_up_task: coroutine = None - _agent: AgentBase = None - _entity: PythonTask = None + _agent: AsyncAgentBase = None - def execute(self, **kwargs) -> typing.Any: + def execute(self: T, **kwargs) -> LiteralMap: ctx = FlyteContext.current_context() ss = ctx.serialization_settings or SerializationSettings(ImageConfig()) output_prefix = ctx.file_access.get_random_remote_directory() from flytekit.tools.translator import get_serializable - self._entity = typing.cast(PythonTask, self) - task_template = get_serializable(OrderedDict(), ss, self._entity).template - self._agent = AgentRegistry.get_agent(task_template.type) + task_template = get_serializable(OrderedDict(), ss, self).template + self._agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version) - res = asyncio.run(self._create(task_template, output_prefix, kwargs)) - res = asyncio.run(self._get(resource_meta=res.resource_meta)) + resource_mata = asyncio.run(self._create(task_template, output_prefix, kwargs)) + resource = asyncio.run(self._get(resource_meta=resource_mata)) - if res.resource.phase != TaskExecution.SUCCEEDED: - raise FlyteUserException(f"Failed to run the task {self._entity.name}") + if resource.phase != TaskExecution.SUCCEEDED: + raise FlyteUserException(f"Failed to run the task {self.name} with error: {resource.message}") - # Read the literals from a remote file, if agent doesn't return the output literals. - if task_template.interface.outputs and len(res.resource.outputs.literals) == 0: + # Read the literals from a remote file if the agent doesn't return the output literals. + if task_template.interface.outputs and resource.outputs is None: local_outputs_file = ctx.file_access.get_random_local_path() - ctx.file_access.get_data(f"{output_prefix}/output/outputs.pb", local_outputs_file) + ctx.file_access.get_data(f"{output_prefix}/outputs.pb", local_outputs_file) output_proto = utils.load_proto_from_file(literals_pb2.LiteralMap, local_outputs_file) return LiteralMap.from_flyte_idl(output_proto) - return LiteralMap.from_flyte_idl(res.resource.outputs) + if resource.outputs and not isinstance(resource.outputs, LiteralMap): + return TypeEngine.dict_to_literal_map(ctx, resource.outputs) + + return resource.outputs async def _create( - self, task_template: TaskTemplate, output_prefix: str, inputs: typing.Dict[str, typing.Any] = None - ) -> CreateTaskResponse: + self: T, task_template: TaskTemplate, output_prefix: str, inputs: Dict[str, Any] = None + ) -> ResourceMeta: ctx = FlyteContext.current_context() - # Convert python inputs to literals - literals = inputs or {} - for k, v in inputs.items(): - literals[k] = TypeEngine.to_literal(ctx, v, type(v), self._entity.interface.inputs[k].type) - literal_map = LiteralMap(literals) - + literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) if isinstance(self, PythonFunctionTask): # Write the inputs to a remote file, so that the remote task can read the inputs from this file. path = ctx.file_access.get_random_local_path() @@ -209,58 +310,47 @@ async def _create( ctx.file_access.put_data(path, f"{output_prefix}/inputs.pb") task_template = render_task_template(task_template, output_prefix) - res = await mirror_async_methods( + resource_meta = await mirror_async_methods( self._agent.create, - output_prefix=output_prefix, task_template=task_template, inputs=literal_map, ) - signal.signal(signal.SIGINT, partial(self.signal_handler, res.resource_meta)) # type: ignore - return res + signal.signal(signal.SIGINT, partial(self.signal_handler, resource_meta)) # type: ignore + return resource_meta - async def _get(self, resource_meta: bytes) -> GetTaskResponse: + async def _get(self: T, resource_meta: ResourceMeta) -> Resource: phase = TaskExecution.RUNNING progress = Progress(transient=True) - task = progress.add_task(f"[cyan]Running Task {self._entity.name}...", total=None) + task = progress.add_task(f"[cyan]Running Task {self.name}...", total=None) task_phase = progress.add_task("[cyan]Task phase: RUNNING, Phase message: ", total=None, visible=False) task_log_links = progress.add_task("[cyan]Log Links: ", total=None, visible=False) with progress: while not is_terminal_phase(phase): progress.start_task(task) time.sleep(1) - res = await mirror_async_methods(self._agent.get, resource_meta=resource_meta) + resource = await mirror_async_methods(self._agent.get, resource_meta=resource_meta) if self._clean_up_task: await self._clean_up_task sys.exit(1) - phase = res.resource.phase + phase = resource.phase progress.update( task_phase, - description=f"[cyan]Task phase: {TaskExecution.Phase.Name(phase)}, Phase message: {res.resource.message}", + description=f"[cyan]Task phase: {TaskExecution.Phase.Name(phase)}, Phase message: {resource.message}", visible=True, ) - log_links = "" - for link in res.resource.log_links: - log_links += f"{link.name}: {link.uri}\n" - if log_links: - progress.update(task_log_links, description=f"[cyan]{log_links}", visible=True) + if resource.log_links: + log_links = "" + for link in resource.log_links: + log_links += f"{link.name}: {link.uri}\n" + if log_links: + progress.update(task_log_links, description=f"[cyan]{log_links}", visible=True) - return res + return resource - def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any: + def signal_handler(self, resource_meta: ResourceMeta, signum: int, frame: FrameType) -> Any: if self._clean_up_task is None: co = mirror_async_methods(self._agent.delete, resource_meta=resource_meta) self._clean_up_task = asyncio.create_task(co) - - -def render_task_template(tt: TaskTemplate, file_prefix: str) -> TaskTemplate: - args = tt.container.args - for i in range(len(args)): - tt.container.args[i] = args[i].replace("{{.input}}", f"{file_prefix}/inputs.pb") - tt.container.args[i] = args[i].replace("{{.outputPrefix}}", f"{file_prefix}/output") - tt.container.args[i] = args[i].replace("{{.rawOutputDataPrefix}}", f"{file_prefix}/raw_output") - tt.container.args[i] = args[i].replace("{{.checkpointOutputPrefix}}", f"{file_prefix}/checkpoint_output") - tt.container.args[i] = args[i].replace("{{.prevCheckpointPrefix}}", f"{file_prefix}/prev_checkpoint") - return tt diff --git a/flytekit/extend/backend/utils.py b/flytekit/extend/backend/utils.py new file mode 100644 index 0000000000..b20c9fdf66 --- /dev/null +++ b/flytekit/extend/backend/utils.py @@ -0,0 +1,52 @@ +import asyncio +import inspect +from typing import Callable, Coroutine + +from flyteidl.core.execution_pb2 import TaskExecution + +import flytekit +from flytekit.models.task import TaskTemplate + + +def mirror_async_methods(func: Callable, **kwargs) -> Coroutine: + if inspect.iscoroutinefunction(func): + return func(**kwargs) + args = [v for _, v in kwargs.items()] + return asyncio.get_running_loop().run_in_executor(None, func, *args) + + +def convert_to_flyte_phase(state: str) -> TaskExecution.Phase: + """ + Convert the state from the agent to the phase in flyte. + """ + state = state.lower() + # timedout is the state of Databricks job. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate + if state in ["failed", "timeout", "timedout", "canceled"]: + return TaskExecution.FAILED + elif state in ["done", "succeeded", "success"]: + return TaskExecution.SUCCEEDED + elif state in ["running"]: + return TaskExecution.RUNNING + raise ValueError(f"Unrecognized state: {state}") + + +def is_terminal_phase(phase: TaskExecution.Phase) -> bool: + """ + Return true if the phase is terminal. + """ + return phase in [TaskExecution.SUCCEEDED, TaskExecution.ABORTED, TaskExecution.FAILED] + + +def get_agent_secret(secret_key: str) -> str: + return flytekit.current_context().secrets.get(secret_key) + + +def render_task_template(tt: TaskTemplate, file_prefix: str) -> TaskTemplate: + args = tt.container.args + for i in range(len(args)): + tt.container.args[i] = args[i].replace("{{.input}}", f"{file_prefix}/inputs.pb") + tt.container.args[i] = args[i].replace("{{.outputPrefix}}", f"{file_prefix}") + tt.container.args[i] = args[i].replace("{{.rawOutputDataPrefix}}", f"{file_prefix}/raw_output") + tt.container.args[i] = args[i].replace("{{.checkpointOutputPrefix}}", f"{file_prefix}/checkpoint_output") + tt.container.args[i] = args[i].replace("{{.prevCheckpointPrefix}}", f"{file_prefix}/prev_checkpoint") + return tt diff --git a/flytekit/sensor/base_sensor.py b/flytekit/sensor/base_sensor.py index 0e40055ea5..3392f77009 100644 --- a/flytekit/sensor/base_sensor.py +++ b/flytekit/sensor/base_sensor.py @@ -1,26 +1,48 @@ import collections import inspect +import typing from abc import abstractmethod +from dataclasses import asdict, dataclass from typing import Any, Dict, Optional, TypeVar -import jsonpickle -from typing_extensions import get_type_hints +from typing_extensions import Protocol, get_type_hints, runtime_checkable from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin, ResourceMeta -T = TypeVar("T") -SENSOR_MODULE = "sensor_module" -SENSOR_NAME = "sensor_name" -SENSOR_CONFIG_PKL = "sensor_config_pkl" -INPUTS = "inputs" + +@runtime_checkable +class SensorConfig(Protocol): + def to_dict(self) -> typing.Dict[str, Any]: + """ + Serialize the sensor config to a dictionary. + """ + raise NotImplementedError + + @classmethod + def from_dict(cls, d: typing.Dict[str, Any]) -> "SensorConfig": + """ + Deserialize the sensor config from a dictionary. + """ + raise NotImplementedError + + +@dataclass +class SensorMetadata(ResourceMeta): + sensor_module: str + sensor_name: str + sensor_config: Optional[dict] = None + inputs: Optional[dict] = None + + +T = TypeVar("T", bound=SensorConfig) class BaseSensor(AsyncAgentExecutorMixin, PythonTask): """ - Base class for all sensors. Sensors are tasks that are designed to run forever, and periodically check for some + Base class for all sensors. Sensors are tasks that are designed to run forever and periodically check for some condition to be met. When the condition is met, the sensor will complete. Sensors are designed to be run by the sensor agent, and not by the Flyte engine. """ @@ -57,10 +79,9 @@ async def poke(self, **kwargs) -> bool: raise NotImplementedError def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - cfg = { - SENSOR_MODULE: type(self).__module__, - SENSOR_NAME: type(self).__name__, - } - if self._sensor_config is not None: - cfg[SENSOR_CONFIG_PKL] = jsonpickle.encode(self._sensor_config) - return cfg + sensor_config = self._sensor_config.to_dict() if self._sensor_config else None + return asdict( + SensorMetadata( + sensor_module=type(self).__module__, sensor_name=type(self).__name__, sensor_config=sensor_config + ) + ) diff --git a/flytekit/sensor/file_sensor.py b/flytekit/sensor/file_sensor.py index 2fb3d64ec1..f894546927 100644 --- a/flytekit/sensor/file_sensor.py +++ b/flytekit/sensor/file_sensor.py @@ -1,14 +1,10 @@ -from typing import Optional, TypeVar - from flytekit import FlyteContextManager from flytekit.sensor.base_sensor import BaseSensor -T = TypeVar("T") - class FileSensor(BaseSensor): - def __init__(self, name: str, config: Optional[T] = None, **kwargs): - super().__init__(name=name, sensor_config=config, **kwargs) + def __init__(self, name: str, **kwargs): + super().__init__(name=name, **kwargs) async def poke(self, path: str) -> bool: file_access = FlyteContextManager.current_context().file_access diff --git a/flytekit/sensor/sensor_engine.py b/flytekit/sensor/sensor_engine.py index 816360715a..ac718abe35 100644 --- a/flytekit/sensor/sensor_engine.py +++ b/flytekit/sensor/sensor_engine.py @@ -1,62 +1,49 @@ import importlib -import typing from typing import Optional -import cloudpickle -import jsonpickle -from flyteidl.admin.agent_pb2 import ( - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, - Resource, -) from flyteidl.core.execution_pb2 import TaskExecution from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from flytekit.sensor.base_sensor import INPUTS, SENSOR_CONFIG_PKL, SENSOR_MODULE, SENSOR_NAME +from flytekit.sensor.base_sensor import SensorMetadata -T = typing.TypeVar("T") - -class SensorEngine(AgentBase): +class SensorEngine(AsyncAgentBase): name = "Sensor" def __init__(self): - super().__init__(task_type="sensor") - - async def create( - self, output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs - ) -> CreateTaskResponse: - python_interface_inputs = { - name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() - } - ctx = FlyteContextManager.current_context() + super().__init__(task_type_name="sensor", metadata_type=SensorMetadata) + + async def create(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwarg) -> SensorMetadata: + sensor_metadata = SensorMetadata(**task_template.custom) + if inputs: + ctx = FlyteContextManager.current_context() + python_interface_inputs = { + name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() + } native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) - task_template.custom[INPUTS] = native_inputs - return CreateTaskResponse(resource_meta=cloudpickle.dumps(task_template.custom)) + sensor_metadata.inputs = native_inputs - async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - meta = cloudpickle.loads(resource_meta) + return sensor_metadata - sensor_module = importlib.import_module(name=meta[SENSOR_MODULE]) - sensor_def = getattr(sensor_module, meta[SENSOR_NAME]) - sensor_config = jsonpickle.decode(meta[SENSOR_CONFIG_PKL]) if meta.get(SENSOR_CONFIG_PKL) else None + async def get(self, resource_meta: SensorMetadata, **kwargs) -> Resource: + sensor_module = importlib.import_module(name=resource_meta.sensor_module) + sensor_def = getattr(sensor_module, resource_meta.sensor_name) - inputs = meta.get(INPUTS, {}) + inputs = resource_meta.inputs cur_phase = ( TaskExecution.SUCCEEDED - if await sensor_def("sensor", config=sensor_config).poke(**inputs) + if await sensor_def("sensor", config=resource_meta.sensor_config).poke(**inputs) else TaskExecution.RUNNING ) - return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=None)) + return Resource(phase=cur_phase, outputs=None) - async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - return DeleteTaskResponse() + async def delete(self, resource_meta: SensorMetadata, **kwargs): + return AgentRegistry.register(SensorEngine()) diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 06b0c44a87..da6bc4d699 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -158,7 +158,7 @@ def new_remote_file(cls, name: typing.Optional[str] = None) -> FlyteFile: return cls(path=remote_path) def __class_getitem__(cls, item: typing.Union[str, typing.Type]) -> typing.Type[FlyteFile]: - from . import FileExt + from flytekit.types.file import FileExt if item is None: return cls diff --git a/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py b/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py index e52453d7bb..2ff0d0e9a5 100644 --- a/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py +++ b/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py @@ -5,12 +5,6 @@ import cloudpickle import jsonpickle -from flyteidl.admin.agent_pb2 import ( - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, - Resource, -) from flyteidl.core.execution_pb2 import TaskExecution from flytekitplugins.airflow.task import AirflowObj, _get_airflow_instance @@ -21,13 +15,13 @@ from airflow.utils.context import Context from flytekit import logger from flytekit.exceptions.user import FlyteUserException -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @dataclass -class ResourceMetadata: +class AirflowMetadata(ResourceMeta): """ This class is used to store the Airflow task configuration. It is serialized and returned to FlytePropeller. """ @@ -37,8 +31,15 @@ class ResourceMetadata: airflow_trigger_callback: str = field(default=None) job_id: typing.Optional[str] = field(default=None) + def encode(self) -> bytes: + return cloudpickle.dumps(self) -class AirflowAgent(AgentBase): + @classmethod + def decode(cls, data: bytes) -> "AirflowMetadata": + return cloudpickle.loads(data) + + +class AirflowAgent(AsyncAgentBase): """ It is used to run Airflow tasks. It is registered as an agent in the AgentRegistry. There are three kinds of Airflow tasks: AirflowOperator, AirflowSensor, and AirflowHook. @@ -62,22 +63,18 @@ class AirflowAgent(AgentBase): name = "Airflow Agent" def __init__(self): - super().__init__(task_type="airflow") + super().__init__(task_type_name="airflow", metadata_type=AirflowMetadata) async def create( - self, - output_prefix: str, - task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, - **kwargs, - ) -> CreateTaskResponse: + self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + ) -> AirflowMetadata: airflow_obj = jsonpickle.decode(task_template.custom["task_config_pkl"]) airflow_instance = _get_airflow_instance(airflow_obj) - resource_meta = ResourceMetadata(airflow_operator=airflow_obj) + resource_meta = AirflowMetadata(airflow_operator=airflow_obj) if isinstance(airflow_instance, BaseOperator) and not isinstance(airflow_instance, BaseSensorOperator): try: - resource_meta = ResourceMetadata(airflow_operator=airflow_obj) + resource_meta = AirflowMetadata(airflow_operator=airflow_obj) airflow_instance.execute(context=Context()) except TaskDeferred as td: parameters = td.trigger.__dict__.copy() @@ -90,12 +87,13 @@ async def create( ) resource_meta.airflow_trigger_callback = td.method_name - return CreateTaskResponse(resource_meta=cloudpickle.dumps(resource_meta)) + return resource_meta - async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - meta = cloudpickle.loads(resource_meta) - airflow_operator_instance = _get_airflow_instance(meta.airflow_operator) - airflow_trigger_instance = _get_airflow_instance(meta.airflow_trigger) if meta.airflow_trigger else None + async def get(self, resource_meta: AirflowMetadata, **kwargs) -> Resource: + airflow_operator_instance = _get_airflow_instance(resource_meta.airflow_operator) + airflow_trigger_instance = ( + _get_airflow_instance(resource_meta.airflow_trigger) if resource_meta.airflow_trigger else None + ) airflow_ctx = Context() message = None cur_phase = TaskExecution.RUNNING @@ -107,7 +105,7 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: if airflow_trigger_instance: try: # Airflow trigger returns immediately when - # 1. Failed to get the task status + # 1. Failed to get task status # 2. Task succeeded or failed # succeeded or failed: returns a TriggerEvent with payload # running: runs forever, so set a default timeout (2 seconds) here. @@ -115,7 +113,7 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: event = await asyncio.wait_for(airflow_trigger_instance.run().__anext__(), 2) try: # Trigger callback will check the status of the task in the payload, and raise AirflowException if failed. - trigger_callback = getattr(airflow_operator_instance, meta.airflow_trigger_callback) + trigger_callback = getattr(airflow_operator_instance, resource_meta.airflow_trigger_callback) trigger_callback(context=airflow_ctx, event=typing.cast(TriggerEvent, event).payload) cur_phase = TaskExecution.SUCCEEDED except AirflowException as e: @@ -136,10 +134,10 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: else: raise FlyteUserException("Only sensor and operator are supported.") - return GetTaskResponse(resource=Resource(phase=cur_phase, message=message)) + return Resource(phase=cur_phase, message=message) - async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - return DeleteTaskResponse() + async def delete(self, resource_meta: AirflowMetadata, **kwargs): + return AgentRegistry.register(AirflowAgent()) diff --git a/plugins/flytekit-airflow/setup.py b/plugins/flytekit-airflow/setup.py index 682cd72c18..09536d2e90 100644 --- a/plugins/flytekit-airflow/setup.py +++ b/plugins/flytekit-airflow/setup.py @@ -6,8 +6,8 @@ plugin_requires = [ "apache-airflow", - "flytekit>=1.9.0", - "flyteidl>=1.10.6", + "flytekit>1.10.7", + "flyteidl>1.10.7", ] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-airflow/tests/test_agent.py b/plugins/flytekit-airflow/tests/test_agent.py index dc4d167b10..57999d5c59 100644 --- a/plugins/flytekit-airflow/tests/test_agent.py +++ b/plugins/flytekit-airflow/tests/test_agent.py @@ -5,10 +5,9 @@ from airflow.operators.python import PythonOperator from airflow.sensors.bash import BashSensor from airflow.sensors.time_sensor import TimeSensor -from flyteidl.admin.agent_pb2 import DeleteTaskResponse from flyteidl.core.execution_pb2 import TaskExecution from flytekitplugins.airflow import AirflowObj -from flytekitplugins.airflow.agent import AirflowAgent, ResourceMetadata +from flytekitplugins.airflow.agent import AirflowAgent, AirflowMetadata from flytekit import workflow from flytekit.interfaces.cli_identifiers import Identifier @@ -44,7 +43,7 @@ def test_resource_metadata(): parameters={"task_id": "id", "bash_command": "echo 'hello world'"}, ) trigger_cfg = AirflowObj(module="airflow.trigger.file", name="FileTrigger", parameters={"filepath": "file.txt"}) - meta = ResourceMetadata( + meta = AirflowMetadata( airflow_operator=task_cfg, airflow_trigger=trigger_cfg, airflow_trigger_callback="execute_complete", @@ -89,10 +88,9 @@ async def test_airflow_agent(): ) agent = AirflowAgent() - res = await agent.create("/tmp", dummy_template, None) - metadata = res.resource_meta - res = await agent.get(metadata) - assert res.resource.phase == TaskExecution.SUCCEEDED - assert res.resource.message == "" + metadata = await agent.create(dummy_template, None) + resource = await agent.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.message is None res = await agent.delete(metadata) - assert res == DeleteTaskResponse() + assert res is None diff --git a/plugins/flytekit-aws-batch/setup.py b/plugins/flytekit-aws-batch/setup.py index db75ce18b9..423e439ba2 100644 --- a/plugins/flytekit-aws-batch/setup.py +++ b/plugins/flytekit-aws-batch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0"] +plugin_requires = ["flytekit>=1.3.0b2"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py index 0418d4f809..0275162f72 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py @@ -1,25 +1,16 @@ import datetime -import json -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Dict, Optional -from flyteidl.admin.agent_pb2 import ( - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, - Resource, -) -from flyteidl.core.execution_pb2 import TaskExecution +from flyteidl.core.execution_pb2 import TaskExecution, TaskLog from google.cloud import bigquery from flytekit import FlyteContextManager, StructuredDataset, logger from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_phase -from flytekit.models import literals -from flytekit.models.core.execution import TaskLog +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from flytekit.models.types import LiteralType, StructuredDatasetType pythonTypeToBigQueryType: Dict[type, str] = { # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#data_type_sizes @@ -34,25 +25,24 @@ @dataclass -class Metadata: +class BigQueryMetadata(ResourceMeta): job_id: str project: str location: str -class BigQueryAgent(AgentBase): +class BigQueryAgent(AsyncAgentBase[BigQueryMetadata]): name = "Bigquery Agent" def __init__(self): - super().__init__(task_type="bigquery_query_job_task") + super().__init__(task_type_name="bigquery_query_job_task", metadata_type=BigQueryMetadata) def create( self, - output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs, - ) -> CreateTaskResponse: + ) -> BigQueryMetadata: job_config = None if inputs: ctx = FlyteContextManager.current_context() @@ -73,54 +63,36 @@ def create( location = custom["Location"] client = bigquery.Client(project=project, location=location) query_job = client.query(task_template.sql.statement, job_config=job_config) - metadata = Metadata(job_id=str(query_job.job_id), location=location, project=project) - return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) + return BigQueryMetadata(job_id=str(query_job.job_id), location=location, project=project) - def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: + def get(self, resource_meta: BigQueryMetadata, **kwargs) -> Resource: client = bigquery.Client() - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - log_links = [ - TaskLog( - uri=f"https://console.cloud.google.com/bigquery?project={metadata.project}&j=bq:{metadata.location}:{metadata.job_id}&page=queryresults", - name="BigQuery Console", - ).to_flyte_idl() - ] - - job = client.get_job(metadata.job_id, metadata.project, metadata.location) + log_link = TaskLog( + uri=f"https://console.cloud.google.com/bigquery?project={resource_meta.project}&j=bq:{resource_meta.location}:{resource_meta.job_id}&page=queryresults", + name="BigQuery Console", + ) + + job = client.get_job(resource_meta.job_id, resource_meta.project, resource_meta.location) if job.errors: logger.error("failed to run BigQuery job with error:", job.errors.__str__()) - return GetTaskResponse( - resource=Resource(state=TaskExecution.FAILED, message=job.errors.__str__()), log_links=log_links - ) + return Resource(phase=TaskExecution.FAILED, message=job.errors.__str__(), log_links=[log_link]) cur_phase = convert_to_flyte_phase(str(job.state)) res = None if cur_phase == TaskExecution.SUCCEEDED: - ctx = FlyteContextManager.current_context() - if job.destination: - output_location = ( - f"bq://{job.destination.project}:{job.destination.dataset_id}.{job.destination.table_id}" - ) - res = literals.LiteralMap( - { - "results": TypeEngine.to_literal( - ctx, - StructuredDataset(uri=output_location), - StructuredDataset, - LiteralType(structured_dataset_type=StructuredDatasetType(format="")), - ) - } - ).to_flyte_idl() - - return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res, log_links=log_links)) - - def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: + dst = job.destination + if dst: + ctx = FlyteContextManager.current_context() + output_location = f"bq://{dst.project}:{dst.dataset_id}.{dst.table_id}" + res = TypeEngine.dict_to_literal_map(ctx, {"results": StructuredDataset(uri=output_location)}) + + return Resource(phase=cur_phase, message=job.state, log_links=[log_link], outputs=res) + + def delete(self, resource_meta: BigQueryMetadata, **kwargs): client = bigquery.Client() - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - client.cancel_job(metadata.job_id, metadata.project, metadata.location) - return DeleteTaskResponse() + client.cancel_job(resource_meta.job_id, resource_meta.project, resource_meta.location) AgentRegistry.register(BigQueryAgent()) diff --git a/plugins/flytekit-bigquery/setup.py b/plugins/flytekit-bigquery/setup.py index 10dd3c7ca5..9f2dea65c0 100644 --- a/plugins/flytekit-bigquery/setup.py +++ b/plugins/flytekit-bigquery/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "google-cloud-bigquery", "flyteidl>=v1.10.6"] +plugin_requires = ["flytekit>1.10.7", "google-cloud-bigquery", "flyteidl>1.10.7"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-bigquery/tests/test_agent.py b/plugins/flytekit-bigquery/tests/test_agent.py index 0293ebd3cc..5897b4b468 100644 --- a/plugins/flytekit-bigquery/tests/test_agent.py +++ b/plugins/flytekit-bigquery/tests/test_agent.py @@ -1,10 +1,8 @@ -import json -from dataclasses import asdict from datetime import timedelta from unittest import mock from flyteidl.core.execution_pb2 import TaskExecution -from flytekitplugins.bigquery.agent import Metadata +from flytekitplugins.bigquery.agent import BigQueryMetadata import flytekit.models.interface as interface_models from flytekit.extend.backend.base_agent import AgentRegistry @@ -86,20 +84,18 @@ def __init__(self): sql=Sql("SELECT 1"), ) - metadata_bytes = json.dumps( - asdict(Metadata(job_id="dummy_id", project="dummy_project", location="us-central1")) - ).encode("utf-8") - assert agent.create("/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes - res = agent.get(metadata_bytes) - assert res.resource.phase == TaskExecution.SUCCEEDED + metadata = BigQueryMetadata(job_id="dummy_id", project="dummy_project", location="us-central1") + assert agent.create(dummy_template, task_inputs) == metadata + resource = agent.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED assert ( - res.resource.outputs.literals["results"].scalar.structured_dataset.uri + resource.outputs.literals["results"].scalar.structured_dataset.uri == "bq://dummy_project:dummy_dataset.dummy_table" ) - assert res.resource.log_links[0].name == "BigQuery Console" + assert resource.log_links[0].name == "BigQuery Console" assert ( - res.resource.log_links[0].uri + resource.log_links[0].uri == "https://console.cloud.google.com/bigquery?project=dummy_project&j=bq:us-central1:dummy_id&page=queryresults" ) - agent.delete(metadata_bytes) + agent.delete(metadata) mock_instance.cancel_job.assert_called() diff --git a/plugins/flytekit-greatexpectations/setup.py b/plugins/flytekit-greatexpectations/setup.py index 0ef3fcf2fc..506dd4853b 100644 --- a/plugins/flytekit-greatexpectations/setup.py +++ b/plugins/flytekit-greatexpectations/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit>=1.5.0,<2.0.0", + "flytekit>=1.5.0", "great-expectations>=0.13.30,<=0.18.8", "sqlalchemy>=1.4.23,<2.0.0", "pyspark==3.3.1", diff --git a/plugins/flytekit-k8s-pod/setup.py b/plugins/flytekit-k8s-pod/setup.py index 9767c24ddb..1a3479805b 100644 --- a/plugins/flytekit-k8s-pod/setup.py +++ b/plugins/flytekit-k8s-pod/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit>=1.3.0b2,<2.0.0", + "flytekit>=1.3.0b2", "kubernetes>=12.0.1", ] diff --git a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py index 285be4e88b..e0dbceada2 100644 --- a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py +++ b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py @@ -1,30 +1,29 @@ import json import shlex import subprocess -from dataclasses import asdict, dataclass +from dataclasses import dataclass from tempfile import NamedTemporaryFile from typing import Optional -from flyteidl.admin.agent_pb2 import CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource from flytekitplugins.mmcloud.utils import async_check_output, mmcloud_status_to_flyte_phase from flytekit import current_context -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta from flytekit.loggers import logger from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @dataclass -class Metadata: +class MMCloudMetadata(ResourceMeta): job_id: str -class MMCloudAgent(AgentBase): +class MMCloudAgent(AsyncAgentBase): name = "MMCloud Agent" def __init__(self): - super().__init__(task_type="mmcloud_task", asynchronous=True) + super().__init__(task_type_name="mmcloud_task", metadata_type=MMCloudMetadata) self._response_format = ["--format", "json"] async def async_login(self): @@ -57,10 +56,10 @@ async def async_login(self): logger.info("Logged in to OpCenter") async def create( - self, output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs - ) -> CreateTaskResponse: + self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + ) -> MMCloudMetadata: """ - Submit Flyte task as MMCloud job to the OpCenter, and return the job UID for the task. + Submit a Flyte task as MMCloud job to the OpCenter, and return the job UID for the task. """ submit_command = [ "float", @@ -128,16 +127,13 @@ async def create( logger.exception("Cannot open job script for writing") raise - metadata = Metadata(job_id=job_id) + return MMCloudMetadata(job_id=job_id) - return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) - - async def async_get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: + async def get(self, resource_meta: MMCloudMetadata, **kwargs) -> Resource: """ Return the status of the task, and return the outputs on success. """ - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - job_id = metadata.job_id + job_id = resource_meta.job_id show_command = [ "float", @@ -173,14 +169,13 @@ async def async_get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: logger.info(f"Obtained status for MMCloud job {job_id}: {job_status}") logger.debug(f"OpCenter response: {show_response}") - return GetTaskResponse(resource=Resource(phase=task_phase)) + return Resource(phase=task_phase) - async def async_delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: + async def delete(self, resource_meta: MMCloudMetadata, **kwargs): """ Delete the task. This call should be idempotent. """ - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - job_id = metadata.job_id + job_id = resource_meta.job_id cancel_command = [ "float", @@ -203,7 +198,5 @@ async def async_delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskRespon logger.info(f"Submitted cancel request for MMCloud job: {job_id}") - return DeleteTaskResponse() - AgentRegistry.register(MMCloudAgent()) diff --git a/plugins/flytekit-mmcloud/tests/test_mmcloud.py b/plugins/flytekit-mmcloud/tests/test_mmcloud.py index eff4c4e63c..79830e2c56 100644 --- a/plugins/flytekit-mmcloud/tests/test_mmcloud.py +++ b/plugins/flytekit-mmcloud/tests/test_mmcloud.py @@ -115,7 +115,7 @@ def say_hello0(name: str) -> str: assert isinstance(agent, MMCloudAgent) create_task_response = asyncio.run( - agent.async_create( + agent.create( context=context, output_prefix="", task_template=task_spec.template, @@ -124,13 +124,13 @@ def say_hello0(name: str) -> str: ) resource_meta = create_task_response.resource_meta - get_task_response = asyncio.run(agent.async_get(context=context, resource_meta=resource_meta)) + get_task_response = asyncio.run(agent.get(context=context, resource_meta=resource_meta)) phase = get_task_response.resource.phase assert phase in (TaskExecution.RUNNING, TaskExecution.SUCCEEDED) - asyncio.run(agent.async_delete(context=context, resource_meta=resource_meta)) + asyncio.run(agent.delete(context=context, resource_meta=resource_meta)) - get_task_response = asyncio.run(agent.async_get(context=context, resource_meta=resource_meta)) + get_task_response = asyncio.run(agent.get(context=context, resource_meta=resource_meta)) phase = get_task_response.resource.phase assert phase == TaskExecution.FAILED @@ -146,7 +146,7 @@ def say_hello1(name: str) -> str: task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello1) with pytest.raises(subprocess.CalledProcessError): create_task_response = asyncio.run( - agent.async_create( + agent.create( context=context, output_prefix="", task_template=task_spec.template, @@ -165,7 +165,7 @@ def say_hello2(name: str) -> str: task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello2) with pytest.raises(subprocess.CalledProcessError): create_task_response = asyncio.run( - agent.async_create( + agent.create( context=context, output_prefix="", task_template=task_spec.template, @@ -183,7 +183,7 @@ def say_hello3(name: str) -> str: task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello3) create_task_response = asyncio.run( - agent.async_create( + agent.create( context=context, output_prefix="", task_template=task_spec.template, @@ -191,7 +191,7 @@ def say_hello3(name: str) -> str: ) ) resource_meta = create_task_response.resource_meta - asyncio.run(agent.async_delete(context=context, resource_meta=resource_meta)) + asyncio.run(agent.delete(context=context, resource_meta=resource_meta)) @task( task_config=MMCloudConfig(), @@ -203,7 +203,7 @@ def say_hello4(name: str) -> str: task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello4) create_task_response = asyncio.run( - agent.async_create( + agent.create( context=context, output_prefix="", task_template=task_spec.template, @@ -211,7 +211,7 @@ def say_hello4(name: str) -> str: ) ) resource_meta = create_task_response.resource_meta - asyncio.run(agent.async_delete(context=context, resource_meta=resource_meta)) + asyncio.run(agent.delete(context=context, resource_meta=resource_meta)) @task( task_config=MMCloudConfig(), @@ -222,7 +222,7 @@ def say_hello5(name: str) -> str: task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello5) create_task_response = asyncio.run( - agent.async_create( + agent.create( context=context, output_prefix="", task_template=task_spec.template, @@ -230,4 +230,4 @@ def say_hello5(name: str) -> str: ) ) resource_meta = create_task_response.resource_meta - asyncio.run(agent.async_delete(context=context, resource_meta=resource_meta)) + asyncio.run(agent.delete(context=context, resource_meta=resource_meta)) diff --git a/plugins/flytekit-modin/setup.py b/plugins/flytekit-modin/setup.py index 7d62ea16fe..0a3394a2d0 100644 --- a/plugins/flytekit-modin/setup.py +++ b/plugins/flytekit-modin/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit<1.3.0b2,<2.0.0", + "flytekit", "modin[ray]>=0.13.0", "fsspec", ] diff --git a/plugins/flytekit-onnx-scikitlearn/setup.py b/plugins/flytekit-onnx-scikitlearn/setup.py index 45780ae174..fe55536066 100644 --- a/plugins/flytekit-onnx-scikitlearn/setup.py +++ b/plugins/flytekit-onnx-scikitlearn/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit<1.3.0b2,<2.0.0", "skl2onnx>=1.10.3", "networkx<3.2; python_version<'3.9'"] +plugin_requires = ["flytekit", "skl2onnx>=1.10.3", "networkx<3.2; python_version<'3.9'"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-papermill/dev-requirements.in b/plugins/flytekit-papermill/dev-requirements.in index d0a9617bdb..3dc10d1afc 100644 --- a/plugins/flytekit-papermill/dev-requirements.in +++ b/plugins/flytekit-papermill/dev-requirements.in @@ -1,4 +1,4 @@ -flyteidl>=1.10.7b0 +-e file:../../.#egg=flytekit -e file:../../.#egg=flytekitplugins-pod&subdirectory=plugins/flytekit-k8s-pod -e file:../../.#egg=flytekitplugins-spark&subdirectory=plugins/flytekit-spark -e file:../../.#egg=flytekitplugins-awsbatch&subdirectory=plugins/flytekit-aws-batch diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py index d06bc68085..8cb38662e3 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -1,18 +1,12 @@ -import json -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Optional -from flyteidl.admin.agent_pb2 import ( - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, - Resource, -) from flyteidl.core.execution_pb2 import TaskExecution from flytekit import FlyteContextManager, StructuredDataset, lazy_module, logger from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_phase +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase from flytekit.models import literals from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -25,7 +19,7 @@ @dataclass -class Metadata: +class SnowflakeJobMetadata(ResourceMeta): user: str account: str database: str @@ -53,7 +47,7 @@ def get_private_key(): return pkb -def get_connection(metadata: Metadata) -> snowflake_connector: +def get_connection(metadata: SnowflakeJobMetadata) -> snowflake_connector: return snowflake_connector.connect( user=metadata.user, account=metadata.account, @@ -64,25 +58,18 @@ def get_connection(metadata: Metadata) -> snowflake_connector: ) -class SnowflakeAgent(AgentBase): +class SnowflakeAgent(AsyncAgentBase): def __init__(self): - super().__init__(task_type=TASK_TYPE) + super().__init__(task_type_name=TASK_TYPE, metadata_type=SnowflakeJobMetadata) async def create( - self, output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs - ) -> CreateTaskResponse: - params = None - if inputs: - ctx = FlyteContextManager.current_context() - python_interface_inputs = { - name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() - } - native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) - logger.info(f"Create Snowflake agent params with inputs: {native_inputs}") - params = native_inputs + self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + ) -> SnowflakeJobMetadata: + ctx = FlyteContextManager.current_context() + literal_types = task_template.interface.inputs + params = TypeEngine.literal_map_to_kwargs(ctx, inputs, literal_types=literal_types) if inputs else None config = task_template.config - conn = snowflake_connector.connect( user=config["user"], account=config["account"], @@ -95,7 +82,7 @@ async def create( cs = conn.cursor() cs.execute_async(task_template.sql.statement, params=params) - metadata = Metadata( + return SnowflakeJobMetadata( user=config["user"], account=config["account"], database=config["database"], @@ -105,22 +92,19 @@ async def create( query_id=str(cs.sfqid), ) - return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) - - async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - conn = get_connection(metadata) + async def get(self, resource_meta: SnowflakeJobMetadata, **kwargs) -> Resource: + conn = get_connection(resource_meta) try: - query_status = conn.get_query_status_throw_if_error(metadata.query_id) + query_status = conn.get_query_status_throw_if_error(resource_meta.query_id) except snowflake_connector.ProgrammingError as err: logger.error("Failed to get snowflake job status with error:", err.msg) - return GetTaskResponse(resource=Resource(state=TaskExecution.FAILED)) + return Resource(phase=TaskExecution.FAILED) cur_phase = convert_to_flyte_phase(str(query_status.name)) res = None if cur_phase == TaskExecution.SUCCEEDED: ctx = FlyteContextManager.current_context() - output_metadata = f"snowflake://{metadata.user}:{metadata.account}/{metadata.warehouse}/{metadata.database}/{metadata.schema}/{metadata.table}" + output_metadata = f"snowflake://{resource_meta.user}:{resource_meta.account}/{resource_meta.warehouse}/{resource_meta.database}/{resource_meta.schema}/{resource_meta.table}" res = literals.LiteralMap( { "results": TypeEngine.to_literal( @@ -132,19 +116,17 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: } ).to_flyte_idl() - return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res)) + return Resource(phase=cur_phase, outputs=res) - async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - conn = get_connection(metadata) + async def delete(self, resource_meta: SnowflakeJobMetadata, **kwargs): + conn = get_connection(resource_meta) cs = conn.cursor() try: - cs.execute(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')") + cs.execute(f"SELECT SYSTEM$CANCEL_QUERY('{resource_meta.query_id}')") cs.fetchall() finally: cs.close() conn.close() - return DeleteTaskResponse() AgentRegistry.register(SnowflakeAgent()) diff --git a/plugins/flytekit-snowflake/setup.py b/plugins/flytekit-snowflake/setup.py index 527daa2486..b5265c299e 100644 --- a/plugins/flytekit-snowflake/setup.py +++ b/plugins/flytekit-snowflake/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "snowflake-connector-python>=3.1.0"] +plugin_requires = ["flytekit>1.10.7", "snowflake-connector-python>=3.1.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-snowflake/tests/test_agent.py b/plugins/flytekit-snowflake/tests/test_agent.py index 017297704e..f3dcb0686d 100644 --- a/plugins/flytekit-snowflake/tests/test_agent.py +++ b/plugins/flytekit-snowflake/tests/test_agent.py @@ -1,13 +1,10 @@ -import json -from dataclasses import asdict from datetime import timedelta from unittest import mock from unittest.mock import MagicMock import pytest -from flyteidl.admin.agent_pb2 import DeleteTaskResponse from flyteidl.core.execution_pb2 import TaskExecution -from flytekitplugins.snowflake.agent import Metadata +from flytekitplugins.snowflake.agent import SnowflakeJobMetadata import flytekit.models.interface as interface_models from flytekit import lazy_module @@ -30,8 +27,11 @@ async def test_snowflake_agent(mock_get_private_key): mock_conn_instance = snowflake_connector.connect.return_value mock_conn_instance.get_query_status_throw_if_error.return_value = query_status_mock - agent = AgentRegistry.get_agent("snowflake") + mock_cursor = MagicMock() + mock_cursor.sfqid = "dummy_id" + mock_conn_instance.cursor.return_value = mock_cursor + agent = AgentRegistry.get_agent("snowflake") task_id = Identifier( resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" ) @@ -82,32 +82,28 @@ async def test_snowflake_agent(mock_get_private_key): sql=Sql("SELECT 1"), ) - metadata = Metadata( + snowflake_metadata = SnowflakeJobMetadata( user="dummy_user", account="dummy_account", table="dummy_table", database="dummy_database", schema="dummy_schema", warehouse="dummy_warehouse", - query_id="dummy_query_id", + query_id="dummy_id", ) - res = await agent.create("/tmp", dummy_template, task_inputs) - metadata.query_id = Metadata(**json.loads(res.resource_meta.decode("utf-8"))).query_id - metadata_bytes = json.dumps(asdict(metadata)).encode("utf-8") - assert res.resource_meta == metadata_bytes + metadata = await agent.create(dummy_template, task_inputs) + assert metadata == snowflake_metadata - res = await agent.get(metadata_bytes) - assert res.resource.phase == TaskExecution.SUCCEEDED + resource = await agent.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED assert ( - res.resource.outputs.literals["results"].scalar.structured_dataset.uri + resource.outputs.literals["results"].scalar.structured_dataset.uri == "snowflake://dummy_user:dummy_account/dummy_warehouse/dummy_database/dummy_schema/dummy_table" ) - delete_response = await agent.delete(metadata_bytes) - - # Assert the response - assert isinstance(delete_response, DeleteTaskResponse) + delete_response = await agent.delete(snowflake_metadata) + assert delete_response is None # Verify that the expected methods were called on the mock cursor mock_cursor = mock_conn_instance.cursor.return_value diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index c75617b0e0..8200263ac3 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -1,15 +1,14 @@ import http import json -import pickle import typing from dataclasses import dataclass from typing import Optional -from flyteidl.admin.agent_pb2 import CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource from flyteidl.core.execution_pb2 import TaskExecution from flytekit import lazy_module -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_phase, get_agent_secret +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret from flytekit.models.core.execution import TaskLog from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -20,24 +19,20 @@ @dataclass -class Metadata: +class DatabricksJobMetadata(ResourceMeta): databricks_instance: str run_id: str -class DatabricksAgent(AgentBase): +class DatabricksAgent(AsyncAgentBase): name = "Databricks Agent" def __init__(self): - super().__init__(task_type="spark", asynchronous=True) + super().__init__(task_type_name="spark", metadata_type=DatabricksJobMetadata) async def create( - self, - output_prefix: str, - task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, - **kwargs, - ) -> CreateTaskResponse: + self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + ) -> DatabricksJobMetadata: custom = task_template.custom container = task_template.container databricks_job = custom["databricksConf"] @@ -72,21 +67,18 @@ async def create( if resp.status != http.HTTPStatus.OK: raise Exception(f"Failed to create databricks job with error: {response}") - metadata = Metadata( - databricks_instance=databricks_instance, - run_id=str(response["run_id"]), - ) - return CreateTaskResponse(resource_meta=pickle.dumps(metadata)) + return DatabricksJobMetadata(databricks_instance=databricks_instance, run_id=str(response["run_id"])) - async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - metadata = pickle.loads(resource_meta) - databricks_instance = metadata.databricks_instance - databricks_url = f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/get?run_id={metadata.run_id}" + async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource: + databricks_instance = resource_meta.databricks_instance + databricks_url = ( + f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/get?run_id={resource_meta.run_id}" + ) async with aiohttp.ClientSession() as session: async with session.get(databricks_url, headers=get_header()) as resp: if resp.status != http.HTTPStatus.OK: - raise Exception(f"Failed to get databricks job {metadata.run_id} with error: {resp.reason}") + raise Exception(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}") response = await resp.json() cur_phase = TaskExecution.RUNNING @@ -99,25 +91,21 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: message = state["state_message"] job_id = response.get("job_id") - databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{metadata.run_id}" + databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{resource_meta.run_id}" log_links = [TaskLog(uri=databricks_console_url, name="Databricks Console").to_flyte_idl()] - return GetTaskResponse(resource=Resource(phase=cur_phase, message=message, log_links=log_links)) + return Resource(phase=cur_phase, message=message, log_links=log_links) - async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - metadata = pickle.loads(resource_meta) - - databricks_url = f"https://{metadata.databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/cancel" - data = json.dumps({"run_id": metadata.run_id}) + async def delete(self, resource_meta: DatabricksJobMetadata, **kwargs): + databricks_url = f"https://{resource_meta.databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/cancel" + data = json.dumps({"run_id": resource_meta.run_id}) async with aiohttp.ClientSession() as session: async with session.post(databricks_url, headers=get_header(), data=data) as resp: if resp.status != http.HTTPStatus.OK: - raise Exception(f"Failed to cancel databricks job {metadata.run_id} with error: {resp.reason}") + raise Exception(f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}") await resp.json() - return DeleteTaskResponse() - def get_header() -> typing.Dict[str, str]: token = get_agent_secret("FLYTE_DATABRICKS_ACCESS_TOKEN") diff --git a/plugins/flytekit-spark/setup.py b/plugins/flytekit-spark/setup.py index ac7b650ecb..4bc8983289 100644 --- a/plugins/flytekit-spark/setup.py +++ b/plugins/flytekit-spark/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pyspark>=3.0.0", "aiohttp", "flyteidl>=1.10.0", "pandas"] +plugin_requires = ["flytekit>1.10.7", "pyspark>=3.0.0", "aiohttp", "flyteidl>1.10.7", "pandas"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-spark/tests/test_agent.py b/plugins/flytekit-spark/tests/test_agent.py index f875800268..80f91c5c76 100644 --- a/plugins/flytekit-spark/tests/test_agent.py +++ b/plugins/flytekit-spark/tests/test_agent.py @@ -1,12 +1,11 @@ import http -import pickle from datetime import timedelta from unittest import mock import pytest from aioresponses import aioresponses from flyteidl.core.execution_pb2 import TaskExecution -from flytekitplugins.spark.agent import DATABRICKS_API_ENDPOINT, Metadata, get_header +from flytekitplugins.spark.agent import DATABRICKS_API_ENDPOINT, DatabricksJobMetadata, get_header from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.interfaces.cli_identifiers import Identifier @@ -103,11 +102,9 @@ async def test_databricks_agent(): mocked_context = mock.patch("flytekit.current_context", autospec=True).start() mocked_context.return_value.secrets.get.return_value = mocked_token - metadata_bytes = pickle.dumps( - Metadata( - databricks_instance="test-account.cloud.databricks.com", - run_id="123", - ) + databricks_metadata = DatabricksJobMetadata( + databricks_instance="test-account.cloud.databricks.com", + run_id="123", ) mock_create_response = {"run_id": "123"} @@ -118,17 +115,19 @@ async def test_databricks_agent(): delete_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/cancel" with aioresponses() as mocked: mocked.post(create_url, status=http.HTTPStatus.OK, payload=mock_create_response) - res = await agent.create("/tmp", dummy_template, None) - assert res.resource_meta == metadata_bytes + res = await agent.create(dummy_template, None) + assert res == databricks_metadata mocked.get(get_url, status=http.HTTPStatus.OK, payload=mock_get_response) - res = await agent.get(metadata_bytes) - assert res.resource.phase == TaskExecution.SUCCEEDED - assert res.resource.outputs == literals.LiteralMap({}).to_flyte_idl() - assert res.resource.message == "OK" + resource = await agent.get(databricks_metadata) + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.outputs is None + assert resource.message == "OK" + assert resource.log_links[0].name == "Databricks Console" + assert resource.log_links[0].uri == "https://test-account.cloud.databricks.com/#job/1/run/123" mocked.post(delete_url, status=http.HTTPStatus.OK, payload=mock_delete_response) - await agent.delete(metadata_bytes) + await agent.delete(databricks_metadata) assert get_header() == {"Authorization": f"Bearer {mocked_token}", "content-type": "application/json"} diff --git a/pyproject.toml b/pyproject.toml index c4b6c03e97..07d75cf00d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0,<7.0.0", "docstring-parser>=0.9.0", + "flyteidl>1.10.7", "flyteidl>=1.11.0b0", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index f92489e9c4..a1c137dc48 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -26,9 +26,10 @@ @pytest.fixture(scope="session") def register(): - subprocess.run( + out = subprocess.run( [ "pyflyte", + "--verbose", "-c", CONFIG, "register", @@ -43,6 +44,7 @@ def register(): MODULE_PATH, ] ) + assert out.returncode == 0 def test_fetch_execute_launch_plan(register): @@ -52,7 +54,7 @@ def test_fetch_execute_launch_plan(register): assert execution.outputs["o0"] == "hello world" -def fetch_execute_launch_plan_with_args(register): +def test_fetch_execute_launch_plan_with_args(register): remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) flyte_launch_plan = remote.fetch_launch_plan(name="basic.basic_workflow.my_wf", version=VERSION) execution = remote.execute(flyte_launch_plan, inputs={"a": 10, "b": "foobar"}, wait=True) diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index ce9b9e5b9b..85c88def45 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -1,92 +1,108 @@ -import asyncio -import json import typing from collections import OrderedDict -from dataclasses import asdict, dataclass +from dataclasses import dataclass from unittest.mock import MagicMock, patch import grpc import pytest from flyteidl.admin.agent_pb2 import ( + CreateRequestHeader, CreateTaskRequest, - CreateTaskResponse, DeleteTaskRequest, - DeleteTaskResponse, + ExecuteTaskSyncRequest, + GetAgentRequest, GetTaskRequest, - GetTaskResponse, - Resource, + ListAgentsRequest, + ListAgentsResponse, + TaskCategory, ) -from flyteidl.core.execution_pb2 import TaskExecution +from flyteidl.core.execution_pb2 import TaskExecution, TaskLog from flytekit import PythonFunctionTask, task from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings -from flytekit.extend.backend.agent_service import AsyncAgentService +from flytekit.core.base_task import PythonTask, kwtypes +from flytekit.core.interface import Interface +from flytekit.exceptions.system import FlyteAgentNotFound +from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService, SyncAgentService from flytekit.extend.backend.base_agent import ( - AgentBase, AgentRegistry, + AsyncAgentBase, AsyncAgentExecutorMixin, - convert_to_flyte_phase, - get_agent_secret, + Resource, + ResourceMeta, + SyncAgentBase, + SyncAgentExecutorMixin, is_terminal_phase, render_task_template, ) +from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret from flytekit.models import literals -from flytekit.models.core.execution import TaskLog from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate from flytekit.tools.translator import get_serializable dummy_id = "dummy_id" -loop = asyncio.get_event_loop() @dataclass -class Metadata: +class DummyMetadata(ResourceMeta): job_id: str -class DummyAgent(AgentBase): +class DummyAgent(AsyncAgentBase): name = "Dummy Agent" def __init__(self): - super().__init__(task_type="dummy", asynchronous=False) + super().__init__(task_type_name="dummy", metadata_type=DummyMetadata) - def create( - self, output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, **kwargs - ) -> CreateTaskResponse: - return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")) + def create(self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap], **kwargs) -> DummyMetadata: + return DummyMetadata(job_id=dummy_id) - def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - return GetTaskResponse( - resource=Resource( - phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000").to_flyte_idl()] - ), - ) + def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource: + return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")]) - def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - return DeleteTaskResponse() + def delete(self, resource_meta: DummyMetadata, **kwargs): + ... -class AsyncDummyAgent(AgentBase): +class AsyncDummyAgent(AsyncAgentBase): name = "Async Dummy Agent" def __init__(self): - super().__init__(task_type="async_dummy") + super().__init__(task_type_name="async_dummy", metadata_type=DummyMetadata) async def create( - self, - output_prefix: str, - task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None, - **kwargs, - ) -> CreateTaskResponse: - return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")) + self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, **kwargs + ) -> DummyMetadata: + return DummyMetadata(job_id=dummy_id) + + async def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource: + return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")]) + + async def delete(self, resource_meta: DummyMetadata, **kwargs): + ... + + +class MockOpenAIAgent(SyncAgentBase): + name = "mock openAI Agent" + + def __init__(self): + super().__init__(task_type_name="openai") - async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - return GetTaskResponse(resource=Resource(phase=TaskExecution.SUCCEEDED)) + def do(self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, **kwargs) -> Resource: + assert inputs.literals["a"].scalar.primitive.integer == 1 + return Resource(phase=TaskExecution.SUCCEEDED, outputs={"o0": 1}) - async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - return DeleteTaskResponse() + +class MockAsyncOpenAIAgent(SyncAgentBase): + name = "mock async openAI Agent" + + def __init__(self): + super().__init__(task_type_name="async_openai") + + async def do(self, task_template: TaskTemplate, inputs: LiteralMap = None, **kwargs) -> Resource: + assert inputs.literals["a"].scalar.primitive.integer == 1 + return Resource(phase=TaskExecution.SUCCEEDED, outputs={"o0": 1}) def get_task_template(task_type: str) -> TaskTemplate: @@ -115,85 +131,149 @@ def simple_task(i: int): ) -dummy_template = get_task_template("dummy") -async_dummy_template = get_task_template("async_dummy") -sync_dummy_template = get_task_template("sync_dummy") - - def test_dummy_agent(): - AgentRegistry.register(DummyAgent()) + AgentRegistry.register(DummyAgent(), override=True) agent = AgentRegistry.get_agent("dummy") - metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") - assert agent.create("/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes - res = agent.get(metadata_bytes) - assert res.resource.phase == TaskExecution.SUCCEEDED - assert agent.delete(metadata_bytes) == DeleteTaskResponse() + template = get_task_template("dummy") + metadata = DummyMetadata(job_id=dummy_id) + assert agent.create(template, task_inputs) == DummyMetadata(job_id=dummy_id) + resource = agent.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.log_links[0].name == "console" + assert resource.log_links[0].uri == "localhost:3000" + assert agent.delete(metadata) is None class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): - super().__init__( - task_type="dummy", - **kwargs, - ) + super().__init__(task_type="dummy", **kwargs) t = DummyTask(task_config={}, task_function=lambda: None, container_image="dummy") t.execute() t._task_type = "non-exist-type" - with pytest.raises(Exception, match="Cannot find agent for task type: non-exist-type."): + with pytest.raises(Exception, match="Cannot find agent for task category: non-exist-type."): t.execute() - agent_metadata = AgentRegistry.get_agent_metadata("Dummy Agent") - assert agent_metadata.name == "Dummy Agent" - assert agent_metadata.supported_task_types == ["dummy"] - +@pytest.mark.parametrize("agent", [DummyAgent(), AsyncDummyAgent()], ids=["sync", "async"]) @pytest.mark.asyncio -async def test_async_dummy_agent(): - AgentRegistry.register(AsyncDummyAgent()) - agent = AgentRegistry.get_agent("async_dummy") - metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") - res = await agent.create("/tmp", async_dummy_template, task_inputs) +async def test_async_agent_service(agent): + AgentRegistry.register(agent, override=True) + service = AsyncAgentService() + ctx = MagicMock(spec=grpc.ServicerContext) + + inputs_proto = task_inputs.to_flyte_idl() + output_prefix = "/tmp" + metadata_bytes = DummyMetadata(job_id=dummy_id).encode() + + tmp = get_task_template(agent.task_category.name).to_flyte_idl() + task_category = TaskCategory(name=agent.task_category.name, version=0) + req = CreateTaskRequest(inputs=inputs_proto, output_prefix=output_prefix, template=tmp) + + res = await service.CreateTask(req, ctx) assert res.resource_meta == metadata_bytes - res = await agent.get(metadata_bytes) + res = await service.GetTask(GetTaskRequest(task_category=task_category, resource_meta=metadata_bytes), ctx) assert res.resource.phase == TaskExecution.SUCCEEDED - res = await agent.delete(metadata_bytes) - assert res == DeleteTaskResponse() + res = await service.DeleteTask(DeleteTaskRequest(task_category=task_category, resource_meta=metadata_bytes), ctx) + assert res is None - agent_metadata = AgentRegistry.get_agent_metadata("Async Dummy Agent") - assert agent_metadata.name == "Async Dummy Agent" - assert agent_metadata.supported_task_types == ["async_dummy"] + agent_metadata = AgentRegistry.get_agent_metadata(agent.name) + assert agent_metadata.supported_task_types[0] == agent.task_category.name + assert agent_metadata.supported_task_categories[0].name == agent.task_category.name + + with pytest.raises(FlyteAgentNotFound): + AgentRegistry.get_agent_metadata("non-exist-namr") + + +def test_register_agent(): + agent = DummyAgent() + AgentRegistry.register(agent, override=True) + assert AgentRegistry.get_agent("dummy").name == agent.name + + with pytest.raises(ValueError, match="Duplicate agent for task type: dummy_v0"): + AgentRegistry.register(agent) + + with pytest.raises(FlyteAgentNotFound): + AgentRegistry.get_agent("non-exist-type") + + agents = AgentRegistry.list_agents() + assert len(agents) >= 1 @pytest.mark.asyncio -async def run_agent_server(): - service = AsyncAgentService() +async def test_agent_metadata_service(): + agent = DummyAgent() + AgentRegistry.register(agent, override=True) + ctx = MagicMock(spec=grpc.ServicerContext) - request = CreateTaskRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() - ) - async_request = CreateTaskRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() - ) - metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") + metadata_service = AgentMetadataService() + res = await metadata_service.ListAgents(ListAgentsRequest(), ctx) + assert isinstance(res, ListAgentsResponse) + res = await metadata_service.GetAgent(GetAgentRequest(name="Dummy Agent"), ctx) + assert res.agent.name == agent.name + assert res.agent.supported_task_types[0] == agent.task_category.name + assert res.agent.supported_task_categories[0].name == agent.task_category.name - res = await service.CreateTask(request, ctx) - assert res.resource_meta == metadata_bytes - res = await service.GetTask(GetTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) - assert res.resource.phase == TaskExecution.SUCCEEDED - res = await service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) - assert isinstance(res, DeleteTaskResponse) - res = await service.CreateTask(async_request, ctx) - assert res.resource_meta == metadata_bytes - res = await service.GetTask(GetTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) - assert res.resource.phase == TaskExecution.SUCCEEDED - res = await service.DeleteTask(DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) - assert isinstance(res, DeleteTaskResponse) +def test_openai_agent(): + AgentRegistry.register(MockOpenAIAgent(), override=True) + + class OpenAITask(SyncAgentExecutorMixin, PythonTask): + def __init__(self, **kwargs): + super().__init__( + task_type="openai", interface=Interface(inputs=kwtypes(a=int), outputs=kwtypes(o0=int)), **kwargs + ) + + t = OpenAITask(task_config={}, name="openai task") + res = t(a=1) + assert res == 1 + + +def test_async_openai_agent(): + AgentRegistry.register(MockAsyncOpenAIAgent(), override=True) + + class OpenAITask(SyncAgentExecutorMixin, PythonTask): + def __init__(self, **kwargs): + super().__init__( + task_type="async_openai", + interface=Interface(inputs=kwtypes(a=int), outputs=kwtypes(o0=int)), + **kwargs, + ) + + t = OpenAITask(task_config={}, name="openai task") + res = t(a=1) + assert res == 1 + + +async def get_request_iterator(task_type: str): + inputs_proto = task_inputs.to_flyte_idl() + template = get_task_template(task_type).to_flyte_idl() + header = CreateRequestHeader(template=template, output_prefix="/tmp") + yield ExecuteTaskSyncRequest(header=header) + yield ExecuteTaskSyncRequest(inputs=inputs_proto) + +@pytest.mark.asyncio +async def test_sync_agent_service(): + AgentRegistry.register(MockOpenAIAgent(), override=True) + ctx = MagicMock(spec=grpc.ServicerContext) + + service = SyncAgentService() + res = await service.ExecuteTaskSync(get_request_iterator("openai"), ctx).__anext__() + assert res.header.resource.phase == TaskExecution.SUCCEEDED + assert res.header.resource.outputs.literals["o0"].scalar.primitive.integer == 1 + + +@pytest.mark.asyncio +async def test_sync_agent_service_with_asyncio(): + AgentRegistry.register(MockAsyncOpenAIAgent(), override=True) + AgentRegistry.register(DummyAgent(), override=True) + ctx = MagicMock(spec=grpc.ServicerContext) -def test_agent_server(): - loop.run_in_executor(None, run_agent_server) + service = SyncAgentService() + res = await service.ExecuteTaskSync(get_request_iterator("async_openai"), ctx).__anext__() + assert res.header.resource.phase == TaskExecution.SUCCEEDED + assert res.header.resource.outputs.literals["o0"].scalar.primitive.integer == 1 def test_is_terminal_phase(): @@ -227,7 +307,8 @@ def test_get_agent_secret(mocked_context): def test_render_task_template(): - tt = render_task_template(dummy_template, "s3://becket") + template = get_task_template("dummy") + tt = render_task_template(template, "s3://becket") assert tt.container.args == [ "pyflyte-fast-execute", "--additional-distribution", @@ -239,7 +320,7 @@ def test_render_task_template(): "--inputs", "s3://becket/inputs.pb", "--output-prefix", - "s3://becket/output", + "s3://becket", "--raw-output-data-prefix", "s3://becket/raw_output", "--checkpoint-path", diff --git a/tests/flytekit/unit/sensor/test_file_sensor.py b/tests/flytekit/unit/sensor/test_file_sensor.py index f6a50836be..bb0553dc27 100644 --- a/tests/flytekit/unit/sensor/test_file_sensor.py +++ b/tests/flytekit/unit/sensor/test_file_sensor.py @@ -16,7 +16,12 @@ def test_sensor_task(): env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) - assert sensor.get_custom(settings) == {"sensor_module": "flytekit.sensor.file_sensor", "sensor_name": "FileSensor"} + assert sensor.get_custom(settings) == { + "sensor_module": "flytekit.sensor.file_sensor", + "sensor_name": "FileSensor", + "sensor_config": None, + "inputs": None, + } tmp_file = tempfile.NamedTemporaryFile() @task() diff --git a/tests/flytekit/unit/sensor/test_sensor_engine.py b/tests/flytekit/unit/sensor/test_sensor_engine.py index b5353b61b4..4a12aed877 100644 --- a/tests/flytekit/unit/sensor/test_sensor_engine.py +++ b/tests/flytekit/unit/sensor/test_sensor_engine.py @@ -1,20 +1,20 @@ import tempfile +from dataclasses import asdict -import cloudpickle import pytest -from flyteidl.admin.agent_pb2 import DeleteTaskResponse from flyteidl.core.execution_pb2 import TaskExecution import flytekit.models.interface as interface_models from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.models import literals, types from flytekit.sensor import FileSensor -from flytekit.sensor.base_sensor import SENSOR_MODULE, SENSOR_NAME +from flytekit.sensor.base_sensor import SensorMetadata from tests.flytekit.unit.extend.test_agent import get_task_template @pytest.mark.asyncio async def test_sensor_engine(): + file = tempfile.NamedTemporaryFile() interfaces = interface_models.TypedInterface( { "path": interface_models.Variable(types.LiteralType(types.SimpleType.STRING), "description1"), @@ -22,12 +22,10 @@ async def test_sensor_engine(): {}, ) tmp = get_task_template("sensor") - tmp._custom = { - SENSOR_MODULE: FileSensor.__module__, - SENSOR_NAME: FileSensor.__name__, - } - file = tempfile.NamedTemporaryFile() - + sensor_metadata = SensorMetadata( + sensor_module=FileSensor.__module__, sensor_name=FileSensor.__name__, inputs={"path": file.name} + ) + tmp._custom = asdict(sensor_metadata) tmp._interface = interfaces task_inputs = literals.LiteralMap( @@ -37,11 +35,10 @@ async def test_sensor_engine(): ) agent = AgentRegistry.get_agent("sensor") - res = await agent.create("/tmp", tmp, task_inputs) + res = await agent.create(tmp, task_inputs) - metadata_bytes = cloudpickle.dumps(tmp.custom) - assert res.resource_meta == metadata_bytes - res = await agent.get(metadata_bytes) - assert res.resource.phase == TaskExecution.SUCCEEDED - res = await agent.delete(metadata_bytes) - assert res == DeleteTaskResponse() + assert res == sensor_metadata + resource = await agent.get(sensor_metadata) + assert resource.phase == TaskExecution.SUCCEEDED + res = await agent.delete(sensor_metadata) + assert res is None