Skip to content

Commit

Permalink
Merge branch 'master' of github.com:flyteorg/flytekit into downloadfile
Browse files Browse the repository at this point in the history
  • Loading branch information
pingsutw committed Jun 10, 2023
2 parents c6ac046 + 3370a96 commit 49678b2
Show file tree
Hide file tree
Showing 18 changed files with 354 additions and 308 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/pythonpublish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ jobs:
cache-from: type=gha
cache-to: type=gha,mode=max

build-and-push-external-plugin-service-images:
build-and-push-flyteagent-images:
runs-on: ubuntu-latest
needs: deploy
steps:
Expand All @@ -161,12 +161,12 @@ jobs:
registry: ghcr.io
username: "${{ secrets.FLYTE_BOT_USERNAME }}"
password: "${{ secrets.FLYTE_BOT_PAT }}"
- name: Prepare External Plugin Service Image Names
id: external-plugin-service-names
- name: Prepare Flyte Agent Image Names
id: flyteagent-names
uses: docker/metadata-action@v3
with:
images: |
ghcr.io/${{ github.repository_owner }}/external-plugin-service
ghcr.io/${{ github.repository_owner }}/flyteagent
tags: |
latest
${{ github.sha }}
Expand All @@ -177,10 +177,10 @@ jobs:
context: "."
platforms: linux/arm64, linux/amd64
push: ${{ github.event_name == 'release' }}
tags: ${{ steps.external-plugin-service-names.outputs.tags }}
tags: ${{ steps.flyteagent-names.outputs.tags }}
build-args: |
VERSION=${{ needs.deploy.outputs.version }}
file: ./Dockerfile.external-plugin-service
file: ./Dockerfile.agent
cache-from: type=gha
cache-to: type=gha,mode=max

Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ flask==2.2.3
# via mlflow
flatbuffers==23.1.21
# via tensorflow
flyteidl==1.5.6
flyteidl==1.5.10
# via flytekit
fonttools==4.38.0
# via matplotlib
Expand Down
12 changes: 11 additions & 1 deletion flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ClientConfig:
device_authorization_endpoint: typing.Optional[str] = None
scopes: typing.List[str] = None
header_key: str = "authorization"
audience: typing.Optional[str] = None


class ClientConfigStore(object):
Expand Down Expand Up @@ -174,6 +175,7 @@ def __init__(
scopes: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
audience: typing.Optional[str] = None,
):
if not client_id or not client_secret:
raise ValueError("Client ID and Client SECRET both are required.")
Expand All @@ -183,6 +185,7 @@ def __init__(
self._scopes = scopes or cfg.scopes
self._client_id = client_id
self._client_secret = client_secret
self._audience = audience or cfg.audience
super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url, verify=verify)

