Skip to content

Commit

Permalink
Update api.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez authored Jul 26, 2024
1 parent d60b550 commit 5bfc0e3
Showing 1 changed file with 18 additions and 33 deletions.
51 changes: 18 additions & 33 deletions servers/agent/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,7 @@
from pydantic import BaseModel
from swarms import Agent, Anthropic, GPT4o, GPT4VisionAPI, OpenAIChat
from swarms.utils.loguru_logger import logger

from swarms_cloud.schema.cog_vlm_schemas import (
ChatCompletionResponse,
UsageInfo,
)

from swarms_cloud.schema.cog_vlm_schemas import ChatCompletionResponse, UsageInfo

# Define the input model using Pydantic
class AgentInput(BaseModel):
Expand All @@ -35,16 +30,15 @@ class AgentInput(BaseModel):
context_length: int = 8192
task: str = None


# Define the input model using Pydantic
# Define the output model using Pydantic
class AgentOutput(BaseModel):
agent: AgentInput
completions: ChatCompletionResponse

# Define the available models
AVAILABLE_MODELS = ["OpenAIChat", "GPT4o", "GPT4VisionAPI", "Anthropic"]

async def count_tokens(
text: str,
):
def count_tokens(text: str):
try:
# Get the encoding for the specific model
encoding = tiktoken.get_encoding("gpt-4o")
Expand All @@ -59,8 +53,7 @@ async def count_tokens(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))


async def model_router(model_name: str):
def model_router(model_name: str):
"""
Function to switch to the specified model.
Expand Down Expand Up @@ -89,11 +82,10 @@ async def model_router(model_name: str):
llm = Anthropic(anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"))
else:
# Invalid model name
pass
raise HTTPException(status_code=400, detail=f"Invalid model name: {model_name}")

return llm


# Create a FastAPI app
app = FastAPI(debug=True)

Expand All @@ -106,20 +98,15 @@ async def model_router(model_name: str):
allow_headers=["*"],
)

@app.get("/v1/models", response_model=List[str])
async def list_models():
"""
An endpoint to list available models. It returns a list of model names.
This is useful for clients to query and understand what models are available for use.
"""
return AVAILABLE_MODELS

# @app.get("/v1/models", response_model=ModelList)
# async def list_models():
# """
# An endpoint to list available models. It returns a list of model cards.
# This is useful for clients to query and understand what models are available for use.
# """
# model_card = ModelCard(
# id="cogvlm-chat-17b"
# ) # can be replaced by your model id like cogagent-chat-18b
# return ModelList(data=[model_card])


@app.post("v1/agent/completions", response_model=AgentOutput)
@app.post("/v1/agent/completions", response_model=AgentOutput)
async def agent_completions(agent_input: AgentInput):
try:
logger.info(f"Received request: {agent_input}")
Expand Down Expand Up @@ -149,10 +136,9 @@ async def agent_completions(agent_input: AgentInput):
completions = await agent.run(agent_input.task)

logger.info(f"Completions: {completions}")
all_input_tokens, output_tokens = await asyncio.gather(
count_tokens(agent.short_memory.return_history_as_string()),
count_tokens(completions),
)
input_history = agent.short_memory.return_history_as_string()
all_input_tokens = count_tokens(input_history)
output_tokens = count_tokens(completions)

logger.info(f"Token counts: {all_input_tokens}, {output_tokens}")

Expand Down Expand Up @@ -183,7 +169,6 @@ async def agent_completions(agent_input: AgentInput):
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))


if __name__ == "__main__":
import uvicorn

Expand Down

0 comments on commit 5bfc0e3

Please sign in to comment.