Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update async agent example #4906

Merged
merged 2 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 7 additions & 21 deletions docs/flyte_agents/creating_an_agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Metadata:
# If you add s3 file path, the agent will check if the file exists.
job_id: str

class CustomAgent(AgentBase):
class CustomAsyncAgent(AsyncAgentBase):
def __init__(self, task_type: str):
# Each agent should have a unique task type.
# The Flyte agent service will use the task type
Expand All @@ -52,48 +52,34 @@ class CustomAgent(AgentBase):

def create(
self,
context: grpc.ServicerContext,
output_prefix: str,
task_template: TaskTemplate,
inputs: typing.Optional[LiteralMap] = None,
**kwargs,
) -> TaskCreateResponse:
# 1. Submit the task to the external service (BigQuery, DataBricks, etc.)
# 2. Create metadata for the task, such as jobID.
# 3. Return the metadata, serialized to bytes.
res = requests.post(url, json=data)
return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=str(res.job_id)))).encode("utf-8"))

def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> TaskGetResponse:
def get(self, resource_meta: bytes, **kwargs) -> TaskGetResponse:
# 1. Deserialize the metadata.
# 2. Use the metadata to get the job status.
# 3. Return the job status.
metadata = Metadata(**json.loads(resource_meta.decode("utf-8")))
res = requests.get(url, json={"job_id": metadata.job_id})
return GetTaskResponse(resource=Resource(state=res.state)

def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> TaskDeleteResponse:
def delete(self, resource_meta: bytes, **kwargs) -> TaskDeleteResponse:
# 1. Deserialize the metadata.
# 2. Use the metadata to delete the job.
# 3. If failed to delete the job, add the error message to the grpc context.
# context.set_code(grpc.StatusCode.INTERNAL)
# context.set_details(f"failed to create task with error {e}")
try:
metadata = Metadata(**json.loads(resource_meta.decode("utf-8")))
requests.delete(url, json={"job_id": metadata.job_id})
except Exception as e:
logger.error(f"failed to delete task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to delete task with error {e}")
metadata = Metadata(**json.loads(resource_meta.decode("utf-8")))
requests.delete(url, json={"job_id": metadata.job_id})
return DeleteTaskResponse()

# To register the custom agent
AgentRegistry.register(CustomAgent())
AgentRegistry.register(CustomAsyncAgent())
```

For an example implementation, see the [BigQuery Agent](https://github.com/flyteorg/flytekit/blob/9977aac26242ebbede8e00d476c2fbc59ac5487a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py#L35).

## Sync agent interface specification
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neverett, I'll add it back in the separate PR because some sync agent APIs are still under discussion.


To create a new async agent, extend the `AgentBase` class in the `flytekit.backend` module and implement the `execute` method.

- `execute`: This method is used to initiate a new job and return the response.
6 changes: 6 additions & 0 deletions docs/flyte_agents/testing_agents_locally.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ Agents can be tested locally without configuring or running the backend server,

The task inherited from `AsyncAgentExecutorMixin` can be executed locally, allowing flytekit to mimic FlytePropeller's behavior to call the agent.

```python
class BigQueryTask(AsyncAgentExecutorMixin, SQLTask[BigQueryConfig]):
def __init__(self, name: str, **kwargs):
...
```

:::{note}

In some cases, you will need to store credentials in your local environment when testing locally.
Expand Down
Loading