-
Notifications
You must be signed in to change notification settings - Fork 300
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' of github.com:flyteorg/flytekit into downloadfile
- Loading branch information
Showing
18 changed files
with
354 additions
and
308 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import grpc | ||
from flyteidl.admin.agent_pb2 import ( | ||
PERMANENT_FAILURE, | ||
CreateTaskRequest, | ||
CreateTaskResponse, | ||
DeleteTaskRequest, | ||
DeleteTaskResponse, | ||
GetTaskRequest, | ||
GetTaskResponse, | ||
Resource, | ||
) | ||
from flyteidl.service.agent_pb2_grpc import AsyncAgentServiceServicer | ||
|
||
from flytekit import logger | ||
from flytekit.extend.backend.base_agent import AgentRegistry | ||
from flytekit.models.literals import LiteralMap | ||
from flytekit.models.task import TaskTemplate | ||
|
||
|
||
class AgentService(AsyncAgentServiceServicer): | ||
def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse: | ||
try: | ||
tmp = TaskTemplate.from_flyte_idl(request.template) | ||
inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None | ||
agent = AgentRegistry.get_agent(context, tmp.type) | ||
if agent is None: | ||
return CreateTaskResponse() | ||
return agent.create(context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp) | ||
except Exception as e: | ||
logger.error(f"failed to create task with error {e}") | ||
context.set_code(grpc.StatusCode.INTERNAL) | ||
context.set_details(f"failed to create task with error {e}") | ||
|
||
def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse: | ||
try: | ||
agent = AgentRegistry.get_agent(context, request.task_type) | ||
if agent is None: | ||
return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE)) | ||
return agent.get(context=context, resource_meta=request.resource_meta) | ||
except Exception as e: | ||
logger.error(f"failed to get task with error {e}") | ||
context.set_code(grpc.StatusCode.INTERNAL) | ||
context.set_details(f"failed to get task with error {e}") | ||
|
||
def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse: | ||
try: | ||
agent = AgentRegistry.get_agent(context, request.task_type) | ||
if agent is None: | ||
return DeleteTaskResponse() | ||
return agent.delete(context=context, resource_meta=request.resource_meta) | ||
except Exception as e: | ||
logger.error(f"failed to delete task with error {e}") | ||
context.set_code(grpc.StatusCode.INTERNAL) | ||
context.set_details(f"failed to delete task with error {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import typing | ||
from abc import ABC, abstractmethod | ||
|
||
import grpc | ||
from flyteidl.admin.agent_pb2 import ( | ||
RETRYABLE_FAILURE, | ||
RUNNING, | ||
SUCCEEDED, | ||
CreateTaskResponse, | ||
DeleteTaskResponse, | ||
GetTaskResponse, | ||
State, | ||
) | ||
from flyteidl.core.tasks_pb2 import TaskTemplate | ||
|
||
from flytekit import logger | ||
from flytekit.models.literals import LiteralMap | ||
|
||
|
||
class AgentBase(ABC): | ||
""" | ||
This is the base class for all agents. It defines the interface that all agents must implement. | ||
The agent service will be run either locally or in a pod, and will be responsible for | ||
invoking agents. The propeller will communicate with the agent service | ||
to create tasks, get the status of tasks, and delete tasks. | ||
All the agents should be registered in the AgentRegistry. Agent Service | ||
will look up the agent based on the task type. Every task type can only have one agent. | ||
""" | ||
|
||
def __init__(self, task_type: str): | ||
self._task_type = task_type | ||
|
||
@property | ||
def task_type(self) -> str: | ||
""" | ||
task_type is the name of the task type that this agent supports. | ||
""" | ||
return self._task_type | ||
|
||
@abstractmethod | ||
def create( | ||
self, | ||
context: grpc.ServicerContext, | ||
output_prefix: str, | ||
task_template: TaskTemplate, | ||
inputs: typing.Optional[LiteralMap] = None, | ||
) -> CreateTaskResponse: | ||
""" | ||
Return a Unique ID for the task that was created. It should return error code if the task creation failed. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: | ||
""" | ||
Return the status of the task, and return the outputs in some cases. For example, bigquery job | ||
can't write the structured dataset to the output location, so it returns the output literals to the propeller, | ||
and the propeller will write the structured dataset to the blob store. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: | ||
""" | ||
Delete the task. This call should be idempotent. | ||
""" | ||
pass | ||
|
||
|
||
class AgentRegistry(object): | ||
""" | ||
This is the registry for all agents. The agent service will look up the agent | ||
based on the task type. | ||
""" | ||
|
||
_REGISTRY: typing.Dict[str, AgentBase] = {} | ||
|
||
@staticmethod | ||
def register(agent: AgentBase): | ||
if agent.task_type in AgentRegistry._REGISTRY: | ||
raise ValueError(f"Duplicate agent for task type {agent.task_type}") | ||
AgentRegistry._REGISTRY[agent.task_type] = agent | ||
logger.info(f"Registering an agent for task type {agent.task_type}") | ||
|
||
@staticmethod | ||
def get_agent(context: grpc.ServicerContext, task_type: str) -> typing.Optional[AgentBase]: | ||
if task_type not in AgentRegistry._REGISTRY: | ||
logger.error(f"Cannot find agent for task type [{task_type}]") | ||
context.set_code(grpc.StatusCode.NOT_FOUND) | ||
context.set_details(f"Cannot find the agent for task type [{task_type}]") | ||
return None | ||
return AgentRegistry._REGISTRY[task_type] | ||
|
||
|
||
def convert_to_flyte_state(state: str) -> State: | ||
""" | ||
Convert the state from the agent to the state in flyte. | ||
""" | ||
state = state.lower() | ||
if state in ["failed"]: | ||
return RETRYABLE_FAILURE | ||
elif state in ["done", "succeeded"]: | ||
return SUCCEEDED | ||
elif state in ["running"]: | ||
return RUNNING | ||
raise ValueError(f"Unrecognized state: {state}") |
Oops, something went wrong.