Skip to content

Commit

Permalink
[FEAT]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Apr 16, 2024
1 parent b8e4930 commit a48e259
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 5 deletions.
160 changes: 160 additions & 0 deletions servers/swarm_agents/normal_deploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from dotenv import load_dotenv
import os

import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from sse_starlette.sse import EventSourceResponse

from swarms_cloud.schema.cog_vlm_schemas import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessageResponse,
ModelCard,
ModelList,
UsageInfo,
)

# from exa.structs.parallelize_models_gpus import prepare_model_for_ddp_inference

# Load environment variables from .env file
load_dotenv()

# Environment variables
MODEL_PATH = os.environ.get("COGVLM_MODEL_PATH", "THUDM/cogvlm-chat-hf")
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", "lmsys/vicuna-7b-v1.5")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
QUANT_ENABLED = os.environ.get("QUANT_ENABLED", True)

# Create a FastAPI app
app = FastAPI(
title="Swarms Cloud API",
description="A simple API server for Swarms Cloud",
debug=True,
version="1.0",
)


# Load the middleware to handle CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)


@app.get("/v1/models", response_model=ModelList)
async def list_models():
"""
An endpoint to list available models. It returns a list of model cards.
This is useful for clients to query and understand what models are available for use.
"""
model_card = ModelCard(
id="cogvlm-chat-17b"
) # can be replaced by your model id like cogagent-chat-18b
return ModelList(data=[model_card])


@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(
request: ChatCompletionRequest, # token: str = Depends(authenticate_user)
):
try:
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")

# print(f"Request: {request}")
dict(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
echo=False,
stream=request.stream,
)

if request.stream:
# generate = predict(request.model, gen_params)
generate = None
return EventSourceResponse(generate, media_type="text/event-stream")

# Generate response
# response = generate_cogvlm(model, tokenizer, gen_params)
response = None

usage = UsageInfo()

# ChatMessageResponse
message = ChatMessageResponse(
role="assistant",
content=response["text"],
)

# # # Log the entry to supabase
# entry = ModelAPILogEntry(
# user_id=fetch_api_key_info(token),
# model_id="41a2869c-5f8d-403f-83bb-1f06c56bad47",
# input_tokens=count_tokens(request.messages, tokenizer, request.model),
# output_tokens=count_tokens(response["text"], tokenizer, request.model),
# all_cost=calculate_pricing(
# texts=[message.content], tokenizer=tokenizer, rate_per_million=15.0
# ),
# input_cost=calculate_pricing(
# texts=[message.content], tokenizer=tokenizer, rate_per_million=15.0
# ),
# output_cost=calculate_pricing(
# texts=response["text"], tokenizer=tokenizer, rate_per_million=15.0
# )
# * 5,
# messages=request.messages,
# # temperature=request.temperature,
# top_p=request.top_p,
# # echo=request.echo,
# stream=request.stream,
# repetition_penalty=request.repetition_penalty,
# max_tokens=request.max_tokens,
# )

# # Log the entry to supabase
# log_to_supabase(entry=entry)

# ChatCompletionResponseChoice
logger.debug(f"==== message ====\n{message}")
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
)

# task_usage = UsageInfo.model_validate(response["usage"])
task_usage = UsageInfo.parse_obj(response["usage"])
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)

out = ChatCompletionResponse(
model=request.model,
choices=[choice_data],
object="chat.completion",
usage=usage,
)

return out
except Exception as e:
logger.error(f"Error: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")




if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=int(os.environ.get("SWARM_AGENT_API_PORT", 8000)),
log_level="info",
use_colors=True,
)
68 changes: 68 additions & 0 deletions servers/swarm_agents/omni_modal_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from swarms import Agent, Anthropic, tool

# Model
llm = Anthropic(
temperature=0.1,
)


# Tools
@tool
def text_to_video(task: str):
"""
Converts a given text task into an animated video.
Args:
task (str): The text task to be converted into a video.
Returns:
str: The path to the exported GIF file.
"""
import torch
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
from diffusers.utils import export_to_gif
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

device = "cuda"
dtype = torch.float16

step = 4 # Options: [1,2,4,8]
repo = "ByteDance/AnimateDiff-Lightning"
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
base = "emilianJR/epiCRealism" # Choose to your favorite base model.

adapter = MotionAdapter().to(device, dtype)
adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device))
pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")

output = pipe(prompt=task, guidance_scale=1.0, num_inference_steps=step)
out = export_to_gif(output.frames[0], "animation.gif")
return out



# Agent
agent = Agent(
agent_name="Devin",
system_prompt=(
"Autonomous agent that can interact with humans and other"
" agents. Be Helpful and Kind. Use the tools provided to"
" assist the user. Return all code in markdown format."
),
llm=llm,
max_loops="auto",
autosave=True,
dashboard=False,
streaming_on=True,
verbose=True,
stopping_token="<DONE>",
interactive=True,
tools=[text_to_video],
code_interpreter=True,
)

# Run the agent
out = agent("Create a vide of a girl coding AI wearing hijab")
print(out)
10 changes: 5 additions & 5 deletions servers/vllm_llm/sky_serve.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ service:
# Fields below describe each replica.
resources:
accelerators: {L4:8, A10g:8, A100:4, A100:8, A100-80GB:2, A100-80GB:4, A100-80GB:8}
cpus: 32+
memory: 512+
use_spot: True
disk_size: 512 # Ensure model checkpoints (~246GB) can fit.
disk_tier: best
# cpus: 32+
# memory: 512+
# use_spot: True
# disk_size: 512 # Ensure model checkpoints (~246GB) can fit.
# disk_tier: best
ports: 8080 # Expose to internet traffic.

# workdir: ~/swarms-cloud/servers/cogvlm
Expand Down

0 comments on commit a48e259

Please sign in to comment.