diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index 120fd56cb8..e6a147e67e 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -142,7 +142,7 @@ jobs: cache-from: type=gha cache-to: type=gha,mode=max - build-and-push-external-plugin-service-images: + build-and-push-flyteagent-images: runs-on: ubuntu-latest needs: deploy steps: @@ -161,12 +161,12 @@ jobs: registry: ghcr.io username: "${{ secrets.FLYTE_BOT_USERNAME }}" password: "${{ secrets.FLYTE_BOT_PAT }}" - - name: Prepare External Plugin Service Image Names - id: external-plugin-service-names + - name: Prepare Flyte Agent Image Names + id: flyteagent-names uses: docker/metadata-action@v3 with: images: | - ghcr.io/${{ github.repository_owner }}/external-plugin-service + ghcr.io/${{ github.repository_owner }}/flyteagent tags: | latest ${{ github.sha }} @@ -177,10 +177,10 @@ jobs: context: "." platforms: linux/arm64, linux/amd64 push: ${{ github.event_name == 'release' }} - tags: ${{ steps.external-plugin-service-names.outputs.tags }} + tags: ${{ steps.flyteagent-names.outputs.tags }} build-args: | VERSION=${{ needs.deploy.outputs.version }} - file: ./Dockerfile.external-plugin-service + file: ./Dockerfile.agent cache-from: type=gha cache-to: type=gha,mode=max diff --git a/Dockerfile.external-plugin-service b/Dockerfile.agent similarity index 100% rename from Dockerfile.external-plugin-service rename to Dockerfile.agent diff --git a/doc-requirements.txt b/doc-requirements.txt index 5264673f4f..c1e17439f9 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -244,7 +244,7 @@ flask==2.2.3 # via mlflow flatbuffers==23.1.21 # via tensorflow -flyteidl==1.5.6 +flyteidl==1.5.10 # via flytekit fonttools==4.38.0 # via matplotlib diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index b1cf19647d..b2b82831c7 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -25,6 +25,7 @@ class ClientConfig: device_authorization_endpoint: typing.Optional[str] = None scopes: typing.List[str] = None header_key: str = "authorization" + audience: typing.Optional[str] = None class ClientConfigStore(object): @@ -174,6 +175,7 @@ def __init__( scopes: typing.Optional[typing.List[str]] = None, http_proxy_url: typing.Optional[str] = None, verify: typing.Optional[typing.Union[bool, str]] = None, + audience: typing.Optional[str] = None, ): if not client_id or not client_secret: raise ValueError("Client ID and Client SECRET both are required.") @@ -183,6 +185,7 @@ def __init__( self._scopes = scopes or cfg.scopes self._client_id = client_id self._client_secret = client_secret + self._audience = audience or cfg.audience super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url, verify=verify) def refresh_credentials(self): @@ -195,14 +198,21 @@ def refresh_credentials(self): """ token_endpoint = self._token_endpoint scopes = self._scopes + audience = self._audience # Note that unlike the Pkce flow, the client ID does not come from Admin. logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}") authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret) token, expires_in = token_client.get_token( - token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url, verify=self._verify + token_endpoint=token_endpoint, + authorization_header=authorization_header, + http_proxy_url=self._http_proxy_url, + verify=self._verify, + scopes=scopes, + audience=audience, ) + logging.info("Retrieved new token, expires in {}".format(expires_in)) self._creds = Credentials(token) diff --git a/flytekit/clients/auth/token_client.py b/flytekit/clients/auth/token_client.py index 2e14fe8afc..e5eae32ed7 100644 --- a/flytekit/clients/auth/token_client.py +++ b/flytekit/clients/auth/token_client.py @@ -74,6 +74,7 @@ def get_token( authorization_header: typing.Optional[str] = None, client_id: typing.Optional[str] = None, device_code: typing.Optional[str] = None, + audience: typing.Optional[str] = None, grant_type: GrantType = GrantType.CLIENT_CREDS, http_proxy_url: typing.Optional[str] = None, verify: typing.Optional[typing.Union[bool, str]] = None, @@ -98,9 +99,12 @@ def get_token( body["device_code"] = device_code if scopes is not None: body["scope"] = ",".join(scopes) + if audience: + body["audience"] = audience proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify) + if not response.ok: j = response.json() if "error" in j: diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py index ce2992723f..5c4fafe579 100644 --- a/flytekit/clients/auth_helper.py +++ b/flytekit/clients/auth_helper.py @@ -43,6 +43,7 @@ def get_client_config(self) -> ClientConfig: scopes=public_client_config.scopes, header_key=public_client_config.authorization_metadata_key or None, device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint, + audience=public_client_config.audience, ) @@ -73,6 +74,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth client_secret=cfg.client_credentials_secret, cfg_store=cfg_store, scopes=cfg.scopes, + audience=cfg.audience, http_proxy_url=cfg.http_proxy_url, verify=verify, ) diff --git a/flytekit/clis/sdk_in_container/serve.py b/flytekit/clis/sdk_in_container/serve.py index 71b539d36c..c95754e6c6 100644 --- a/flytekit/clis/sdk_in_container/serve.py +++ b/flytekit/clis/sdk_in_container/serve.py @@ -2,11 +2,11 @@ import click import grpc -from flyteidl.service.external_plugin_service_pb2_grpc import add_ExternalPluginServiceServicer_to_server +from flyteidl.service.agent_pb2_grpc import add_AsyncAgentServiceServicer_to_server -from flytekit.extend.backend.external_plugin_service import BackendPluginServer +from flytekit.extend.backend.agent_service import AgentService -_serve_help = """Start a grpc server for the external plugin service.""" +_serve_help = """Start a grpc server for the agent service.""" @click.command("serve", help=_serve_help) @@ -15,7 +15,7 @@ default="8000", is_flag=False, type=int, - help="Grpc port for the external plugin service", + help="Grpc port for the agent service", ) @click.option( "--worker", @@ -35,11 +35,11 @@ @click.pass_context def serve(_: click.Context, port, worker, timeout): """ - Start a grpc server for the external plugin service. + Start a grpc server for the agent service. """ - click.secho("Starting the external plugin service...", fg="blue") + click.secho("Starting the agent service...", fg="blue") server = grpc.server(futures.ThreadPoolExecutor(max_workers=worker)) - add_ExternalPluginServiceServicer_to_server(BackendPluginServer(), server) + add_AsyncAgentServiceServicer_to_server(AgentService(), server) server.add_insecure_port(f"[::]:{port}") server.start() diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index a7e2c69ebd..e31af5f389 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -428,6 +428,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file)) kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file)) kwargs = set_if_exists(kwargs, "console_endpoint", _internal.Platform.CONSOLE_ENDPOINT.read(config_file)) + kwargs = set_if_exists(kwargs, "http_proxy_url", _internal.Platform.HTTP_PROXY_URL.read(config_file)) return PlatformConfig(**kwargs) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py new file mode 100644 index 0000000000..55f71959fe --- /dev/null +++ b/flytekit/extend/backend/agent_service.py @@ -0,0 +1,54 @@ +import grpc +from flyteidl.admin.agent_pb2 import ( + PERMANENT_FAILURE, + CreateTaskRequest, + CreateTaskResponse, + DeleteTaskRequest, + DeleteTaskResponse, + GetTaskRequest, + GetTaskResponse, + Resource, +) +from flyteidl.service.agent_pb2_grpc import AsyncAgentServiceServicer + +from flytekit import logger +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + + +class AgentService(AsyncAgentServiceServicer): + def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse: + try: + tmp = TaskTemplate.from_flyte_idl(request.template) + inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None + agent = AgentRegistry.get_agent(context, tmp.type) + if agent is None: + return CreateTaskResponse() + return agent.create(context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp) + except Exception as e: + logger.error(f"failed to create task with error {e}") + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"failed to create task with error {e}") + + def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse: + try: + agent = AgentRegistry.get_agent(context, request.task_type) + if agent is None: + return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE)) + return agent.get(context=context, resource_meta=request.resource_meta) + except Exception as e: + logger.error(f"failed to get task with error {e}") + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"failed to get task with error {e}") + + def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse: + try: + agent = AgentRegistry.get_agent(context, request.task_type) + if agent is None: + return DeleteTaskResponse() + return agent.delete(context=context, resource_meta=request.resource_meta) + except Exception as e: + logger.error(f"failed to delete task with error {e}") + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"failed to delete task with error {e}") diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py new file mode 100644 index 0000000000..0c93d2f60f --- /dev/null +++ b/flytekit/extend/backend/base_agent.py @@ -0,0 +1,107 @@ +import typing +from abc import ABC, abstractmethod + +import grpc +from flyteidl.admin.agent_pb2 import ( + RETRYABLE_FAILURE, + RUNNING, + SUCCEEDED, + CreateTaskResponse, + DeleteTaskResponse, + GetTaskResponse, + State, +) +from flyteidl.core.tasks_pb2 import TaskTemplate + +from flytekit import logger +from flytekit.models.literals import LiteralMap + + +class AgentBase(ABC): + """ + This is the base class for all agents. It defines the interface that all agents must implement. + The agent service will be run either locally or in a pod, and will be responsible for + invoking agents. The propeller will communicate with the agent service + to create tasks, get the status of tasks, and delete tasks. + + All the agents should be registered in the AgentRegistry. Agent Service + will look up the agent based on the task type. Every task type can only have one agent. + """ + + def __init__(self, task_type: str): + self._task_type = task_type + + @property + def task_type(self) -> str: + """ + task_type is the name of the task type that this agent supports. + """ + return self._task_type + + @abstractmethod + def create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + """ + Return a Unique ID for the task that was created. It should return error code if the task creation failed. + """ + pass + + @abstractmethod + def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + """ + Return the status of the task, and return the outputs in some cases. For example, bigquery job + can't write the structured dataset to the output location, so it returns the output literals to the propeller, + and the propeller will write the structured dataset to the blob store. + """ + pass + + @abstractmethod + def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + """ + Delete the task. This call should be idempotent. + """ + pass + + +class AgentRegistry(object): + """ + This is the registry for all agents. The agent service will look up the agent + based on the task type. + """ + + _REGISTRY: typing.Dict[str, AgentBase] = {} + + @staticmethod + def register(agent: AgentBase): + if agent.task_type in AgentRegistry._REGISTRY: + raise ValueError(f"Duplicate agent for task type {agent.task_type}") + AgentRegistry._REGISTRY[agent.task_type] = agent + logger.info(f"Registering an agent for task type {agent.task_type}") + + @staticmethod + def get_agent(context: grpc.ServicerContext, task_type: str) -> typing.Optional[AgentBase]: + if task_type not in AgentRegistry._REGISTRY: + logger.error(f"Cannot find agent for task type [{task_type}]") + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(f"Cannot find the agent for task type [{task_type}]") + return None + return AgentRegistry._REGISTRY[task_type] + + +def convert_to_flyte_state(state: str) -> State: + """ + Convert the state from the agent to the state in flyte. + """ + state = state.lower() + if state in ["failed"]: + return RETRYABLE_FAILURE + elif state in ["done", "succeeded"]: + return SUCCEEDED + elif state in ["running"]: + return RUNNING + raise ValueError(f"Unrecognized state: {state}") diff --git a/flytekit/extend/backend/base_plugin.py b/flytekit/extend/backend/base_plugin.py deleted file mode 100644 index 9fc1bc206b..0000000000 --- a/flytekit/extend/backend/base_plugin.py +++ /dev/null @@ -1,107 +0,0 @@ -import typing -from abc import ABC, abstractmethod - -import grpc -from flyteidl.core.tasks_pb2 import TaskTemplate -from flyteidl.service.external_plugin_service_pb2 import ( - RETRYABLE_FAILURE, - RUNNING, - SUCCEEDED, - State, - TaskCreateResponse, - TaskDeleteResponse, - TaskGetResponse, -) - -from flytekit import logger -from flytekit.models.literals import LiteralMap - - -class BackendPluginBase(ABC): - """ - This is the base class for all backend plugins. It defines the interface that all plugins must implement. - The external plugins service will be run either locally or in a pod, and will be responsible for - invoking backend plugins. The propeller will communicate with the external plugins service - to create tasks, get the status of tasks, and delete tasks. - - All the backend plugins should be registered in the BackendPluginRegistry. External plugins service - will look up the plugin based on the task type. Every task type can only have one plugin. - """ - - def __init__(self, task_type: str): - self._task_type = task_type - - @property - def task_type(self) -> str: - """ - task_type is the name of the task type that this plugin supports. - """ - return self._task_type - - @abstractmethod - def create( - self, - context: grpc.ServicerContext, - output_prefix: str, - task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None, - ) -> TaskCreateResponse: - """ - Return a Unique ID for the task that was created. It should return error code if the task creation failed. - """ - pass - - @abstractmethod - def get(self, context: grpc.ServicerContext, job_id: str) -> TaskGetResponse: - """ - Return the status of the task, and return the outputs in some cases. For example, bigquery job - can't write the structured dataset to the output location, so it returns the output literals to the propeller, - and the propeller will write the structured dataset to the blob store. - """ - pass - - @abstractmethod - def delete(self, context: grpc.ServicerContext, job_id: str) -> TaskDeleteResponse: - """ - Delete the task. This call should be idempotent. - """ - pass - - -class BackendPluginRegistry(object): - """ - This is the registry for all backend plugins. The external plugins service will look up the plugin - based on the task type. - """ - - _REGISTRY: typing.Dict[str, BackendPluginBase] = {} - - @staticmethod - def register(plugin: BackendPluginBase): - if plugin.task_type in BackendPluginRegistry._REGISTRY: - raise ValueError(f"Duplicate plugin for task type {plugin.task_type}") - BackendPluginRegistry._REGISTRY[plugin.task_type] = plugin - logger.info(f"Registering backend plugin for task type {plugin.task_type}") - - @staticmethod - def get_plugin(context: grpc.ServicerContext, task_type: str) -> typing.Optional[BackendPluginBase]: - if task_type not in BackendPluginRegistry._REGISTRY: - logger.error(f"Cannot find backend plugin for task type [{task_type}]") - context.set_code(grpc.StatusCode.NOT_FOUND) - context.set_details(f"Cannot find backend plugin for task type [{task_type}]") - return None - return BackendPluginRegistry._REGISTRY[task_type] - - -def convert_to_flyte_state(state: str) -> State: - """ - Convert the state from the backend plugin to the state in flyte. - """ - state = state.lower() - if state in ["failed"]: - return RETRYABLE_FAILURE - elif state in ["done", "succeeded"]: - return SUCCEEDED - elif state in ["running"]: - return RUNNING - raise ValueError(f"Unrecognized state: {state}") diff --git a/flytekit/extend/backend/external_plugin_service.py b/flytekit/extend/backend/external_plugin_service.py deleted file mode 100644 index e820a320b1..0000000000 --- a/flytekit/extend/backend/external_plugin_service.py +++ /dev/null @@ -1,53 +0,0 @@ -import grpc -from flyteidl.service.external_plugin_service_pb2 import ( - PERMANENT_FAILURE, - TaskCreateRequest, - TaskCreateResponse, - TaskDeleteRequest, - TaskDeleteResponse, - TaskGetRequest, - TaskGetResponse, -) -from flyteidl.service.external_plugin_service_pb2_grpc import ExternalPluginServiceServicer - -from flytekit import logger -from flytekit.extend.backend.base_plugin import BackendPluginRegistry -from flytekit.models.literals import LiteralMap -from flytekit.models.task import TaskTemplate - - -class BackendPluginServer(ExternalPluginServiceServicer): - def CreateTask(self, request: TaskCreateRequest, context: grpc.ServicerContext) -> TaskCreateResponse: - try: - tmp = TaskTemplate.from_flyte_idl(request.template) - inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None - plugin = BackendPluginRegistry.get_plugin(context, tmp.type) - if plugin is None: - return TaskCreateResponse() - return plugin.create(context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp) - except Exception as e: - logger.error(f"failed to create task with error {e}") - context.set_code(grpc.StatusCode.INTERNAL) - context.set_details(f"failed to create task with error {e}") - - def GetTask(self, request: TaskGetRequest, context: grpc.ServicerContext) -> TaskGetResponse: - try: - plugin = BackendPluginRegistry.get_plugin(context, request.task_type) - if plugin is None: - return TaskGetResponse(state=PERMANENT_FAILURE) - return plugin.get(context=context, job_id=request.job_id) - except Exception as e: - logger.error(f"failed to get task with error {e}") - context.set_code(grpc.StatusCode.INTERNAL) - context.set_details(f"failed to get task with error {e}") - - def DeleteTask(self, request: TaskDeleteRequest, context: grpc.ServicerContext) -> TaskDeleteResponse: - try: - plugin = BackendPluginRegistry.get_plugin(context, request.task_type) - if plugin is None: - return TaskDeleteResponse() - return plugin.delete(context=context, job_id=request.job_id) - except Exception as e: - logger.error(f"failed to delete task with error {e}") - context.set_code(grpc.StatusCode.INTERNAL) - context.set_details(f"failed to delete task with error {e}") diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/__init__.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/__init__.py index 416a021516..0e0fe80bc7 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/__init__.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/__init__.py @@ -9,7 +9,8 @@ BigQueryConfig BigQueryTask + BigQueryAgent """ -from .backend_plugin import BigQueryPlugin +from .agent import BigQueryAgent from .task import BigQueryConfig, BigQueryTask diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/backend_plugin.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py similarity index 70% rename from plugins/flytekit-bigquery/flytekitplugins/bigquery/backend_plugin.py rename to plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py index acd5ece430..0a9a22923e 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/backend_plugin.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py @@ -1,18 +1,15 @@ import datetime +import json +from dataclasses import asdict, dataclass from typing import Dict, Optional import grpc -from flyteidl.service.external_plugin_service_pb2 import ( - SUCCEEDED, - TaskCreateResponse, - TaskDeleteResponse, - TaskGetResponse, -) +from flyteidl.admin.agent_pb2 import SUCCEEDED, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource from google.cloud import bigquery from flytekit import FlyteContextManager, StructuredDataset, logger from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_plugin import BackendPluginBase, BackendPluginRegistry, convert_to_flyte_state +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state from flytekit.models import literals from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -30,7 +27,12 @@ } -class BigQueryPlugin(BackendPluginBase): +@dataclass +class Metadata: + job_id: str + + +class BigQueryAgent(AgentBase): def __init__(self): super().__init__(task_type="bigquery_query_job_task") @@ -40,7 +42,7 @@ def create( output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, - ) -> TaskCreateResponse: + ) -> CreateTaskResponse: job_config = None if inputs: ctx = FlyteContextManager.current_context() @@ -61,11 +63,14 @@ def create( client = bigquery.Client(project=custom["ProjectID"], location=custom["Location"]) query_job = client.query(task_template.sql.statement, job_config=job_config) - return TaskCreateResponse(job_id=str(query_job.job_id)) + return CreateTaskResponse( + resource_meta=json.dumps(asdict(Metadata(job_id=str(query_job.job_id)))).encode("utf-8") + ) - def get(self, context: grpc.ServicerContext, job_id: str) -> TaskGetResponse: + def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: client = bigquery.Client() - job = client.get_job(job_id) + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + job = client.get_job(metadata.job_id) cur_state = convert_to_flyte_state(str(job.state)) res = None @@ -83,12 +88,13 @@ def get(self, context: grpc.ServicerContext, job_id: str) -> TaskGetResponse: } ) - return TaskGetResponse(state=cur_state, outputs=res.to_flyte_idl()) + return GetTaskResponse(resource=Resource(state=cur_state, outputs=res.to_flyte_idl())) - def delete(self, context: grpc.ServicerContext, job_id: str) -> TaskDeleteResponse: + def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: client = bigquery.Client() - client.cancel_job(job_id) - return TaskDeleteResponse() + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + client.cancel_job(metadata.job_id) + return DeleteTaskResponse() -BackendPluginRegistry.register(BigQueryPlugin()) +AgentRegistry.register(BigQueryAgent()) diff --git a/plugins/flytekit-bigquery/tests/test_backend_plugin.py b/plugins/flytekit-bigquery/tests/test_agent.py similarity index 77% rename from plugins/flytekit-bigquery/tests/test_backend_plugin.py rename to plugins/flytekit-bigquery/tests/test_agent.py index c95cf308a7..237c5c0718 100644 --- a/plugins/flytekit-bigquery/tests/test_backend_plugin.py +++ b/plugins/flytekit-bigquery/tests/test_agent.py @@ -1,12 +1,15 @@ +import json +from dataclasses import asdict from datetime import timedelta from unittest import mock from unittest.mock import MagicMock import grpc -from flyteidl.service.external_plugin_service_pb2 import SUCCEEDED +from flyteidl.admin.agent_pb2 import SUCCEEDED +from flytekitplugins.bigquery.agent import Metadata import flytekit.models.interface as interface_models -from flytekit.extend.backend.base_plugin import BackendPluginRegistry +from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.interfaces.cli_identifiers import Identifier from flytekit.models import literals, task, types from flytekit.models.core.identifier import ResourceType @@ -15,7 +18,7 @@ @mock.patch("google.cloud.bigquery.job.QueryJob") @mock.patch("google.cloud.bigquery.Client") -def test_bigquery_plugin(mock_client, mock_query_job): +def test_bigquery_agent(mock_client, mock_query_job): job_id = "dummy_id" mock_instance = mock_client.return_value mock_query_job_instance = mock_query_job.return_value @@ -39,7 +42,7 @@ def __init__(self): mock_instance.cancel_job.return_value = MockJob() ctx = MagicMock(spec=grpc.ServicerContext) - p = BackendPluginRegistry.get_plugin(ctx, "bigquery_query_job_task") + agent = AgentRegistry.get_agent(ctx, "bigquery_query_job_task") task_id = Identifier( resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" @@ -84,11 +87,13 @@ def __init__(self): sql=Sql("SELECT 1"), ) - assert p.create(ctx, "/tmp", dummy_template, task_inputs).job_id == job_id - res = p.get(ctx, job_id) - assert res.state == SUCCEEDED + metadata_bytes = json.dumps(asdict(Metadata(job_id="dummy_id"))).encode("utf-8") + assert agent.create(ctx, "/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes + res = agent.get(ctx, metadata_bytes) + assert res.resource.state == SUCCEEDED assert ( - res.outputs.literals["results"].scalar.structured_dataset.uri == "bq://dummy_project:dummy_dataset.dummy_table" + res.resource.outputs.literals["results"].scalar.structured_dataset.uri + == "bq://dummy_project:dummy_dataset.dummy_table" ) - p.delete(ctx, job_id) + agent.delete(ctx, metadata_bytes) mock_instance.cancel_job.assert_called() diff --git a/setup.py b/setup.py index 5273590d9f..4bae776373 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ }, install_requires=[ "googleapis-common-protos>=1.57", - "flyteidl>=1.5.6", + "flyteidl>=1.5.10", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py new file mode 100644 index 0000000000..4fbba075af --- /dev/null +++ b/tests/flytekit/unit/extend/test_agent.py @@ -0,0 +1,121 @@ +import json +import typing +from dataclasses import asdict, dataclass +from datetime import timedelta +from unittest.mock import MagicMock + +import grpc +from flyteidl.admin.agent_pb2 import ( + PERMANENT_FAILURE, + SUCCEEDED, + CreateTaskRequest, + CreateTaskResponse, + DeleteTaskRequest, + DeleteTaskResponse, + GetTaskRequest, + GetTaskResponse, + Resource, +) + +import flytekit.models.interface as interface_models +from flytekit.extend.backend.agent_service import AgentService +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.models import literals, task, types +from flytekit.models.core.identifier import Identifier, ResourceType +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +dummy_id = "dummy_id" + + +@dataclass +class Metadata: + job_id: str + + +class DummyAgent(AgentBase): + def __init__(self): + super().__init__(task_type="dummy") + + def create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")) + + def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + return GetTaskResponse(resource=Resource(state=SUCCEEDED)) + + def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + return DeleteTaskResponse() + + +AgentRegistry.register(DummyAgent()) + +task_id = Identifier(resource_type=ResourceType.TASK, project="project", domain="domain", name="t1", version="version") +task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", +) + +int_type = types.LiteralType(types.SimpleType.INTEGER) +interfaces = interface_models.TypedInterface( + { + "a": interface_models.Variable(int_type, "description1"), + }, + {}, +) +task_inputs = literals.LiteralMap( + { + "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + }, +) + +dummy_template = TaskTemplate( + id=task_id, + metadata=task_metadata, + interface=interfaces, + type="dummy", + custom={}, +) + + +def test_dummy_agent(): + ctx = MagicMock(spec=grpc.ServicerContext) + agent = AgentRegistry.get_agent(ctx, "dummy") + metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") + assert agent.create(ctx, "/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes + assert agent.get(ctx, metadata_bytes).resource.state == SUCCEEDED + assert agent.delete(ctx, metadata_bytes) == DeleteTaskResponse() + + +def test_agent_server(): + service = AgentService() + ctx = MagicMock(spec=grpc.ServicerContext) + request = CreateTaskRequest( + inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() + ) + + metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") + assert service.CreateTask(request, ctx).resource_meta == metadata_bytes + assert ( + service.GetTask(GetTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx).resource.state + == SUCCEEDED + ) + assert ( + service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) + == DeleteTaskResponse() + ) + + res = service.GetTask(GetTaskRequest(task_type="fake", resource_meta=metadata_bytes), ctx) + assert res.resource.state == PERMANENT_FAILURE diff --git a/tests/flytekit/unit/extend/test_backend_plugin.py b/tests/flytekit/unit/extend/test_backend_plugin.py deleted file mode 100644 index 9dfd20d99e..0000000000 --- a/tests/flytekit/unit/extend/test_backend_plugin.py +++ /dev/null @@ -1,105 +0,0 @@ -import typing -from datetime import timedelta -from unittest.mock import MagicMock - -import grpc -from flyteidl.service.external_plugin_service_pb2 import ( - PERMANENT_FAILURE, - SUCCEEDED, - TaskCreateRequest, - TaskCreateResponse, - TaskDeleteRequest, - TaskDeleteResponse, - TaskGetRequest, - TaskGetResponse, -) - -import flytekit.models.interface as interface_models -from flytekit.extend.backend.base_plugin import BackendPluginBase, BackendPluginRegistry -from flytekit.extend.backend.external_plugin_service import BackendPluginServer -from flytekit.models import literals, task, types -from flytekit.models.core.identifier import Identifier, ResourceType -from flytekit.models.literals import LiteralMap -from flytekit.models.task import TaskTemplate - -dummy_id = "dummy_id" - - -class DummyPlugin(BackendPluginBase): - def __init__(self): - super().__init__(task_type="dummy") - - def create( - self, - context: grpc.ServicerContext, - output_prefix: str, - task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None, - ) -> TaskCreateResponse: - return TaskCreateResponse(job_id=dummy_id) - - def get(self, context: grpc.ServicerContext, job_id: str) -> TaskGetResponse: - return TaskGetResponse(state=SUCCEEDED) - - def delete(self, context: grpc.ServicerContext, job_id) -> TaskDeleteResponse: - return TaskDeleteResponse() - - -BackendPluginRegistry.register(DummyPlugin()) - -task_id = Identifier(resource_type=ResourceType.TASK, project="project", domain="domain", name="t1", version="version") -task_metadata = task.TaskMetadata( - True, - task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), - timedelta(days=1), - literals.RetryStrategy(3), - True, - "0.1.1b0", - "This is deprecated!", - True, - "A", -) - -int_type = types.LiteralType(types.SimpleType.INTEGER) -interfaces = interface_models.TypedInterface( - { - "a": interface_models.Variable(int_type, "description1"), - }, - {}, -) -task_inputs = literals.LiteralMap( - { - "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), - }, -) - -dummy_template = TaskTemplate( - id=task_id, - metadata=task_metadata, - interface=interfaces, - type="dummy", - custom={}, -) - - -def test_dummy_plugin(): - ctx = MagicMock(spec=grpc.ServicerContext) - p = BackendPluginRegistry.get_plugin(ctx, "dummy") - assert p.create(ctx, "/tmp", dummy_template, task_inputs).job_id == dummy_id - assert p.get(ctx, dummy_id).state == SUCCEEDED - assert p.delete(ctx, dummy_id) == TaskDeleteResponse() - - -def test_backend_plugin_server(): - server = BackendPluginServer() - ctx = MagicMock(spec=grpc.ServicerContext) - request = TaskCreateRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() - ) - - assert server.CreateTask(request, ctx).job_id == dummy_id - assert server.GetTask(TaskGetRequest(task_type="dummy", job_id=dummy_id), ctx).state == SUCCEEDED - assert server.DeleteTask(TaskDeleteRequest(task_type="dummy", job_id=dummy_id), ctx) == TaskDeleteResponse() - - res = server.GetTask(TaskGetRequest(task_type="fake", job_id=dummy_id), ctx) - assert res.state == PERMANENT_FAILURE