def refresh_credentials(self):
Expand All @@ -195,14 +198,21 @@ def refresh_credentials(self):
"""
token_endpoint = self._token_endpoint
scopes = self._scopes
audience = self._audience

# Note that unlike the Pkce flow, the client ID does not come from Admin.
logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}")
authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret)

token, expires_in = token_client.get_token(
token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url, verify=self._verify
token_endpoint=token_endpoint,
authorization_header=authorization_header,
http_proxy_url=self._http_proxy_url,
verify=self._verify,
scopes=scopes,
audience=audience,
)

logging.info("Retrieved new token, expires in {}".format(expires_in))
self._creds = Credentials(token)

Expand Down
4 changes: 4 additions & 0 deletions flytekit/clients/auth/token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def get_token(
authorization_header: typing.Optional[str] = None,
client_id: typing.Optional[str] = None,
device_code: typing.Optional[str] = None,
audience: typing.Optional[str] = None,
grant_type: GrantType = GrantType.CLIENT_CREDS,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
Expand All @@ -98,9 +99,12 @@ def get_token(
body["device_code"] = device_code
if scopes is not None:
body["scope"] = ",".join(scopes)
if audience:
body["audience"] = audience

proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify)

if not response.ok:
j = response.json()
if "error" in j:
Expand Down
2 changes: 2 additions & 0 deletions flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def get_client_config(self) -> ClientConfig:
scopes=public_client_config.scopes,
header_key=public_client_config.authorization_metadata_key or None,
device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint,
audience=public_client_config.audience,
)


Expand Down Expand Up @@ -73,6 +74,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
client_secret=cfg.client_credentials_secret,
cfg_store=cfg_store,
scopes=cfg.scopes,
audience=cfg.audience,
http_proxy_url=cfg.http_proxy_url,
verify=verify,
)
Expand Down
14 changes: 7 additions & 7 deletions flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import click
import grpc
from flyteidl.service.external_plugin_service_pb2_grpc import add_ExternalPluginServiceServicer_to_server
from flyteidl.service.agent_pb2_grpc import add_AsyncAgentServiceServicer_to_server

from flytekit.extend.backend.external_plugin_service import BackendPluginServer
from flytekit.extend.backend.agent_service import AgentService

_serve_help = """Start a grpc server for the external plugin service."""
_serve_help = """Start a grpc server for the agent service."""


@click.command("serve", help=_serve_help)
Expand All @@ -15,7 +15,7 @@
default="8000",
is_flag=False,
type=int,
help="Grpc port for the external plugin service",
help="Grpc port for the agent service",
)
@click.option(
"--worker",
Expand All @@ -35,11 +35,11 @@
@click.pass_context
def serve(_: click.Context, port, worker, timeout):
"""
Start a grpc server for the external plugin service.
Start a grpc server for the agent service.
"""
click.secho("Starting the external plugin service...", fg="blue")
click.secho("Starting the agent service...", fg="blue")
server = grpc.server(futures.ThreadPoolExecutor(max_workers=worker))
add_ExternalPluginServiceServicer_to_server(BackendPluginServer(), server)
add_AsyncAgentServiceServicer_to_server(AgentService(), server)

server.add_insecure_port(f"[::]:{port}")
server.start()
Expand Down
1 change: 1 addition & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None
kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file))
kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file))
kwargs = set_if_exists(kwargs, "console_endpoint", _internal.Platform.CONSOLE_ENDPOINT.read(config_file))

kwargs = set_if_exists(kwargs, "http_proxy_url", _internal.Platform.HTTP_PROXY_URL.read(config_file))
return PlatformConfig(**kwargs)

Expand Down
54 changes: 54 additions & 0 deletions flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import grpc
from flyteidl.admin.agent_pb2 import (
PERMANENT_FAILURE,
CreateTaskRequest,
CreateTaskResponse,
DeleteTaskRequest,
DeleteTaskResponse,
GetTaskRequest,
GetTaskResponse,
Resource,
)
from flyteidl.service.agent_pb2_grpc import AsyncAgentServiceServicer

from flytekit import logger
from flytekit.extend.backend.base_agent import AgentRegistry
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate


class AgentService(AsyncAgentServiceServicer):
def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse:
try:
tmp = TaskTemplate.from_flyte_idl(request.template)
inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
agent = AgentRegistry.get_agent(context, tmp.type)
if agent is None:
return CreateTaskResponse()
return agent.create(context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp)
except Exception as e:
logger.error(f"failed to create task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to create task with error {e}")

def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse:
try:
agent = AgentRegistry.get_agent(context, request.task_type)
if agent is None:
return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE))
return agent.get(context=context, resource_meta=request.resource_meta)
except Exception as e:
logger.error(f"failed to get task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to get task with error {e}")

def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
try:
agent = AgentRegistry.get_agent(context, request.task_type)
if agent is None:
return DeleteTaskResponse()
return agent.delete(context=context, resource_meta=request.resource_meta)
except Exception as e:
logger.error(f"failed to delete task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to delete task with error {e}")
107 changes: 107 additions & 0 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import typing
from abc import ABC, abstractmethod

import grpc
from flyteidl.admin.agent_pb2 import (
RETRYABLE_FAILURE,
RUNNING,
SUCCEEDED,
CreateTaskResponse,
DeleteTaskResponse,
GetTaskResponse,
State,
)
from flyteidl.core.tasks_pb2 import TaskTemplate

from flytekit import logger
from flytekit.models.literals import LiteralMap


class AgentBase(ABC):
"""
This is the base class for all agents. It defines the interface that all agents must implement.
The agent service will be run either locally or in a pod, and will be responsible for
invoking agents. The propeller will communicate with the agent service
to create tasks, get the status of tasks, and delete tasks.
All the agents should be registered in the AgentRegistry. Agent Service
will look up the agent based on the task type. Every task type can only have one agent.
"""

def __init__(self, task_type: str):
self._task_type = task_type

@property
def task_type(self) -> str:
"""
task_type is the name of the task type that this agent supports.
"""
return self._task_type

@abstractmethod
def create(
self,
context: grpc.ServicerContext,
output_prefix: str,
task_template: TaskTemplate,
inputs: typing.Optional[LiteralMap] = None,
) -> CreateTaskResponse:
"""
Return a Unique ID for the task that was created. It should return error code if the task creation failed.
"""
pass

@abstractmethod
def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse:
"""
Return the status of the task, and return the outputs in some cases. For example, bigquery job
can't write the structured dataset to the output location, so it returns the output literals to the propeller,
and the propeller will write the structured dataset to the blob store.
"""
pass

@abstractmethod
def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
"""
Delete the task. This call should be idempotent.
"""
pass


class AgentRegistry(object):
"""
This is the registry for all agents. The agent service will look up the agent
based on the task type.
"""

_REGISTRY: typing.Dict[str, AgentBase] = {}

@staticmethod
def register(agent: AgentBase):
if agent.task_type in AgentRegistry._REGISTRY:
raise ValueError(f"Duplicate agent for task type {agent.task_type}")
AgentRegistry._REGISTRY[agent.task_type] = agent
logger.info(f"Registering an agent for task type {agent.task_type}")

@staticmethod
def get_agent(context: grpc.ServicerContext, task_type: str) -> typing.Optional[AgentBase]:
if task_type not in AgentRegistry._REGISTRY:
logger.error(f"Cannot find agent for task type [{task_type}]")
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(f"Cannot find the agent for task type [{task_type}]")
return None
return AgentRegistry._REGISTRY[task_type]


def convert_to_flyte_state(state: str) -> State:
"""
Convert the state from the agent to the state in flyte.
"""
state = state.lower()
if state in ["failed"]:
return RETRYABLE_FAILURE
elif state in ["done", "succeeded"]:
return SUCCEEDED
elif state in ["running"]:
return RUNNING
raise ValueError(f"Unrecognized state: {state}")
Loading

0 comments on commit 49678b2

Please sign in to comment.