Skip to content

Commit

Permalink
External Plugin Service (grpc) (#1524)
Browse files Browse the repository at this point in the history
External Plugin Service
  • Loading branch information
pingsutw authored May 6, 2023
1 parent 0b0de27 commit 35e52ef
Show file tree
Hide file tree
Showing 17 changed files with 570 additions and 3 deletions.
42 changes: 42 additions & 0 deletions .github/workflows/pythonpublish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,48 @@ jobs:
cache-from: type=gha
cache-to: type=gha,mode=max

build-and-push-external-plugin-service-images:
runs-on: ubuntu-latest
needs: deploy
steps:
- uses: actions/checkout@v2
with:
fetch-depth: "0"
- name: Set up QEMU
uses: docker/setup-qemu-action@v1
- name: Set up Docker Buildx
id: buildx
uses: docker/setup-buildx-action@v1
- name: Login to GitHub Container Registry
if: ${{ github.event_name == 'release' }}
uses: docker/login-action@v1
with:
registry: ghcr.io
username: "${{ secrets.FLYTE_BOT_USERNAME }}"
password: "${{ secrets.FLYTE_BOT_PAT }}"
- name: Prepare External Plugin Service Image Names
id: external-plugin-service-names
uses: docker/metadata-action@v3
with:
images: |
ghcr.io/${{ github.repository_owner }}/external-plugin-service
tags: |
latest
${{ github.sha }}
${{ needs.deploy.outputs.version }}
- name: Push External Plugin Service Image to GitHub Registry
uses: docker/build-push-action@v2
with:
context: "."
platforms: linux/arm64, linux/amd64
push: ${{ github.event_name == 'release' }}
tags: ${{ steps.external-plugin-service-names.outputs.tags }}
build-args: |
VERSION=${{ needs.deploy.outputs.version }}
file: ./Dockerfile
cache-from: type=gha
cache-to: type=gha,mode=max

build-and-push-spark-images:
runs-on: ubuntu-latest
needs: deploy
Expand Down
10 changes: 10 additions & 0 deletions Dockerfile.external-plugin-service
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
FROM python:3.9-slim-buster

MAINTAINER Flyte Team <[email protected]>
LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit

ARG VERSION
RUN pip install -U flytekit==$VERSION \
flytekitplugins-bigquery==$VERSION \

CMD pyflyte serve --port 8000
2 changes: 1 addition & 1 deletion doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ flask==2.2.3
# via mlflow
flatbuffers==23.1.21
# via tensorflow
flyteidl==1.3.12
flyteidl==1.3.16
# via flytekit
fonttools==4.38.0
# via matplotlib
Expand Down
2 changes: 1 addition & 1 deletion flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ def get_upload_signed_url(

def get_download_signed_url(
self, native_url: str, expires_in: datetime.timedelta = None
) -> _data_proxy_pb2.CreateUploadLocationResponse:
) -> _data_proxy_pb2.CreateDownloadLocationRequest:
expires_in_pb = None
if expires_in:
expires_in_pb = Duration()
Expand Down
2 changes: 2 additions & 0 deletions flytekit/clis/sdk_in_container/pyflyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from flytekit.clis.sdk_in_container.register import register
from flytekit.clis.sdk_in_container.run import run
from flytekit.clis.sdk_in_container.serialize import serialize
from flytekit.clis.sdk_in_container.serve import serve
from flytekit.configuration.internal import LocalSDK
from flytekit.exceptions.base import FlyteException
from flytekit.exceptions.user import FlyteInvalidInputException
Expand Down Expand Up @@ -134,6 +135,7 @@ def main(ctx, pkgs: typing.List[str], config: str, verbose: bool):
main.add_command(run)
main.add_command(register)
main.add_command(backfill)
main.add_command(serve)
main.add_command(build)
main.add_command(launchplan)
main.epilog
Expand Down
46 changes: 46 additions & 0 deletions flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from concurrent import futures

import click
import grpc
from flyteidl.service.external_plugin_service_pb2_grpc import add_ExternalPluginServiceServicer_to_server

from flytekit.extend.backend.external_plugin_service import BackendPluginServer

_serve_help = """Start a grpc server for the external plugin service."""


@click.command("serve", help=_serve_help)
@click.option(
"--port",
default="8000",
is_flag=False,
type=int,
help="Grpc port for the external plugin service",
)
@click.option(
"--worker",
default="10",
is_flag=False,
type=int,
help="Number of workers for the grpc server",
)
@click.option(
"--timeout",
default=None,
is_flag=False,
type=int,
help="It will wait for the specified number of seconds before shutting down grpc server. It should only be used "
"for testing.",
)
@click.pass_context
def serve(_: click.Context, port, worker, timeout):
"""
Start a grpc server for the external plugin service.
"""
click.secho("Starting the external plugin service...", fg="blue")
server = grpc.server(futures.ThreadPoolExecutor(max_workers=worker))
add_ExternalPluginServiceServicer_to_server(BackendPluginServer(), server)

server.add_insecure_port(f"[::]:{port}")
server.start()
server.wait_for_termination(timeout=timeout)
Empty file.
107 changes: 107 additions & 0 deletions flytekit/extend/backend/base_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import typing
from abc import ABC, abstractmethod

import grpc
from flyteidl.core.tasks_pb2 import TaskTemplate
from flyteidl.service.external_plugin_service_pb2 import (
RETRYABLE_FAILURE,
RUNNING,
SUCCEEDED,
State,
TaskCreateResponse,
TaskDeleteResponse,
TaskGetResponse,
)

from flytekit import logger
from flytekit.models.literals import LiteralMap


class BackendPluginBase(ABC):
"""
This is the base class for all backend plugins. It defines the interface that all plugins must implement.
The external plugins service will be run either locally or in a pod, and will be responsible for
invoking backend plugins. The propeller will communicate with the external plugins service
to create tasks, get the status of tasks, and delete tasks.
All the backend plugins should be registered in the BackendPluginRegistry. External plugins service
will look up the plugin based on the task type. Every task type can only have one plugin.
"""

def __init__(self, task_type: str):
self._task_type = task_type

@property
def task_type(self) -> str:
"""
task_type is the name of the task type that this plugin supports.
"""
return self._task_type

@abstractmethod
def create(
self,
context: grpc.ServicerContext,
output_prefix: str,
task_template: TaskTemplate,
inputs: typing.Optional[LiteralMap] = None,
) -> TaskCreateResponse:
"""
Return a Unique ID for the task that was created. It should return error code if the task creation failed.
"""
pass

@abstractmethod
def get(self, context: grpc.ServicerContext, job_id: str) -> TaskGetResponse:
"""
Return the status of the task, and return the outputs in some cases. For example, bigquery job
can't write the structured dataset to the output location, so it returns the output literals to the propeller,
and the propeller will write the structured dataset to the blob store.
"""
pass

@abstractmethod
def delete(self, context: grpc.ServicerContext, job_id: str) -> TaskDeleteResponse:
"""
Delete the task. This call should be idempotent.
"""
pass


class BackendPluginRegistry(object):
"""
This is the registry for all backend plugins. The external plugins service will look up the plugin
based on the task type.
"""

_REGISTRY: typing.Dict[str, BackendPluginBase] = {}

@staticmethod
def register(plugin: BackendPluginBase):
if plugin.task_type in BackendPluginRegistry._REGISTRY:
raise ValueError(f"Duplicate plugin for task type {plugin.task_type}")
BackendPluginRegistry._REGISTRY[plugin.task_type] = plugin
logger.info(f"Registering backend plugin for task type {plugin.task_type}")

@staticmethod
def get_plugin(context: grpc.ServicerContext, task_type: str) -> typing.Optional[BackendPluginBase]:
if task_type not in BackendPluginRegistry._REGISTRY:
logger.error(f"Cannot find backend plugin for task type [{task_type}]")
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(f"Cannot find backend plugin for task type [{task_type}]")
return None
return BackendPluginRegistry._REGISTRY[task_type]


def convert_to_flyte_state(state: str) -> State:
"""
Convert the state from the backend plugin to the state in flyte.
"""
state = state.lower()
if state in ["failed"]:
return RETRYABLE_FAILURE
elif state in ["done", "succeeded"]:
return SUCCEEDED
elif state in ["running"]:
return RUNNING
raise ValueError(f"Unrecognized state: {state}")
53 changes: 53 additions & 0 deletions flytekit/extend/backend/external_plugin_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import grpc
from flyteidl.service.external_plugin_service_pb2 import (
PERMANENT_FAILURE,
TaskCreateRequest,
TaskCreateResponse,
TaskDeleteRequest,
TaskDeleteResponse,
TaskGetRequest,
TaskGetResponse,
)
from flyteidl.service.external_plugin_service_pb2_grpc import ExternalPluginServiceServicer

from flytekit import logger
from flytekit.extend.backend.base_plugin import BackendPluginRegistry
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate


class BackendPluginServer(ExternalPluginServiceServicer):
def CreateTask(self, request: TaskCreateRequest, context: grpc.ServicerContext) -> TaskCreateResponse:
try:
tmp = TaskTemplate.from_flyte_idl(request.template)
inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
plugin = BackendPluginRegistry.get_plugin(context, tmp.type)
if plugin is None:
return TaskCreateResponse()
return plugin.create(context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp)
except Exception as e:
logger.error(f"failed to create task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to create task with error {e}")

def GetTask(self, request: TaskGetRequest, context: grpc.ServicerContext) -> TaskGetResponse:
try:
plugin = BackendPluginRegistry.get_plugin(context, request.task_type)
if plugin is None:
return TaskGetResponse(state=PERMANENT_FAILURE)
return plugin.get(context=context, job_id=request.job_id)
except Exception as e:
logger.error(f"failed to get task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to get task with error {e}")

def DeleteTask(self, request: TaskDeleteRequest, context: grpc.ServicerContext) -> TaskDeleteResponse:
try:
plugin = BackendPluginRegistry.get_plugin(context, request.task_type)
if plugin is None:
return TaskDeleteResponse()
return plugin.delete(context=context, job_id=request.job_id)
except Exception as e:
logger.error(f"failed to delete task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to delete task with error {e}")
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
BigQueryTask
"""

from .backend_plugin import BigQueryPlugin
from .task import BigQueryConfig, BigQueryTask
Loading

0 comments on commit 35e52ef

Please sign in to comment.