Skip to content

Commit

Permalink
Merge branch 'tomeu/tlk-1042-welcome-state-for-new-conversation' of g…
Browse files Browse the repository at this point in the history
…ithub.com:cohere-ai/cohere-toolkit into khalil/TLK-1043
  • Loading branch information
knajjars committed Aug 23, 2024
2 parents 48d7fab + dede5cb commit d21802a
Show file tree
Hide file tree
Showing 51 changed files with 1,189 additions and 510 deletions.
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ wandb/
secrets.toml

# VScode settings
.vscode/
.vscode/*
!.vscode/extensions.json
!.vscode/settings.default.json

.DS_Store

Expand All @@ -156,3 +158,6 @@ secrets.toml
/src/backend/config/configuration.yaml
/src/backend/config/secrets.yaml
logs/
credentials.json
token.json
src/interfaces/assistants_web/bun.lockb
5 changes: 5 additions & 0 deletions .vscode/extensions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"recommendations": [
"charliermarsh.ruff"
]
}
8 changes: 8 additions & 0 deletions .vscode/settings.default.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.codeActionsOnSave": {
"notebook.source.fixAll": "explicit",
// uncomment this if you don't want the editor to automatically organize imports
// "source.organizeImports": "never"
}
}
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ win-setup:
poetry install --with setup --verbose
poetry run python src/backend/cli/main.py

.PHONY: typecheck
typecheck:
poetry run pyright

.PHONY: lint
lint:
poetry run ruff check
Expand Down Expand Up @@ -123,3 +127,7 @@ test-db:
.PHONY: dev-sync
dev-sync:
@docker compose up --build sync_worker sync_publisher flower -d

.PHONY: dev-sync-down
dev-sync-down:
@docker compose down sync_worker sync_publisher flower
10 changes: 9 additions & 1 deletion docs/setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,16 @@ If you also need to install the community features, run:
poetry install --with community
```

The codebase is formatted and linted using [Ruff](https://docs.astral.sh/ruff/).
The codebase is formatted and linted using [Ruff](https://docs.astral.sh/ruff/).

To check for linter and formatter errors, run

```
make lint
```

To apply automatic fixes, run

```
make lint-fix
```
Expand All @@ -271,11 +273,17 @@ Run type checker:
- Run with `pyright`
- Configure in [pyproject.toml](../pyproject.toml) under `[tool.pyright]`

### VSCode recommendations

- Install the [Ruff VSCode Extension](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff)
- Copy the contents of `.vscode/settings.default.json` into `.vscode/settings.json`

## Setting up the Environment Variables

**Please confirm that you have at least one configuration of the Cohere Platform, SageMaker, Bedrock or Azure.**

You have two methods to set up the environment variables:

1. Run `make setup` and follow the instructions to configure it.
2. Copy the contents of `configuration.template.yaml` and `secrets.template.yaml` files to new `configuration.yaml` and `secrets.yaml` files.

Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.pyright]
include = ["src/backend/services/metrics.py"]
include = [
"src/backend/services/metrics.py",
"src/backend/tools/google_drive/sync/actions/",
]
defineConstant = { DEBUG = true }
reportMissingImports = true
reportMissingTypeStubs = false
Expand Down
9 changes: 9 additions & 0 deletions src/backend/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
target_metadata = Base.metadata


def include_object(object, name, type_, reflected, compare_to):
if type_ == "table" and reflected and compare_to is None:
return False
else:
return True


def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
Expand All @@ -45,6 +52,7 @@ def run_migrations_offline() -> None:
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
include_object=include_object,
)

with context.begin_transaction():
Expand Down Expand Up @@ -77,6 +85,7 @@ def process_revision_directives(context, revision, directives):
connection=connection,
target_metadata=target_metadata,
process_revision_directives=process_revision_directives,
include_object=include_object,
)

with context.begin_transaction():
Expand Down
38 changes: 38 additions & 0 deletions src/backend/alembic/versions/2024_08_22_ac3933258035_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Revision ID: ac3933258035
Revises: 08bcb9a24d9b
Create Date: 2024-08-22 20:34:37.547325
"""
from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = 'ac3933258035'
down_revision: Union[str, None] = '08bcb9a24d9b'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('agent_tasks',
sa.Column('agent_id', sa.String(), nullable=False),
sa.Column('task_id', sa.String(), nullable=False),
sa.Column('id', sa.String(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('agent_id', 'task_id', name='unique_agent_task')
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('agent_tasks')
# ### end Alembic commands ###
28 changes: 28 additions & 0 deletions src/backend/crud/agent_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import List

from sqlalchemy.orm import Session

from backend.database_models.agent_task import AgentTask, SyncCeleryTaskMeta
from backend.services.logger.utils import LoggerFactory
from backend.services.transaction import validate_transaction

logger = LoggerFactory().get_logger()


@validate_transaction
def create_agent_task(db: Session, agent_id: str, task_id: str) -> AgentTask:
agent_task = AgentTask(agent_id=agent_id, task_id=task_id)
db.add(agent_task)
db.commit()
db.refresh(agent_task)
return agent_task


@validate_transaction
def get_agent_tasks_by_agent_id(db: Session, agent_id: str) -> List[SyncCeleryTaskMeta]:
return (
db.query(SyncCeleryTaskMeta)
.join(AgentTask, AgentTask.task_id == SyncCeleryTaskMeta.task_id)
.filter(AgentTask.agent_id == agent_id)
.all()
)
1 change: 1 addition & 0 deletions src/backend/database_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# ruff: noqa
from backend.database_models.agent import *
from backend.database_models.agent_task import *
from backend.database_models.agent_tool_metadata import *
from backend.database_models.base import *
from backend.database_models.blacklist import *
Expand Down
43 changes: 43 additions & 0 deletions src/backend/database_models/agent_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from sqlalchemy import (
DateTime,
ForeignKey,
Integer,
LargeBinary,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.orm import Mapped, mapped_column

from backend.database_models.base import Base, MinimalBase


class SyncCeleryTaskMeta(MinimalBase):
__tablename__ = "sync_celery_taskmeta"

id: Mapped[int] = mapped_column(Integer, primary_key=True)
task_id: Mapped[str] = mapped_column(String(155), unique=True)
status: Mapped[str] = mapped_column(String(50))
result: Mapped[bytes] = mapped_column(LargeBinary)
date_done: Mapped[DateTime] = mapped_column(DateTime)
traceback: Mapped[str] = mapped_column(Text)
name: Mapped[str] = mapped_column(String(155))
args: Mapped[bytes] = mapped_column(LargeBinary)
kwargs: Mapped[bytes] = mapped_column(LargeBinary)
worker: Mapped[str] = mapped_column(String(155))
retries: Mapped[int] = mapped_column(Integer)
queue: Mapped[str] = mapped_column(String(155))


class AgentTask(Base):
__tablename__ = "agent_tasks"

agent_id: Mapped[str] = mapped_column(
ForeignKey("agents.id", ondelete="CASCADE"), nullable=False
)

task_id: Mapped[str] = mapped_column(nullable=False)

__table_args__ = (
UniqueConstraint("agent_id", "task_id", name="unique_agent_task"),
)
4 changes: 4 additions & 0 deletions src/backend/database_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __new__(cls, *args, **kwargs):
return object.__new__(cls)


class MinimalBase(DeclarativeBase):
pass


class Base(DeclarativeBase):
id = mapped_column(
String,
Expand Down
23 changes: 21 additions & 2 deletions src/backend/routers/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Optional
from typing import List, Optional

from fastapi import APIRouter, Depends, HTTPException
from fastapi import File as RequestFile
Expand All @@ -11,6 +11,7 @@
from backend.crud import agent as agent_crud
from backend.crud import agent_tool_metadata as agent_tool_metadata_crud
from backend.crud import snapshot as snapshot_crud
from backend.crud.agent_task import get_agent_tasks_by_agent_id
from backend.database_models.agent import Agent as AgentModel
from backend.database_models.agent_tool_metadata import (
AgentToolMetadata as AgentToolMetadataModel,
Expand All @@ -20,6 +21,7 @@
from backend.schemas.agent import (
Agent,
AgentPublic,
AgentTaskResponse,
AgentToolMetadata,
AgentToolMetadataPublic,
AgentVisibility,
Expand All @@ -40,6 +42,7 @@
agent_to_metrics_agent,
)
from backend.services.agent import (
parse_task,
raise_db_error,
validate_agent_exists,
validate_agent_tool_metadata_exists,
Expand Down Expand Up @@ -246,7 +249,6 @@ async def get_agent_by_id(
agent_schema = Agent.model_validate(agent)
ctx.with_agent(agent_schema)
ctx.with_metrics_agent(agent_to_metrics_agent(agent))

return agent


Expand Down Expand Up @@ -278,6 +280,23 @@ async def get_agent_deployments(
]


@router.get(
"/{agent_id}/tasks",
response_model=List[AgentTaskResponse],
dependencies=[
Depends(validate_user_header),
],
)
async def get_agent_tasks(
agent_id: str,
session: DBSessionDep,
ctx: Context = Depends(get_context),
) -> List[AgentTaskResponse]:
raw_tasks = get_agent_tasks_by_agent_id(session, agent_id)
tasks = [parse_task(t) for t in raw_tasks]
return tasks


@router.put(
"/{agent_id}",
response_model=AgentPublic,
Expand Down
12 changes: 11 additions & 1 deletion src/backend/schemas/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
from enum import StrEnum
from typing import Optional
from typing import Any, Dict, Optional

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -110,6 +110,16 @@ class ListAgentsResponse(BaseModel):
agents: list[Agent]


class AgentTaskResponse(BaseModel):
task_id: str
status: str
result: Optional[Dict[str, Any]] = None
date_done: str
exception_snippet: Optional[str] = None
name: str
retries: int


class UpdateAgentRequest(BaseModel):
name: Optional[str] = None
version: Optional[int] = None
Expand Down
27 changes: 27 additions & 0 deletions src/backend/services/agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import pickle

from fastapi import HTTPException

from backend.crud import agent as agent_crud
from backend.crud import agent_tool_metadata as agent_tool_metadata_crud
from backend.database_models.agent import Agent, AgentToolMetadata
from backend.database_models.agent_task import SyncCeleryTaskMeta
from backend.database_models.database import DBSessionDep
from backend.schemas.agent import AgentTaskResponse

TASK_TRACE_PREVIEW_LIMIT = 200

def validate_agent_exists(session: DBSessionDep, agent_id: str, user_id: str) -> Agent:
agent = agent_crud.get_agent_by_id(session, agent_id, user_id)
Expand Down Expand Up @@ -42,3 +47,25 @@ def raise_db_error(e: Exception, type: str, name: str):
)

raise HTTPException(status_code=500, detail=str(e))


def parse_task(t: SyncCeleryTaskMeta) -> AgentTaskResponse:
result = None
exception_snippet = None
if t.status == "SUCCESS":
result = pickle.loads(t.result)
if t.status == "FAILURE":
trace_lines = t.traceback.split("\n")
if len(trace_lines) >= 2:
# first 200 characters of the exception
exception_snippet = trace_lines[-2][:TASK_TRACE_PREVIEW_LIMIT] + "...check logs for details"

return AgentTaskResponse(
task_id=t.task_id,
status=t.status,
name=t.name,
retries=t.retries,
result=result,
exception_snippet=exception_snippet,
date_done=str(t.date_done),
)
Loading

0 comments on commit d21802a

Please sign in to comment.