diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index b94860d37..8bbd7a241 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -91,6 +91,7 @@ jobs:
ENV: DEV
PLAIN_OUTPUT: True
REDIS_URL: "localhost:6379"
+ IS_TESTING: True
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
diff --git a/docker-compose-dev.yaml b/docker-compose-dev.yaml
index 94044916b..66a7086d1 100644
--- a/docker-compose-dev.yaml
+++ b/docker-compose-dev.yaml
@@ -73,4 +73,4 @@ networks:
driver: bridge
volumes:
superagi_postgres_data:
- redis_data:
\ No newline at end of file
+ redis_data:
diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js
index 1fb1f5d48..3f7f62b30 100644
--- a/gui/pages/Content/Agents/AgentCreate.js
+++ b/gui/pages/Content/Agents/AgentCreate.js
@@ -534,7 +534,7 @@ export default function AgentCreate({
setEditButtonClicked(true);
agentData.agent_id = editAgentId;
const name = agentData.name
- const adjustedDate = new Date((new Date()).getTime() + 6*24*60*60*1000 - 1*60*1000);
+ const adjustedDate = new Date((new Date()).getTime());
const formattedDate = `${adjustedDate.getDate()} ${['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December'][adjustedDate.getMonth()]} ${adjustedDate.getFullYear()} ${adjustedDate.getHours().toString().padStart(2, '0')}:${adjustedDate.getMinutes().toString().padStart(2, '0')}`;
agentData.name = "Run " + formattedDate
addAgentRun(agentData)
diff --git a/gui/pages/Content/Agents/Agents.module.css b/gui/pages/Content/Agents/Agents.module.css
index 62d782ef9..c6553a8f9 100644
--- a/gui/pages/Content/Agents/Agents.module.css
+++ b/gui/pages/Content/Agents/Agents.module.css
@@ -429,4 +429,35 @@
color: #888888 !important;
text-decoration: line-through;
pointerEvents: none !important;
+}
+
+.modal_buttons{
+ display: flex;
+ justify-content: flex-end;
+ margin-top: 20px
+}
+
+.modal_info_class{
+ margin-left: -5px;
+ margin-right: 5px;
+}
+
+.table_contents{
+ display: flex;
+ flex-direction: column;
+ align-items: center;
+ justify-content: center;
+ margin-top: 40px;
+ width: 100%
+}
+
+.create_settings_button{
+ display: flex;
+ justify-content: center;
+ align-items: center;
+ margin-top: 10px
+}
+
+.button_margin{
+ margin-top: -10px;
}
\ No newline at end of file
diff --git a/gui/pages/Content/Marketplace/Market.module.css b/gui/pages/Content/Marketplace/Market.module.css
index 545146d52..4bd162ade 100644
--- a/gui/pages/Content/Marketplace/Market.module.css
+++ b/gui/pages/Content/Marketplace/Market.module.css
@@ -515,3 +515,17 @@
overflow-y:scroll;
overflow-x:hidden;
}
+
+.settings_tab_button_clicked{
+ background: #454254;
+ padding-right: 15px
+}
+
+.settings_tab_button{
+ background: transparent;
+ padding-right: 15px
+}
+
+.settings_tab_img{
+ margin-top: -1px;
+}
diff --git a/gui/pages/Dashboard/Settings/ApiKeys.js b/gui/pages/Dashboard/Settings/ApiKeys.js
new file mode 100644
index 000000000..61654b4ca
--- /dev/null
+++ b/gui/pages/Dashboard/Settings/ApiKeys.js
@@ -0,0 +1,286 @@
+import React, {useState, useEffect, useRef} from 'react';
+import {ToastContainer, toast} from 'react-toastify';
+import 'react-toastify/dist/ReactToastify.css';
+import agentStyles from "@/pages/Content/Agents/Agents.module.css";
+import {
+ createApiKey, deleteApiKey,
+ editApiKey, getApiKeys,
+} from "@/pages/api/DashboardService";
+import {EventBus} from "@/utils/eventBus";
+import {createInternalId, loadingTextEffect, preventDefault, removeTab, returnToolkitIcon} from "@/utils/utils";
+import Image from "next/image";
+import styles from "@/pages/Content/Marketplace/Market.module.css";
+import styles1 from "@/pages/Content/Knowledge/Knowledge.module.css";
+
+export default function ApiKeys() {
+ const [apiKeys, setApiKeys] = useState([]);
+ const [keyName, setKeyName] = useState('');
+ const [editKey, setEditKey] = useState('');
+ const apiKeyRef = useRef(null);
+ const editKeyRef = useRef(null);
+ const [editKeyId, setEditKeyId] = useState(-1);
+ const [deleteKey, setDeleteKey] = useState('')
+ const [isLoading, setIsLoading] = useState(true)
+ const [activeDropdown, setActiveDropdown] = useState(null);
+ const [editModal, setEditModal] = useState(false);
+ const [deleteKeyId, setDeleteKeyId] = useState(-1);
+ const [deleteModal, setDeleteModal] = useState(false);
+ const [createModal, setCreateModal] = useState(false);
+ const [displayModal, setDisplayModal] = useState(false);
+ const [apiKeyGenerated, setApiKeyGenerated] = useState('');
+ const [loadingText, setLoadingText] = useState("Loading Api Keys");
+
+
+
+ useEffect(() => {
+ loadingTextEffect('Loading Api Keys', setLoadingText, 500);
+ fetchApiKeys()
+ }, []);
+
+
+ const handleModelApiKey = (event) => {
+ setKeyName(event.target.value);
+ };
+
+ const handleEditApiKey = (event) => {
+ setEditKey(event.target.value);
+ };
+
+ const createApikey = () => {
+ if(!keyName){
+ toast.error("Enter key name", {autoClose: 1800});
+ return;
+ }
+ createApiKey({name : keyName})
+ .then((response) => {
+ setApiKeyGenerated(response.data.api_key)
+ toast.success("Api Key Generated", {autoClose: 1800});
+ setCreateModal(false);
+ setDisplayModal(true);
+ fetchApiKeys();
+ })
+ .catch((error) => {
+ console.error('Error creating api key', error);
+ });
+ }
+ const handleCopyClick = async () => {
+ if (apiKeyRef.current) {
+ try {
+ await navigator.clipboard.writeText(apiKeyRef.current.value);
+ toast.success("Key Copied", {autoClose: 1800});
+ } catch (err) {
+ toast.error('Failed to Copy', {autoClose: 1800});
+ }
+ }
+ };
+
+ const fetchApiKeys = () => {
+ getApiKeys()
+ .then((response) => {
+ const formattedData = response.data.map(item => {
+ return {
+ ...item,
+ created_at: `${new Date(item.created_at).getDate()}-${["JAN", "FEB", "MAR", "APR", "MAY", "JUN", "JUL", "AUG", "SEP", "OCT", "NOV", "DEC"][new Date(item.created_at).getMonth()]}-${new Date(item.created_at).getFullYear()}`
+ };
+ });
+ setApiKeys(formattedData)
+ setIsLoading(false)
+ })
+ .catch((error) => {
+ console.error('Error fetching Api Keys', error);
+ });
+ }
+
+ const handleEditClick = () => {
+ if(editKeyRef.current.value.length <1){
+ toast.error("Enter valid key name", {autoClose: 1800});
+ return;
+ }
+ editApiKey({id: editKeyId,name : editKey})
+ .then((response) => {
+ toast.success("Api Key Edited", {autoClose: 1800});
+ fetchApiKeys();
+ setEditModal(false);
+ setEditKey('')
+ setEditKeyId(-1)
+ })
+ .catch((error) => {
+ console.error('Error editing api key', error);
+ });
+ }
+
+ const handleDeleteClick = () => {
+ deleteApiKey(deleteKeyId)
+ .then((response) => {
+ toast.success("Api Key Deleted", {autoClose: 1800});
+ fetchApiKeys();
+ setDeleteModal(false);
+ setDeleteKeyId(-1)
+ setDeleteKey('')
+ })
+ .catch((error) => {
+ toast.error("Error deleting api key", {autoClose: 1800});
+ console.error('Error deleting api key', error);
+ });
+ }
+
+ return (<>
+
setDropdown(true)}
+
setDropdown(true)}
onMouseLeave={() => setDropdown(false)}>
- setDropdown(false)}>{userName}
+ {userName && setDropdown(false)}>{userName} }
Logout
}
diff --git a/gui/pages/_app.css b/gui/pages/_app.css
index c78b07085..7be317c86 100644
--- a/gui/pages/_app.css
+++ b/gui/pages/_app.css
@@ -997,13 +997,16 @@ p {
.r_0{right: 0}
.w_120p{width: 120px}
+.w_4{width: 4%}
.w_6{width: 6%}
.w_10{width: 10%}
.w_12{width: 12%}
+.w_18{width: 18%}
.w_20{width: 20%}
.w_22{width: 22%}
.w_35{width: 35%}
.w_56{width: 56%}
+.w_60{width: 60%}
.w_100{width: 100%}
.w_inherit{width: inherit}
.w_fit_content{width:fit-content}
@@ -1052,6 +1055,7 @@ p {
.gap_16{gap:16px;}
.gap_20{gap:20px;}
+.border_top_none{border-top: none;}
.border_radius_8{border-radius: 8px;}
.border_radius_25{border-radius: 25px;}
@@ -1072,12 +1076,15 @@ p {
.padding_0_8{padding: 0px 8px;}
.padding_0_15{padding: 0px 15px;}
+.flex_1{flex: 1}
.flex_wrap{flex-wrap: wrap;}
.mix_blend_mode{mix-blend-mode: exclusion;}
.ff_sourceCode{font-family: 'Source Code Pro'}
+.rotate_90{transform: rotate(90deg)}
+
/*------------------------------- My ROWS AND COLUMNS -------------------------------*/
.my_rows {
display: flex;
@@ -1673,3 +1680,27 @@ tr{
.history_box_selected{
background: #474255;
}
+
+.loading_container{
+ display: flex;
+ justify-content: center;
+ align-items: center;
+ height: 50vh
+}
+
+.loading_text{
+ font-size: 16px;
+ font-family: 'Source Code Pro';
+}
+
+.table_container{
+ background: #272335;
+ border-radius: 8px;
+ margin-top:15px
+}
+
+.top_bar_profile_dropdown{
+ display: flex;
+ flex-direction: row;
+ justify-content: center;
+}
diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js
index ed9b0a1ef..625c6c2af 100644
--- a/gui/pages/api/DashboardService.js
+++ b/gui/pages/api/DashboardService.js
@@ -303,3 +303,20 @@ export const fetchKnowledgeTemplateOverview = (knowledgeName) => {
export const installKnowledgeTemplate = (knowledgeName, indexId) => {
return api.get(`/knowledges/install/${knowledgeName}/index/${indexId}`);
};
+
+export const createApiKey = (apiName) => {
+ return api.post(`/api-keys`, apiName);
+};
+
+export const getApiKeys = () => {
+ return api.get(`/api-keys`);
+};
+
+export const editApiKey = (apiDetails) => {
+ return api.put(`/api-keys`, apiDetails);
+};
+
+export const deleteApiKey = (apiId) => {
+ return api.delete(`/api-keys/${apiId}`);
+};
+
diff --git a/gui/public/images/copy_icon.svg b/gui/public/images/copy_icon.svg
new file mode 100644
index 000000000..46644e0dc
--- /dev/null
+++ b/gui/public/images/copy_icon.svg
@@ -0,0 +1,3 @@
+
+
+
diff --git a/gui/public/images/key_white.svg b/gui/public/images/key_white.svg
new file mode 100644
index 000000000..9b35b92e0
--- /dev/null
+++ b/gui/public/images/key_white.svg
@@ -0,0 +1,3 @@
+
+
+
diff --git a/gui/public/images/weaviate.svg b/gui/public/images/weaviate.svg
index 1f1ee9788..9b9f5260c 100644
--- a/gui/public/images/weaviate.svg
+++ b/gui/public/images/weaviate.svg
@@ -6,4 +6,4 @@
-
\ No newline at end of file
+
diff --git a/main.py b/main.py
index f6227a179..4c698eb9a 100644
--- a/main.py
+++ b/main.py
@@ -38,6 +38,9 @@
from superagi.controllers.vector_dbs import router as vector_dbs_router
from superagi.controllers.vector_db_indices import router as vector_db_indices_router
from superagi.controllers.marketplace_stats import router as marketplace_stats_router
+from superagi.controllers.api_key import router as api_key_router
+from superagi.controllers.api.agent import router as api_agent_router
+from superagi.controllers.webhook import router as web_hook_router
from superagi.helper.tool_helper import register_toolkits, register_marketplace_toolkits
from superagi.lib.logger import logger
from superagi.llms.google_palm import GooglePalm
@@ -50,7 +53,6 @@
from superagi.models.workflows.agent_workflow import AgentWorkflow
from superagi.models.workflows.iteration_workflow import IterationWorkflow
from superagi.models.workflows.iteration_workflow_step import IterationWorkflowStep
-
app = FastAPI()
database_url = get_config('POSTGRES_URL')
@@ -113,7 +115,9 @@
app.include_router(vector_dbs_router, prefix="/vector_dbs")
app.include_router(vector_db_indices_router, prefix="/vector_db_indices")
app.include_router(marketplace_stats_router, prefix="/marketplace")
-
+app.include_router(api_key_router, prefix="/api-keys")
+app.include_router(api_agent_router,prefix="/v1/agent")
+app.include_router(web_hook_router,prefix="/webhook")
# in production you can use Settings management
# from pydantic to get secret key from .env
@@ -370,3 +374,4 @@ def github_client_id():
# # __________________TO RUN____________________________
# # uvicorn main:app --host 0.0.0.0 --port 8001 --reload
+
diff --git a/migrations/versions/446884dcae58_add_api_key_and_web_hook.py b/migrations/versions/446884dcae58_add_api_key_and_web_hook.py
new file mode 100644
index 000000000..c4b353756
--- /dev/null
+++ b/migrations/versions/446884dcae58_add_api_key_and_web_hook.py
@@ -0,0 +1,65 @@
+"""add api_key and web_hook
+
+Revision ID: 446884dcae58
+Revises: 71e3980d55f5
+Create Date: 2023-07-29 10:55:21.714245
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '446884dcae58'
+down_revision = '2fbd6472112c'
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('api_keys',
+ sa.Column('id', sa.Integer(), nullable=False),
+ sa.Column('org_id', sa.Integer(), nullable=True),
+ sa.Column('name', sa.String(), nullable=True),
+ sa.Column('key', sa.String(), nullable=True),
+ sa.Column('created_at', sa.DateTime(), nullable=True),
+ sa.Column('updated_at', sa.DateTime(), nullable=True),
+ sa.Column('is_expired',sa.Boolean(),nullable=True,default=False),
+ sa.PrimaryKeyConstraint('id')
+ )
+ op.create_table('webhooks',
+ sa.Column('id', sa.Integer(), nullable=False),
+ sa.Column('name', sa.String(), nullable=True),
+ sa.Column('org_id', sa.Integer(), nullable=True),
+ sa.Column('url', sa.String(), nullable=True),
+ sa.Column('headers', sa.JSON(), nullable=True),
+ sa.Column('created_at', sa.DateTime(), nullable=True),
+ sa.Column('updated_at', sa.DateTime(), nullable=True),
+ sa.Column('is_deleted',sa.Boolean(),nullable=True),
+ sa.PrimaryKeyConstraint('id')
+ )
+ op.create_table('webhook_events',
+ sa.Column('id', sa.Integer(), nullable=False),
+ sa.Column('agent_id', sa.Integer(), nullable=True),
+ sa.Column('run_id', sa.Integer(), nullable=True),
+ sa.Column('event', sa.String(), nullable=True),
+ sa.Column('status', sa.String(), nullable=True),
+ sa.Column('errors', sa.Text(), nullable=True),
+ sa.Column('created_at', sa.DateTime(), nullable=True),
+ sa.Column('updated_at', sa.DateTime(), nullable=True),
+ sa.PrimaryKeyConstraint('id')
+ )
+
+ #add index *********************
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+
+ op.drop_table('webhooks')
+ op.drop_table('api_keys')
+ op.drop_table('webhook_events')
+
+ # ### end Alembic commands ###
diff --git a/superagi/controllers/agent_execution.py b/superagi/controllers/agent_execution.py
index 743aae05b..1a485562c 100644
--- a/superagi/controllers/agent_execution.py
+++ b/superagi/controllers/agent_execution.py
@@ -108,8 +108,7 @@ def create_agent_execution(agent_execution: AgentExecutionIn,
agent_execution_configs[agent_config.key] = []
elif agent_config.key == "constraints":
if agent_config.value:
- constraints = [item.strip('"') for item in agent_config.value.strip('{}').split(',')]
- agent_execution_configs[agent_config.key] = constraints
+ agent_execution_configs[agent_config.key] = agent_config.value
else:
agent_execution_configs[agent_config.key] = []
else:
diff --git a/superagi/controllers/api/agent.py b/superagi/controllers/api/agent.py
new file mode 100644
index 000000000..a71e60a8f
--- /dev/null
+++ b/superagi/controllers/api/agent.py
@@ -0,0 +1,326 @@
+from fastapi import APIRouter
+from fastapi import HTTPException, Depends ,Security
+
+from fastapi_sqlalchemy import db
+from pydantic import BaseModel
+
+from superagi.worker import execute_agent
+from superagi.helper.auth import validate_api_key,get_organisation_from_api_key
+from superagi.models.agent import Agent
+from superagi.models.agent_execution_config import AgentExecutionConfiguration
+from superagi.models.agent_config import AgentConfiguration
+from superagi.models.agent_schedule import AgentSchedule
+from superagi.models.project import Project
+from superagi.models.workflows.agent_workflow import AgentWorkflow
+from superagi.models.agent_execution import AgentExecution
+from superagi.models.organisation import Organisation
+from superagi.models.resource import Resource
+from superagi.controllers.types.agent_with_config import AgentConfigExtInput,AgentConfigUpdateExtInput
+from superagi.models.workflows.iteration_workflow import IterationWorkflow
+from superagi.helper.s3_helper import S3Helper
+from datetime import datetime
+from typing import Optional,List
+from superagi.models.toolkit import Toolkit
+from superagi.apm.event_handler import EventHandler
+from superagi.config.config import get_config
+router = APIRouter()
+
+class AgentExecutionIn(BaseModel):
+ name: Optional[str]
+ goal: Optional[List[str]]
+ instruction: Optional[List[str]]
+
+ class Config:
+ orm_mode = True
+
+class RunFilterConfigIn(BaseModel):
+ run_ids:Optional[List[int]]
+ run_status_filter:Optional[str]
+
+ class Config:
+ orm_mode = True
+
+class ExecutionStateChangeConfigIn(BaseModel):
+ run_ids:Optional[List[int]]
+
+ class Config:
+ orm_mode = True
+
+class RunIDConfig(BaseModel):
+ run_ids:List[int]
+
+ class Config:
+ orm_mode = True
+
+@router.post("", status_code=200)
+def create_agent_with_config(agent_with_config: AgentConfigExtInput,
+ api_key: str = Security(validate_api_key), organisation:Organisation = Depends(get_organisation_from_api_key)):
+ project=Project.find_by_org_id(db.session, organisation.id)
+ try:
+ tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools)
+ except Exception as e:
+ raise HTTPException(status_code=404, detail=str(e))
+
+ agent_with_config.tools=tools_arr
+ agent_with_config.project_id=project.id
+ agent_with_config.exit="No exit criterion"
+ agent_with_config.permission_type="God Mode"
+ agent_with_config.LTM_DB=None
+ db_agent = Agent.create_agent_with_config(db, agent_with_config)
+
+ if agent_with_config.schedule is not None:
+ agent_schedule = AgentSchedule.save_schedule_from_config(db.session, db_agent, agent_with_config.schedule)
+ if agent_schedule is None:
+ raise HTTPException(status_code=500, detail="Failed to schedule agent")
+ EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_with_config.name,
+ 'model': agent_with_config.model}, db_agent.id,
+ organisation.id if organisation else 0)
+ db.session.commit()
+ return {
+ "agent_id": db_agent.id
+ }
+
+ start_step = AgentWorkflow.fetch_trigger_step_id(db.session, db_agent.agent_workflow_id)
+ iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session,
+ start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1
+ # Creating an execution with RUNNING status
+ execution = AgentExecution(status='CREATED', last_execution_time=datetime.now(), agent_id=db_agent.id,
+ name="New Run", current_agent_step_id=start_step.id, iteration_workflow_step_id=iteration_step_id)
+ agent_execution_configs = {
+ "goal": agent_with_config.goal,
+ "instruction": agent_with_config.instruction
+ }
+ db.session.add(execution)
+ db.session.commit()
+ db.session.flush()
+ AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=execution,
+ agent_execution_configs=agent_execution_configs)
+
+ organisation = db_agent.get_agent_organisation(db.session)
+ EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_with_config.name,
+ 'model': agent_with_config.model}, db_agent.id,
+ organisation.id if organisation else 0)
+ # execute_agent.delay(execution.id, datetime.now())
+ db.session.commit()
+ return {
+ "agent_id": db_agent.id
+ }
+
+@router.post("/{agent_id}/run",status_code=200)
+def create_run(agent_id:int,agent_execution: AgentExecutionIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)):
+ agent=Agent.get_agent_from_id(db.session,agent_id)
+ if not agent:
+ raise HTTPException(status_code=404, detail="Agent not found")
+ project=Project.find_by_id(db.session, agent.project_id)
+ if project.organisation_id!=organisation.id:
+ raise HTTPException(status_code=404, detail="Agent not found")
+ db_schedule=AgentSchedule.find_by_agent_id(db.session, agent_id)
+ if db_schedule is not None:
+ raise HTTPException(status_code=409, detail="Agent is already scheduled,cannot run")
+ start_step_id = AgentWorkflow.fetch_trigger_step_id(db.session, agent.agent_workflow_id)
+ db_agent_execution=AgentExecution.get_execution_by_agent_id_and_status(db.session, agent_id, "CREATED")
+
+ if db_agent_execution is None:
+ db_agent_execution = AgentExecution(status="RUNNING", last_execution_time=datetime.now(),
+ agent_id=agent_id, name=agent_execution.name, num_of_calls=0,
+ num_of_tokens=0,
+ current_step_id=start_step_id)
+ db.session.add(db_agent_execution)
+ else:
+ db_agent_execution.status = "RUNNING"
+
+ db.session.commit()
+ db.session.flush()
+
+ agent_execution_configs = {}
+ if agent_execution.goal is not None:
+ agent_execution_configs = {
+ "goal": agent_execution.goal,
+ }
+
+ if agent_execution.instruction is not None:
+ agent_execution_configs["instructions"] = agent_execution.instruction,
+
+ if agent_execution_configs != {}:
+ AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=db_agent_execution,
+ agent_execution_configs=agent_execution_configs)
+ EventHandler(session=db.session).create_event('run_created', {'agent_execution_id': db_agent_execution.id,'agent_execution_name':db_agent_execution.name},
+ agent_id, organisation.id if organisation else 0)
+
+ if db_agent_execution.status == "RUNNING":
+ execute_agent.delay(db_agent_execution.id, datetime.now())
+ return {
+ "run_id":db_agent_execution.id
+ }
+
+@router.put("/{agent_id}",status_code=200)
+def update_agent(agent_id: int, agent_with_config: AgentConfigUpdateExtInput,api_key: str = Security(validate_api_key),
+ organisation:Organisation = Depends(get_organisation_from_api_key)):
+
+ db_agent= Agent.get_active_agent_by_id(db.session, agent_id)
+ if not db_agent:
+ raise HTTPException(status_code=404, detail="agent not found")
+
+ project=Project.find_by_id(db.session, db_agent.project_id)
+ if project is None:
+ raise HTTPException(status_code=404, detail="Project not found")
+
+ if project.organisation_id!=organisation.id:
+ raise HTTPException(status_code=404, detail="Agent not found")
+
+ db_execution=AgentExecution.get_execution_by_agent_id_and_status(db.session, agent_id, "RUNNING")
+ if db_execution is not None:
+ raise HTTPException(status_code=409, detail="Agent is already running,please pause and then update")
+
+ db_schedule=AgentSchedule.find_by_agent_id(db.session, agent_id)
+ if db_schedule is not None:
+ raise HTTPException(status_code=409, detail="Agent is already scheduled,cannot update")
+
+ try:
+ tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools)
+ except Exception as e:
+ raise HTTPException(status_code=404,detail=str(e))
+
+ if agent_with_config.schedule is not None:
+ raise HTTPException(status_code=400,detail="Cannot schedule an existing agent")
+ agent_with_config.tools=tools_arr
+ agent_with_config.project_id=project.id
+ agent_with_config.exit="No exit criterion"
+ agent_with_config.permission_type="God Mode"
+ agent_with_config.LTM_DB=None
+
+ for key,value in agent_with_config.dict().items():
+ if hasattr(db_agent,key) and value is not None:
+ setattr(db_agent,key,value)
+ db.session.commit()
+ db.session.flush()
+
+ start_step = AgentWorkflow.fetch_trigger_step_id(db.session, db_agent.agent_workflow_id)
+ iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session,
+ start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1
+ execution = AgentExecution(status='CREATED', last_execution_time=datetime.now(), agent_id=db_agent.id,
+ name="New Run", current_agent_step_id=start_step.id, iteration_workflow_step_id=iteration_step_id)
+ agent_execution_configs = {
+ "goal": agent_with_config.goal,
+ "instruction": agent_with_config.instruction,
+ "tools":agent_with_config.tools,
+ "constraints": agent_with_config.constraints,
+ "iteration_interval": agent_with_config.iteration_interval,
+ "model": agent_with_config.model,
+ "max_iterations": agent_with_config.max_iterations,
+ "agent_workflow": agent_with_config.agent_workflow,
+ }
+ agent_configurations = [
+ AgentConfiguration(agent_id=db_agent.id, key=key, value=str(value))
+ for key, value in agent_execution_configs.items()
+ ]
+ db.session.add_all(agent_configurations)
+ db.session.add(execution)
+ db.session.commit()
+ db.session.flush()
+ AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=execution,
+ agent_execution_configs=agent_execution_configs)
+ db.session.commit()
+
+ return {
+ "agent_id":db_agent.id
+ }
+
+
+@router.post("/{agent_id}/run-status")
+def get_agent_runs(agent_id:int,filter_config:RunFilterConfigIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)):
+ agent= Agent.get_active_agent_by_id(db.session, agent_id)
+ if not agent:
+ raise HTTPException(status_code=404, detail="Agent not found")
+
+ project=Project.find_by_id(db.session, agent.project_id)
+ if project.organisation_id!=organisation.id:
+ raise HTTPException(status_code=404, detail="Agent not found")
+
+ db_execution_arr=[]
+ if filter_config.run_status_filter is not None:
+ filter_config.run_status_filter=filter_config.run_status_filter.upper()
+
+ db_execution_arr=AgentExecution.get_all_executions_by_filter_config(db.session, agent.id, filter_config)
+
+ response_arr=[]
+ for ind_execution in db_execution_arr:
+ response_arr.append({"run_id":ind_execution.id, "status":ind_execution.status})
+
+ return response_arr
+
+
+@router.post("/{agent_id}/pause",status_code=200)
+def pause_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateChangeConfigIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)):
+ agent= Agent.get_active_agent_by_id(db.session, agent_id)
+ if not agent:
+ raise HTTPException(status_code=404, detail="Agent not found")
+
+ project=Project.find_by_id(db.session, agent.project_id)
+ if project.organisation_id!=organisation.id:
+ raise HTTPException(status_code=404, detail="Agent not found")
+
+ #Checking if the run_ids whose output files are requested belong to the organisation
+ if execution_state_change_input.run_ids is not None:
+ try:
+ AgentExecution.validate_run_ids(db.session,execution_state_change_input.run_ids,organisation.id)
+ except Exception as e:
+ raise HTTPException(status_code=404, detail="One or more run_ids not found")
+
+ db_execution_arr=AgentExecution.get_all_executions_by_status_and_agent_id(db.session, agent.id, execution_state_change_input, "RUNNING")
+ for ind_execution in db_execution_arr:
+ ind_execution.status="PAUSED"
+ db.session.commit()
+ db.session.flush()
+ return {
+ "result":"success"
+ }
+
+@router.post("/{agent_id}/resume",status_code=200)
+def resume_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateChangeConfigIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)):
+ agent= Agent.get_active_agent_by_id(db.session, agent_id)
+ if not agent:
+ raise HTTPException(status_code=404, detail="Agent not found")
+
+ project=Project.find_by_id(db.session, agent.project_id)
+ if project.organisation_id!=organisation.id:
+ raise HTTPException(status_code=404, detail="Agent not found")
+
+ if execution_state_change_input.run_ids is not None:
+ try:
+ AgentExecution.validate_run_ids(db.session,execution_state_change_input.run_ids,organisation.id)
+ except Exception as e:
+ raise HTTPException(status_code=404, detail="One or more run_ids not found")
+
+ db_execution_arr=AgentExecution.get_all_executions_by_status_and_agent_id(db.session, agent.id, execution_state_change_input, "PAUSED")
+ for ind_execution in db_execution_arr:
+ ind_execution.status="RUNNING"
+
+ db.session.commit()
+ db.session.flush()
+ return {
+ "result":"success"
+ }
+
+@router.post("/resources/output",status_code=201)
+def get_run_resources(run_id_config:RunIDConfig,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)):
+ if get_config('STORAGE_TYPE') != "S3":
+ raise HTTPException(status_code=400,detail="This endpoint only works when S3 is configured")
+ run_ids_arr=run_id_config.run_ids
+ if len(run_ids_arr)==0:
+ raise HTTPException(status_code=404,
+ detail=f"No execution_id found")
+ #Checking if the run_ids whose output files are requested belong to the organisation
+ try:
+ AgentExecution.validate_run_ids(db.session,run_ids_arr,organisation.id)
+ except Exception as e:
+ raise HTTPException(status_code=404, detail="One or more run_ids not found")
+
+ db_resources_arr=Resource.find_by_run_ids(db.session, run_ids_arr)
+
+ try:
+ response_obj=S3Helper().get_download_url_of_resources(db_resources_arr)
+ except:
+ raise HTTPException(status_code=401,detail="Invalid S3 credentials")
+ return response_obj
+
diff --git a/superagi/controllers/api_key.py b/superagi/controllers/api_key.py
new file mode 100644
index 000000000..57e5c739b
--- /dev/null
+++ b/superagi/controllers/api_key.py
@@ -0,0 +1,55 @@
+import json
+import uuid
+from fastapi import APIRouter, Body
+from fastapi import HTTPException, Depends
+from fastapi_jwt_auth import AuthJWT
+from fastapi_sqlalchemy import db
+from pydantic import BaseModel
+from superagi.helper.auth import get_user_organisation
+from superagi.helper.auth import check_auth
+from superagi.models.api_key import ApiKey
+from typing import Optional, Annotated
+router = APIRouter()
+
+class ApiKeyIn(BaseModel):
+ id:int
+ name: str
+ class Config:
+ orm_mode = True
+
+class ApiKeyDeleteIn(BaseModel):
+ id:int
+ class Config:
+ orm_mode = True
+
+@router.post("")
+def create_api_key(name: Annotated[str,Body(embed=True)], Authorize: AuthJWT = Depends(check_auth), organisation=Depends(get_user_organisation)):
+ api_key=str(uuid.uuid4())
+ obj=ApiKey(key=api_key,name=name,org_id=organisation.id)
+ db.session.add(obj)
+ db.session.commit()
+ db.session.flush()
+ return {"api_key": api_key}
+
+@router.get("")
+def get_all(Authorize: AuthJWT = Depends(check_auth), organisation=Depends(get_user_organisation)):
+ api_keys=ApiKey.get_by_org_id(db.session, organisation.id)
+ return api_keys
+
+@router.delete("/{api_key_id}")
+def delete_api_key(api_key_id:int, Authorize: AuthJWT = Depends(check_auth)):
+ api_key=ApiKey.get_by_id(db.session, api_key_id)
+ if api_key is None:
+ raise HTTPException(status_code=404, detail="API key not found")
+ ApiKey.delete_by_id(db.session, api_key_id)
+ return {"success": True}
+
+@router.put("")
+def edit_api_key(api_key_in:ApiKeyIn,Authorize: AuthJWT = Depends(check_auth)):
+ api_key=ApiKey.get_by_id(db.session, api_key_in.id)
+ if api_key is None:
+ raise HTTPException(status_code=404, detail="API key not found")
+ ApiKey.update_api_key(db.session, api_key_in.id, api_key_in.name)
+ return {"success": True}
+
+
diff --git a/superagi/controllers/types/agent_with_config.py b/superagi/controllers/types/agent_with_config.py
index f3995b5df..5ce81d211 100644
--- a/superagi/controllers/types/agent_with_config.py
+++ b/superagi/controllers/types/agent_with_config.py
@@ -1,6 +1,6 @@
from pydantic import BaseModel
from typing import List, Optional
-
+from superagi.controllers.types.agent_schedule import AgentScheduleInput
class AgentConfigInput(BaseModel):
name: str
@@ -20,3 +20,45 @@ class AgentConfigInput(BaseModel):
max_iterations: int
user_timezone: Optional[str]
knowledge: Optional[int]
+
+
+
+class AgentConfigExtInput(BaseModel):
+ name: str
+ description: str
+ project_id: Optional[int]
+ goal: List[str]
+ instruction: List[str]
+ agent_workflow: str
+ constraints: List[str]
+ tools: List[dict]
+ LTM_DB:Optional[str]
+ exit: Optional[str]
+ permission_type: Optional[str]
+ iteration_interval: int
+ model: str
+ schedule: Optional[AgentScheduleInput]
+ max_iterations: int
+ user_timezone: Optional[str]
+ knowledge: Optional[int]
+
+class AgentConfigUpdateExtInput(BaseModel):
+ name: Optional[str]
+ description: Optional[str]
+ project_id: Optional[int]
+ goal: Optional[List[str]]
+ instruction: Optional[List[str]]
+ agent_workflow: Optional[str]
+ constraints: Optional[List[str]]
+ tools: Optional[List[dict]]
+ LTM_DB:Optional[str]
+ exit: Optional[str]
+ permission_type: Optional[str]
+ iteration_interval: Optional[int]
+ model: Optional[str]
+ schedule: Optional[AgentScheduleInput]
+ max_iterations: Optional[int]
+ user_timezone: Optional[str]
+ knowledge: Optional[int]
+
+
diff --git a/superagi/controllers/webhook.py b/superagi/controllers/webhook.py
new file mode 100644
index 000000000..0a55bd216
--- /dev/null
+++ b/superagi/controllers/webhook.py
@@ -0,0 +1,60 @@
+from datetime import datetime
+
+from fastapi import APIRouter
+from fastapi import Depends
+from fastapi_jwt_auth import AuthJWT
+from fastapi_sqlalchemy import db
+from pydantic import BaseModel
+
+# from superagi.types.db import AgentOut, AgentIn
+from superagi.helper.auth import check_auth, get_user_organisation
+from superagi.models.webhooks import Webhooks
+
+router = APIRouter()
+
+
+class WebHookIn(BaseModel):
+ name: str
+ url: str
+ headers: dict
+
+ class Config:
+ orm_mode = True
+
+
+class WebHookOut(BaseModel):
+ id: int
+ org_id: int
+ name: str
+ url: str
+ headers: dict
+ is_deleted: bool
+ created_at: datetime
+ updated_at: datetime
+
+ class Config:
+ orm_mode = True
+
+
+# CRUD Operations
+@router.post("/add", response_model=WebHookOut, status_code=201)
+def create_webhook(webhook: WebHookIn, Authorize: AuthJWT = Depends(check_auth),
+ organisation=Depends(get_user_organisation)):
+ """
+ Creates a new webhook
+
+ Args:
+
+ Returns:
+ Agent: An object of Agent representing the created Agent.
+
+ Raises:
+ HTTPException (Status Code=404): If the associated project is not found.
+ """
+ db_webhook = Webhooks(name=webhook.name, url=webhook.url, headers=webhook.headers, org_id=organisation.id,
+ is_deleted=False)
+ db.session.add(db_webhook)
+ db.session.commit()
+ db.session.flush()
+
+ return db_webhook
diff --git a/superagi/helper/auth.py b/superagi/helper/auth.py
index cc02d5643..f916a3ae5 100644
--- a/superagi/helper/auth.py
+++ b/superagi/helper/auth.py
@@ -1,10 +1,14 @@
-from fastapi import Depends, HTTPException
+from fastapi import Depends, HTTPException, Header, Security, status
+from fastapi.security import APIKeyHeader
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
-
+from fastapi.security.api_key import APIKeyHeader
from superagi.config.config import get_config
from superagi.models.organisation import Organisation
from superagi.models.user import User
+from superagi.models.api_key import ApiKey
+from typing import Optional
+from sqlalchemy import or_
def check_auth(Authorize: AuthJWT = Depends()):
@@ -39,6 +43,7 @@ def get_user_organisation(Authorize: AuthJWT = Depends(check_auth)):
organisation = db.session.query(Organisation).filter(Organisation.id == user.organisation_id).first()
return organisation
+
def get_current_user(Authorize: AuthJWT = Depends(check_auth)):
env = get_config("ENV", "DEV")
@@ -50,4 +55,32 @@ def get_current_user(Authorize: AuthJWT = Depends(check_auth)):
# Query the User table to find the user by their email
user = db.session.query(User).filter(User.email == email).first()
- return user
\ No newline at end of file
+ return user
+
+
+api_key_header = APIKeyHeader(name="X-API-Key")
+
+
+def validate_api_key(api_key: str = Security(api_key_header)) -> str:
+ query_result = db.session.query(ApiKey).filter(ApiKey.key == api_key,
+ or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first()
+ if query_result is None:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid or missing API Key",
+ )
+
+ return query_result.key
+
+
+def get_organisation_from_api_key(api_key: str = Security(api_key_header)) -> Organisation:
+ query_result = db.session.query(ApiKey).filter(ApiKey.key == api_key,
+ or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first()
+ if query_result is None:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid or missing API Key",
+ )
+
+ organisation = db.session.query(Organisation).filter(Organisation.id == query_result.org_id).first()
+ return organisation
\ No newline at end of file
diff --git a/superagi/helper/s3_helper.py b/superagi/helper/s3_helper.py
index 2b2d0ba1f..2669b77ed 100644
--- a/superagi/helper/s3_helper.py
+++ b/superagi/helper/s3_helper.py
@@ -5,10 +5,9 @@
from superagi.config.config import get_config
from superagi.lib.logger import logger
+from urllib.parse import unquote
import json
-
-
class S3Helper:
def __init__(self, bucket_name = get_config("BUCKET_NAME")):
"""
@@ -113,4 +112,27 @@ def upload_file_content(self, content, file_path):
try:
self.s3.put_object(Bucket=self.bucket_name, Key=file_path, Body=content)
except:
- raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.")
\ No newline at end of file
+ raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.")
+
+ def get_download_url_of_resources(self,db_resources_arr):
+ s3 = boto3.client(
+ 's3',
+ aws_access_key_id=get_config("AWS_ACCESS_KEY_ID"),
+ aws_secret_access_key=get_config("AWS_SECRET_ACCESS_KEY"),
+ )
+ response_obj={}
+ for db_resource in db_resources_arr:
+ response = self.s3.get_object(Bucket=get_config("BUCKET_NAME"), Key=db_resource.path)
+ content = response["Body"].read()
+ bucket_name = get_config("INSTAGRAM_TOOL_BUCKET_NAME")
+ file_name=db_resource.path.split('/')[-1]
+ file_name=''.join(char for char in file_name if char != "`")
+ object_key=f"public_resources/run_id{db_resource.agent_execution_id}/{file_name}"
+ s3.put_object(Bucket=bucket_name, Key=object_key, Body=content)
+ file_url = f"https://{bucket_name}.s3.amazonaws.com/{object_key}"
+ resource_execution_id=db_resource.agent_execution_id
+ if resource_execution_id in response_obj:
+ response_obj[resource_execution_id].append(file_url)
+ else:
+ response_obj[resource_execution_id]=[file_url]
+ return response_obj
\ No newline at end of file
diff --git a/superagi/helper/webhook_manager.py b/superagi/helper/webhook_manager.py
new file mode 100644
index 000000000..cf5e988d2
--- /dev/null
+++ b/superagi/helper/webhook_manager.py
@@ -0,0 +1,37 @@
+from superagi.models.agent import Agent
+from superagi.models.agent_execution import AgentExecution
+from superagi.models.webhooks import Webhooks
+from superagi.models.webhook_events import WebhookEvents
+import requests
+import json
+from superagi.lib.logger import logger
+class WebHookManager:
+ def __init__(self,session):
+ self.session=session
+
+ def agent_status_change_callback(self, agent_execution_id, curr_status, old_status):
+ if curr_status=="CREATED" or agent_execution_id is None:
+ return
+ agent_id=AgentExecution.get_agent_execution_from_id(self.session,agent_execution_id).agent_id
+ agent=Agent.get_agent_from_id(self.session,agent_id)
+ org=agent.get_agent_organisation(self.session)
+ org_webhooks=self.session.query(Webhooks).filter(Webhooks.org_id == org.id).all()
+
+ for webhook_obj in org_webhooks:
+ webhook_obj_body={"agent_id":agent_id,"org_id":org.id,"event":f"{old_status} to {curr_status}"}
+ error=None
+ request=None
+ status='sent'
+ try:
+ request = requests.post(webhook_obj.url.strip(), data=json.dumps(webhook_obj_body), headers=webhook_obj.headers)
+ except Exception as e:
+ logger.error(f"Exception occured in webhooks {e}")
+ error=str(e)
+ if request is not None and request.status_code not in [200,201] and error is None:
+ error=request.text
+ if error is not None:
+ status='Error'
+ webhook_event=WebhookEvents(agent_id=agent_id, run_id=agent_execution_id, event=f"{old_status} to {curr_status}", status=status, errors=error)
+ self.session.add(webhook_event)
+ self.session.commit()
+
diff --git a/superagi/models/agent.py b/superagi/models/agent.py
index 38a8796fc..49bd93f9d 100644
--- a/superagi/models/agent.py
+++ b/superagi/models/agent.py
@@ -4,8 +4,10 @@
import json
from sqlalchemy import Column, Integer, String, Boolean
+from sqlalchemy import or_
from superagi.lib.logger import logger
+from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_template import AgentTemplate
from superagi.models.agent_template_config import AgentTemplateConfig
# from superagi.models import AgentConfiguration
@@ -13,7 +15,7 @@
from superagi.models.organisation import Organisation
from superagi.models.project import Project
from superagi.models.workflows.agent_workflow import AgentWorkflow
-from superagi.models.agent_config import AgentConfiguration
+
class Agent(DBBaseModel):
"""
@@ -35,8 +37,8 @@ class Agent(DBBaseModel):
project_id = Column(Integer)
description = Column(String)
agent_workflow_id = Column(Integer)
- is_deleted = Column(Boolean, default = False)
-
+ is_deleted = Column(Boolean, default=False)
+
def __repr__(self):
"""
Returns a string representation of the Agent object.
@@ -47,8 +49,8 @@ def __repr__(self):
"""
return f"Agent(id={self.id}, name='{self.name}', project_id={self.project_id}, " \
f"description='{self.description}', agent_workflow_id={self.agent_workflow_id}," \
- f"is_deleted='{self.is_deleted}')"
-
+ f"is_deleted='{self.is_deleted}')"
+
@classmethod
def fetch_configuration(cls, session, agent_id: int):
"""
@@ -105,7 +107,8 @@ def eval_agent_config(cls, key, value):
"""
- if key in ["name", "description", "exit", "model", "permission_type", "LTM_DB", "resource_summary", "knowledge"]:
+ if key in ["name", "description", "agent_type", "exit", "model", "permission_type", "LTM_DB",
+ "resource_summary", "knowledge"]:
return value
elif key in ["project_id", "memory_window", "max_iterations", "iteration_interval"]:
return int(value)
@@ -150,7 +153,6 @@ def create_agent_with_config(cls, db, agent_with_config):
# AgentWorkflow.name == "Fixed Task Queue").first()
# db_agent.agent_workflow_id = agent_workflow.id
-
db.session.commit()
# Create Agent Configuration
@@ -291,3 +293,9 @@ def find_org_by_agent_id(cls, session, agent_id: int):
agent = session.query(Agent).filter_by(id=agent_id).first()
project = session.query(Project).filter(Project.id == agent.project_id).first()
return session.query(Organisation).filter(Organisation.id == project.organisation_id).first()
+
+ @classmethod
+ def get_active_agent_by_id(cls, session, agent_id: int):
+ db_agent = session.query(Agent).filter(Agent.id == agent_id,
+ or_(Agent.is_deleted == False, Agent.is_deleted is None)).first()
+ return db_agent
diff --git a/superagi/models/agent_execution.py b/superagi/models/agent_execution.py
index f95afb85b..5cba5f509 100644
--- a/superagi/models/agent_execution.py
+++ b/superagi/models/agent_execution.py
@@ -164,4 +164,49 @@ def assign_next_step_id(cls, session, agent_execution_id: int, next_step_id: int
if next_step.action_type == "ITERATION_WORKFLOW":
trigger_step = IterationWorkflow.fetch_trigger_step_id(session, next_step.action_reference_id)
agent_execution.iteration_workflow_step_id = trigger_step.id
- session.commit()
\ No newline at end of file
+ session.commit()
+
+ @classmethod
+ def get_execution_by_agent_id_and_status(cls, session, agent_id: int, status_filter: str):
+ db_agent_execution = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id, AgentExecution.status == status_filter).first()
+ return db_agent_execution
+
+
+ @classmethod
+ def get_all_executions_by_status_and_agent_id(cls, session, agent_id, execution_state_change_input, current_status: str):
+ db_execution_arr = []
+ if execution_state_change_input.run_ids is not None:
+ db_execution_arr = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id, AgentExecution.status == current_status,AgentExecution.id.in_(execution_state_change_input.run_ids)).all()
+ else:
+ db_execution_arr = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id, AgentExecution.status == current_status).all()
+ return db_execution_arr
+
+ @classmethod
+ def get_all_executions_by_filter_config(cls, session, agent_id: int, filter_config):
+ db_execution_query = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id)
+ if filter_config.run_ids is not None:
+ db_execution_query = db_execution_query.filter(AgentExecution.id.in_(filter_config.run_ids))
+
+ if filter_config.run_status_filter is not None and filter_config.run_status_filter in ["CREATED", "RUNNING",
+ "PAUSED", "COMPLETED",
+ "TERMINATED"]:
+ db_execution_query = db_execution_query.filter(AgentExecution.status == filter_config.run_status_filter)
+
+ db_execution_arr = db_execution_query.all()
+ return db_execution_arr
+
+ @classmethod
+ def validate_run_ids(cls, session, run_ids: list, organisation_id: int):
+ from superagi.models.agent import Agent
+ from superagi.models.project import Project
+
+ run_ids=list(set(run_ids))
+ agent_ids=session.query(AgentExecution.agent_id).filter(AgentExecution.id.in_(run_ids)).distinct().all()
+ agent_ids = [id for (id,) in agent_ids]
+ project_ids=session.query(Agent.project_id).filter(Agent.id.in_(agent_ids)).distinct().all()
+ project_ids = [id for (id,) in project_ids]
+ org_ids=session.query(Project.organisation_id).filter(Project.id.in_(project_ids)).distinct().all()
+ org_ids = [id for (id,) in org_ids]
+
+ if len(org_ids) > 1 or org_ids[0] != organisation_id:
+ raise Exception(f"one or more run IDs not found")
diff --git a/superagi/models/agent_schedule.py b/superagi/models/agent_schedule.py
index 6415c375c..32e8dcd84 100644
--- a/superagi/models/agent_schedule.py
+++ b/superagi/models/agent_schedule.py
@@ -45,4 +45,27 @@ def __repr__(self):
f"expiry_date={self.expiry_date}, " \
f"expiry_runs={self.expiry_runs}), " \
f"current_runs={self.expiry_runs}), " \
- f"status={self.status}), "
\ No newline at end of file
+ f"status={self.status}), "
+
+ @classmethod
+ def save_schedule_from_config(cls, session, db_agent, schedule_config: AgentScheduleInput):
+ agent_schedule = AgentSchedule(
+ agent_id=db_agent.id,
+ start_time=schedule_config.start_time,
+ next_scheduled_time=schedule_config.start_time,
+ recurrence_interval=schedule_config.recurrence_interval,
+ expiry_date=schedule_config.expiry_date,
+ expiry_runs=schedule_config.expiry_runs,
+ current_runs=0,
+ status="SCHEDULED"
+ )
+
+ agent_schedule.agent_id = db_agent.id
+ session.add(agent_schedule)
+ session.commit()
+ return agent_schedule
+
+ @classmethod
+ def find_by_agent_id(cls, session, agent_id: int):
+ db_schedule=session.query(AgentSchedule).filter(AgentSchedule.agent_id == agent_id).first()
+ return db_schedule
diff --git a/superagi/models/api_key.py b/superagi/models/api_key.py
new file mode 100644
index 000000000..1cc3e310a
--- /dev/null
+++ b/superagi/models/api_key.py
@@ -0,0 +1,46 @@
+from sqlalchemy import Column, Integer, Text, String, Boolean, ForeignKey
+from sqlalchemy.orm import relationship
+from superagi.models.base_model import DBBaseModel
+from superagi.models.agent_execution import AgentExecution
+from sqlalchemy import func, or_
+
+class ApiKey(DBBaseModel):
+ """
+
+ Attributes:
+
+
+ Methods:
+ """
+ __tablename__ = 'api_keys'
+
+ id = Column(Integer, primary_key=True)
+ org_id = Column(Integer)
+ name = Column(String)
+ key = Column(String)
+ is_expired= Column(Boolean)
+
+ @classmethod
+ def get_by_org_id(cls, session, org_id: int):
+ db_api_keys=session.query(ApiKey).filter(ApiKey.org_id==org_id,or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).all()
+ return db_api_keys
+
+ @classmethod
+ def get_by_id(cls, session, id: int):
+ db_api_key=session.query(ApiKey).filter(ApiKey.id==id,or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first()
+ return db_api_key
+
+ @classmethod
+ def delete_by_id(cls, session,id: int):
+ db_api_key = session.query(ApiKey).filter(ApiKey.id == id).first()
+ db_api_key.is_expired = True
+ session.commit()
+ session.flush()
+
+ @classmethod
+ def update_api_key(cls, session, id: int, name: str):
+ db_api_key = session.query(ApiKey).filter(ApiKey.id == id).first()
+ db_api_key.name = name
+ session.commit()
+ session.flush()
+
diff --git a/superagi/models/project.py b/superagi/models/project.py
index 1de1eb624..c798e1d1a 100644
--- a/superagi/models/project.py
+++ b/superagi/models/project.py
@@ -55,3 +55,13 @@ def find_or_create_default_project(cls, session, organisation_id):
else:
default_project = project
return default_project
+
+ @classmethod
+ def find_by_org_id(cls, session, org_id: int):
+ project = session.query(Project).filter(Project.organisation_id == org_id).first()
+ return project
+
+ @classmethod
+ def find_by_id(cls, session, project_id: int):
+ project = session.query(Project).filter(Project.id == project_id).first()
+ return project
\ No newline at end of file
diff --git a/superagi/models/resource.py b/superagi/models/resource.py
index 926123e47..78713b690 100644
--- a/superagi/models/resource.py
+++ b/superagi/models/resource.py
@@ -58,6 +58,11 @@ def validate_resource_type(storage_type):
if storage_type not in valid_types:
raise InvalidResourceType("Invalid resource type")
-
+
+ @classmethod
+ def find_by_run_ids(cls, session, run_ids: list):
+ db_resources_arr=session.query(Resource).filter(Resource.agent_execution_id.in_(run_ids)).all()
+ return db_resources_arr
+
class InvalidResourceType(Exception):
"""Custom exception for invalid resource type"""
diff --git a/superagi/models/toolkit.py b/superagi/models/toolkit.py
index 2c89a38ab..5a9c0a0e9 100644
--- a/superagi/models/toolkit.py
+++ b/superagi/models/toolkit.py
@@ -138,3 +138,26 @@ def fetch_tool_ids_from_toolkit(cls, session, toolkit_ids):
if tool is not None:
agent_toolkit_tools.append(tool.id)
return agent_toolkit_tools
+
+ @classmethod
+ def get_tool_and_toolkit_arr(cls, session, agent_config_tools_arr: list):
+ from superagi.models.tool import Tool
+ toolkits_arr= set()
+ tools_arr= set()
+ for tool_obj in agent_config_tools_arr:
+ toolkit=session.query(Toolkit).filter(Toolkit.name == tool_obj["name"].strip()).first()
+ if toolkit is None:
+ raise Exception("One or more of the Tool(s)/Toolkit(s) does not exist.")
+ toolkits_arr.add(toolkit.id)
+ if tool_obj.get("tools"):
+ for tool_name_str in tool_obj["tools"]:
+ tool_db_obj=session.query(Tool).filter(Tool.name == tool_name_str.strip()).first()
+ if tool_db_obj is None:
+ raise Exception("One or more of the Tool(s)/Toolkit(s) does not exist.")
+
+ tools_arr.add(tool_db_obj.id)
+ else:
+ tools=Tool.get_toolkit_tools(session, toolkit.id)
+ for tool_db_obj in tools:
+ tools_arr.add(tool_db_obj.id)
+ return list(tools_arr)
diff --git a/superagi/models/webhook_events.py b/superagi/models/webhook_events.py
new file mode 100644
index 000000000..af90ea492
--- /dev/null
+++ b/superagi/models/webhook_events.py
@@ -0,0 +1,25 @@
+from sqlalchemy import Column, Integer, Text, String, Boolean, ForeignKey
+from sqlalchemy.orm import relationship
+from superagi.models.base_model import DBBaseModel
+from superagi.models.agent_execution import AgentExecution
+
+
+class WebhookEvents(DBBaseModel):
+ """
+
+ Attributes:
+
+
+ Methods:
+ """
+ __tablename__ = 'webhook_events'
+
+ id = Column(Integer, primary_key=True)
+ agent_id=Column(Integer)
+ run_id = Column(Integer)
+ event = Column(String)
+ status = Column(String)
+ errors= Column(Text)
+
+
+
diff --git a/superagi/models/webhooks.py b/superagi/models/webhooks.py
new file mode 100644
index 000000000..14d683472
--- /dev/null
+++ b/superagi/models/webhooks.py
@@ -0,0 +1,22 @@
+from sqlalchemy import Column, Integer, Text, String, Boolean, ForeignKey,JSON
+from sqlalchemy.orm import relationship
+from sqlalchemy.dialects.postgresql import JSONB
+from superagi.models.base_model import DBBaseModel
+from superagi.models.agent_execution import AgentExecution
+
+class Webhooks(DBBaseModel):
+ """
+
+ Attributes:
+
+
+ Methods:
+ """
+ __tablename__ = 'webhooks'
+
+ id = Column(Integer, primary_key=True)
+ name=Column(String)
+ org_id = Column(Integer)
+ url = Column(String)
+ headers=Column(JSON)
+ is_deleted=Column(Boolean)
diff --git a/superagi/vector_store/vector_factory.py b/superagi/vector_store/vector_factory.py
index 5d02c7f74..c67f8888b 100644
--- a/superagi/vector_store/vector_factory.py
+++ b/superagi/vector_store/vector_factory.py
@@ -94,11 +94,10 @@ def build_vector_storage(cls, vector_store: VectorStoreType, index_name, embeddi
return qdrant.Qdrant(client, embedding_model, index_name)
except:
raise ValueError("Qdrant API key not found")
-
+
if vector_store == VectorStoreType.WEAVIATE:
try:
client = weaviate.create_weaviate_client(creds["url"], creds["api_key"])
return weaviate.Weaviate(client, embedding_model, index_name)
except:
raise ValueError("Weaviate API key not found")
-
\ No newline at end of file
diff --git a/superagi/worker.py b/superagi/worker.py
index 9e41c2fd8..b6e5ea231 100644
--- a/superagi/worker.py
+++ b/superagi/worker.py
@@ -15,6 +15,10 @@
from superagi.models.db import connect_db
from superagi.types.model_source_types import ModelSourceType
+from sqlalchemy import event
+from superagi.models.agent_execution import AgentExecution
+from superagi.helper.webhook_manager import WebHookManager
+
redis_url = get_config('REDIS_URL', 'super__redis:6379')
app = Celery("superagi", include=["superagi.worker"], imports=["superagi.worker"])
@@ -32,9 +36,16 @@
}
app.conf.beat_schedule = beat_schedule
+@event.listens_for(AgentExecution.status, "set")
+def agent_status_change(target, val,old_val,initiator):
+ if get_config("IN_TESTING",False):
+ webhook_callback.delay(target.id,val,old_val)
+
+
@app.task(name="initialize-schedule-agent", autoretry_for=(Exception,), retry_backoff=2, max_retries=5)
def initialize_schedule_agent_task():
"""Executing agent scheduling in the background."""
+
schedule_helper = AgentScheduleHelper()
schedule_helper.update_next_scheduled_time()
schedule_helper.run_scheduled_agents()
@@ -49,7 +60,7 @@ def execute_agent(agent_execution_id: int, time):
AgentExecutor().execute_next_step(agent_execution_id=agent_execution_id)
-@app.task(name="summarize_resource", autoretry_for=(Exception,), retry_backoff=2, max_retries=5, serializer='pickle')
+@app.task(name="summarize_resource", autoretry_for=(Exception,), retry_backoff=2, max_retries=5,serializer='pickle')
def summarize_resource(agent_id: int, resource_id: int):
"""Summarize a resource in background."""
from superagi.resource_manager.resource_summary import ResourceSummarizer
@@ -77,3 +88,11 @@ def summarize_resource(agent_id: int, resource_id: int):
resource_summarizer.add_to_vector_store_and_create_summary(resource_id=resource_id,
documents=documents)
session.close()
+
+@app.task(name="webhook_callback", autoretry_for=(Exception,), retry_backoff=2, max_retries=5,serializer='pickle')
+def webhook_callback(agent_execution_id,val,old_val):
+ engine = connect_db()
+ Session = sessionmaker(bind=engine)
+ with Session() as session:
+ WebHookManager(session).agent_status_change_callback(agent_execution_id, val, old_val)
+
diff --git a/tests/unit_tests/controllers/api/__init__.py b/tests/unit_tests/controllers/api/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/unit_tests/controllers/api/test_agent.py b/tests/unit_tests/controllers/api/test_agent.py
new file mode 100644
index 000000000..99fda13b2
--- /dev/null
+++ b/tests/unit_tests/controllers/api/test_agent.py
@@ -0,0 +1,220 @@
+import pytest
+from fastapi.testclient import TestClient
+from fastapi import HTTPException
+
+import superagi.config.config
+from unittest.mock import MagicMock, patch,Mock
+from main import app
+from unittest.mock import patch,create_autospec
+from sqlalchemy.orm import Session
+from superagi.controllers.api.agent import ExecutionStateChangeConfigIn,AgentConfigUpdateExtInput
+from superagi.models.agent import Agent
+from superagi.models.project import Project
+
+client = TestClient(app)
+
+@pytest.fixture
+def mock_api_key_get():
+ mock_api_key = "your_mock_api_key"
+ return mock_api_key
+@pytest.fixture
+def mock_execution_state_change_input():
+ return {
+
+ }
+@pytest.fixture
+def mock_run_id_config():
+ return {
+ "run_ids":[1,2]
+ }
+
+@pytest.fixture
+def mock_agent_execution():
+ return {
+
+ }
+@pytest.fixture
+def mock_run_id_config_empty():
+ return {
+ "run_ids":[]
+ }
+
+@pytest.fixture
+def mock_run_id_config_invalid():
+ return {
+ "run_ids":[12310]
+ }
+@pytest.fixture
+def mock_agent_config_update_ext_input():
+ return AgentConfigUpdateExtInput(
+ tools=[{"name":"Image Generation Toolkit"}],
+ schedule=None,
+ goal=["Test Goal"],
+ instruction=["Test Instruction"],
+ constraints=["Test Constraints"],
+ iteration_interval=10,
+ model="Test Model",
+ max_iterations=100,
+ agent_type="Test Agent Type"
+ )
+
+@pytest.fixture
+def mock_update_agent_config():
+ return {
+ "name": "agent_3_UPDATED",
+ "description": "AI assistant to solve complex problems",
+ "goal": ["create a photo of a cat"],
+ "agent_type": "Dynamic Task Workflow",
+ "constraints": [
+ "~4000 word limit for short term memory.",
+ "Your long term memory is short, so immediately save important information to files.",
+ "If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.",
+ "No user assistance",
+ "Exclusively use the commands listed in double quotes e.g. \"command name\""
+ ],
+ "instruction": ["Be accurate"],
+ "tools":[
+ {
+ "name":"Image Generation Toolkit"
+ }
+ ],
+ "iteration_interval": 500,
+ "model": "gpt-4",
+ "max_iterations": 100
+ }
+# Define test cases
+
+def test_update_agent_not_found(mock_update_agent_config,mock_api_key_get):
+ with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \
+ patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \
+ patch('superagi.helper.auth.db') as mock_auth_db, \
+ patch('superagi.controllers.api.agent.db') as db_mock:
+
+ # Mock the session
+ mock_session = create_autospec(Session)
+ # # Configure session query methods to return None for agent
+ mock_session.query.return_value.filter.return_value.first.return_value = None
+ response = client.put(
+ "/v1/agent/1",
+ headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers
+ json=mock_update_agent_config
+ )
+ assert response.status_code == 404
+ assert response.text == '{"detail":"Agent not found"}'
+
+
+def test_get_run_resources_no_run_ids(mock_run_id_config_empty,mock_api_key_get):
+ with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \
+ patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \
+ patch('superagi.helper.auth.db') as mock_auth_db, \
+ patch('superagi.controllers.api.agent.db') as db_mock, \
+ patch('superagi.controllers.api.agent.get_config', return_value="S3") as mock_get_config:
+
+ # Mock the session
+ mock_session = create_autospec(Session)
+ # # Configure session query methods to return None for agent
+ mock_session.query.return_value.filter.return_value.first.return_value = None
+ response = client.post(
+ "v1/agent/resources/output",
+ headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers
+ json=mock_run_id_config_empty
+ )
+ assert response.status_code == 404
+ assert response.text == '{"detail":"No execution_id found"}'
+
+def test_get_run_resources_invalid_run_ids(mock_run_id_config_invalid,mock_api_key_get):
+ with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \
+ patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \
+ patch('superagi.helper.auth.db') as mock_auth_db, \
+ patch('superagi.controllers.api.agent.db') as db_mock, \
+ patch('superagi.controllers.api.agent.get_config', return_value="S3") as mock_get_config:
+
+ # Mock the session
+ mock_session = create_autospec(Session)
+ # # Configure session query methods to return None for agent
+ mock_session.query.return_value.filter.return_value.first.return_value = None
+ response = client.post(
+ "v1/agent/resources/output",
+ headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers
+ json=mock_run_id_config_invalid
+ )
+ assert response.status_code == 404
+ assert response.text == '{"detail":"One or more run_ids not found"}'
+
+def test_resume_agent_runs_agent_not_found(mock_execution_state_change_input,mock_api_key_get):
+ with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \
+ patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \
+ patch('superagi.helper.auth.db') as mock_auth_db, \
+ patch('superagi.controllers.api.agent.db') as db_mock:
+
+ # Mock the session
+ mock_session = create_autospec(Session)
+ # # Configure session query methods to return None for agent
+ mock_session.query.return_value.filter.return_value.first.return_value = None
+ response = client.post(
+ "/v1/agent/1/resume",
+ headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers
+ json=mock_execution_state_change_input
+ )
+ assert response.status_code == 404
+ assert response.text == '{"detail":"Agent not found"}'
+
+
+def test_pause_agent_runs_agent_not_found(mock_execution_state_change_input,mock_api_key_get):
+ with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \
+ patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \
+ patch('superagi.helper.auth.db') as mock_auth_db, \
+ patch('superagi.controllers.api.agent.db') as db_mock:
+
+ # Mock the session
+ mock_session = create_autospec(Session)
+ # # Configure session query methods to return None for agent
+ mock_session.query.return_value.filter.return_value.first.return_value = None
+ response = client.post(
+ "/v1/agent/1/pause",
+ headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers
+ json=mock_execution_state_change_input
+ )
+ assert response.status_code == 404
+ assert response.text == '{"detail":"Agent not found"}'
+
+def test_create_run_agent_not_found(mock_agent_execution,mock_api_key_get):
+ with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \
+ patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \
+ patch('superagi.helper.auth.db') as mock_auth_db, \
+ patch('superagi.controllers.api.agent.db') as db_mock:
+
+ # Mock the session
+ mock_session = create_autospec(Session)
+ # # Configure session query methods to return None for agent
+ mock_session.query.return_value.filter.return_value.first.return_value = None
+ response = client.post(
+ "/v1/agent/1/run",
+ headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers
+ json=mock_agent_execution
+ )
+ assert response.status_code == 404
+ assert response.text == '{"detail":"Agent not found"}'
+
+def test_create_run_project_not_matching_org(mock_agent_execution, mock_api_key_get):
+ with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \
+ patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \
+ patch('superagi.helper.auth.db') as mock_auth_db, \
+ patch('superagi.controllers.api.agent.db') as db_mock:
+
+ # Mock the session and configure query methods to return agent and project
+ mock_session = create_autospec(Session)
+ mock_agent = Agent(id=1, project_id=1, agent_workflow_id=1)
+ mock_session.query.return_value.filter.return_value.first.return_value = mock_agent
+ mock_project = Project(id=1, organisation_id=2) # Different organisation ID
+ db_mock.Project.find_by_id.return_value = mock_project
+ db_mock.session.return_value.__enter__.return_value = mock_session
+
+ response = client.post(
+ "/v1/agent/1/run",
+ headers={"X-API-Key": mock_api_key_get},
+ json=mock_agent_execution
+ )
+
+ assert response.status_code == 404
+ assert response.text == '{"detail":"Agent not found"}'
diff --git a/tests/unit_tests/models/test_agent.py b/tests/unit_tests/models/test_agent.py
index da7e6c4c1..18d614ce0 100644
--- a/tests/unit_tests/models/test_agent.py
+++ b/tests/unit_tests/models/test_agent.py
@@ -22,6 +22,27 @@ def test_get_agent_from_id():
# Assert that the returned agent object matches the mock agent
assert agent == mock_agent
+
+def test_get_active_agent_by_id():
+ # Create a mock session
+ session = create_autospec(Session)
+
+ # Create a sample agent ID
+ agent_id = 1
+
+ # Create a mock agent object to be returned by the session query
+ mock_agent = Agent(id=agent_id, name="Test Agent", project_id=1, description="Agent for testing",is_deleted=False)
+
+ # Configure the session query to return the mock agent
+ session.query.return_value.filter.return_value.first.return_value = mock_agent
+
+ # Call the method under test
+ agent = Agent.get_active_agent_by_id(session, agent_id)
+
+ # Assert that the returned agent object matches the mock agent
+ assert agent == mock_agent
+ assert agent.is_deleted == False
+
def test_eval_tools_key():
key = "tools"
value = "[1, 2, 3]"
diff --git a/tests/unit_tests/models/test_agent_execution.py b/tests/unit_tests/models/test_agent_execution.py
index a2c91e581..3ecbdc84f 100644
--- a/tests/unit_tests/models/test_agent_execution.py
+++ b/tests/unit_tests/models/test_agent_execution.py
@@ -9,8 +9,6 @@
from superagi.models.agent_execution import AgentExecution
from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep
from superagi.models.workflows.iteration_workflow import IterationWorkflow
-
-
def test_get_agent_execution_from_id():
# Create a mock session
session = create_autospec(Session)
@@ -89,3 +87,28 @@ def test_assign_next_step_id(mock_session, mocker):
# Check that the attributes were updated
assert mock_execution.current_agent_step_id == 2
assert mock_execution.iteration_workflow_step_id == 3
+
+def test_get_execution_by_agent_id_and_status():
+ session = create_autospec(Session)
+
+ # Create a sample agent execution ID
+ agent_execution_id = 1
+
+ # Create a mock agent execution object to be returned by the session query
+ mock_agent_execution = AgentExecution(id=agent_execution_id, name="Test Execution", status="RUNNING")
+
+ # Configure the session query to return the mock agent
+ session.query.return_value.filter.return_value.first.return_value = mock_agent_execution
+
+ # Call the method under test
+ agent_execution = AgentExecution.get_execution_by_agent_id_and_status(session, agent_execution_id,"RUNNING")
+
+ # Assert that the returned agent object matches the mock agent
+ assert agent_execution == mock_agent_execution
+ assert agent_execution.status == "RUNNING"
+
+@pytest.fixture
+def mock_session(mocker):
+ return mocker.MagicMock()
+
+
diff --git a/tests/unit_tests/models/test_agent_schedule.py b/tests/unit_tests/models/test_agent_schedule.py
new file mode 100644
index 000000000..2f5f84355
--- /dev/null
+++ b/tests/unit_tests/models/test_agent_schedule.py
@@ -0,0 +1,23 @@
+from unittest.mock import create_autospec
+
+from sqlalchemy.orm import Session
+from superagi.models.agent_schedule import AgentSchedule
+
+def test_find_by_agent_id():
+ # Create a mock session
+ session = create_autospec(Session)
+
+ # Create a sample agent ID
+ agent_id = 1
+
+ # Create a mock agent schedule object to be returned by the session query
+ mock_agent_schedule = AgentSchedule(id=1,agent_id=agent_id, start_time="2023-08-10 12:17:00", recurrence_interval="2 Minutes", expiry_runs=2)
+
+ # Configure the session query to return the mock agent
+ session.query.return_value.filter.return_value.first.return_value = mock_agent_schedule
+
+ # Call the method under test
+ agent_schedule = AgentSchedule.find_by_agent_id(session, agent_id)
+
+ # Assert that the returned agent object matches the mock agent
+ assert agent_schedule == mock_agent_schedule
diff --git a/tests/unit_tests/models/test_api_key.py b/tests/unit_tests/models/test_api_key.py
new file mode 100644
index 000000000..fe5752f9d
--- /dev/null
+++ b/tests/unit_tests/models/test_api_key.py
@@ -0,0 +1,93 @@
+from unittest.mock import create_autospec
+
+from sqlalchemy.orm import Session
+from superagi.models.api_key import ApiKey
+
+def test_get_by_org_id():
+ # Create a mock session
+ session = create_autospec(Session)
+
+ # Create a sample organization ID
+ org_id = 1
+
+ # Create a mock ApiKey object to be returned by the session query
+ mock_api_keys = [
+ ApiKey(id=1, org_id=org_id, key="key1", is_expired=False),
+ ApiKey(id=2, org_id=org_id, key="key2", is_expired=False),
+ ]
+
+ # Configure the session query to return the mock api keys
+ session.query.return_value.filter.return_value.all.return_value = mock_api_keys
+
+ # Call the method under test
+ api_keys = ApiKey.get_by_org_id(session, org_id)
+
+ # Assert that the returned api keys match the mock api keys
+ assert api_keys == mock_api_keys
+
+
+def test_get_by_id():
+ # Create a mock session
+ session = create_autospec(Session)
+
+ # Create a sample api key ID
+ api_key_id = 1
+
+ # Create a mock ApiKey object to be returned by the session query
+ mock_api_key = ApiKey(id=api_key_id, org_id=1, key="key1", is_expired=False)
+
+ # Configure the session query to return the mock api key
+ session.query.return_value.filter.return_value.first.return_value = mock_api_key
+
+ # Call the method under test
+ api_key = ApiKey.get_by_id(session, api_key_id)
+
+ # Assert that the returned api key matches the mock api key
+ assert api_key == mock_api_key
+
+def test_delete_by_id():
+ # Create a mock session
+ session = create_autospec(Session)
+
+ # Create a sample api key ID
+ api_key_id = 1
+
+ # Create a mock ApiKey object to be returned by the session query
+ mock_api_key = ApiKey(id=api_key_id, org_id=1, key="key1", is_expired=False)
+
+ # Configure the session query to return the mock api key
+ session.query.return_value.filter.return_value.first.return_value = mock_api_key
+
+ # Call the method under test
+ ApiKey.delete_by_id(session, api_key_id)
+
+ # Assert that the api key's is_expired attribute is set to True
+ assert mock_api_key.is_expired == True
+
+ # Assert that the session.commit and session.flush methods were called
+ session.commit.assert_called_once()
+ session.flush.assert_called_once()
+
+def test_edit_by_id():
+ # Create a mock session
+ session = create_autospec(Session)
+
+ # Create a sample api key ID and new name
+ api_key_id = 1
+ new_name = "New Name"
+
+ # Create a mock ApiKey object to be returned by the session query
+ mock_api_key = ApiKey(id=api_key_id, org_id=1, key="key1", is_expired=False)
+
+ # Configure the session query to return the mock api key
+ session.query.return_value.filter.return_value.first.return_value = mock_api_key
+
+ # Call the method under test
+ ApiKey.update_api_key(session, api_key_id, new_name)
+
+ # Assert that the api key's name attribute is updated
+ assert mock_api_key.name == new_name
+
+ # Assert that the session.commit and session.flush methods were called
+ session.commit.assert_called_once()
+ session.flush.assert_called_once()
\ No newline at end of file
diff --git a/tests/unit_tests/models/test_project.py b/tests/unit_tests/models/test_project.py
new file mode 100644
index 000000000..9ac868217
--- /dev/null
+++ b/tests/unit_tests/models/test_project.py
@@ -0,0 +1,42 @@
+from unittest.mock import create_autospec
+
+from sqlalchemy.orm import Session
+from superagi.models.project import Project
+
+def test_find_by_org_id():
+ # Create a mock session
+ session = create_autospec(Session)
+
+ # Create a sample org ID
+ org_id = 123
+
+ # Create a mock project object to be returned by the session query
+ mock_project = Project(id=1, name="Test Project", organisation_id=org_id, description="Project for testing")
+
+ # Configure the session query to return the mock project
+ session.query.return_value.filter.return_value.first.return_value = mock_project
+
+ # Call the method under test
+ project = Project.find_by_org_id(session, org_id)
+
+ # Assert that the returned project object matches the mock project
+ assert project == mock_project
+
+def test_find_by_id():
+ # Create a mock session
+ session = create_autospec(Session)
+
+ # Create a sample project ID
+ project_id = 123
+
+ # Create a mock project object to be returned by the session query
+ mock_project = Project(id=project_id, name="Test Project", organisation_id=1, description="Project for testing")
+
+ # Configure the session query to return the mock project
+ session.query.return_value.filter.return_value.first.return_value = mock_project
+
+ # Call the method under test
+ project = Project.find_by_id(session, project_id)
+
+ # Assert that the returned project object matches the mock project
+ assert project == mock_project
\ No newline at end of file
diff --git a/tests/unit_tests/models/test_toolkit.py b/tests/unit_tests/models/test_toolkit.py
index 82302d97e..339c970c9 100644
--- a/tests/unit_tests/models/test_toolkit.py
+++ b/tests/unit_tests/models/test_toolkit.py
@@ -1,10 +1,11 @@
-from unittest.mock import MagicMock, patch, call
+from unittest.mock import MagicMock, patch, call,create_autospec,Mock
import pytest
from superagi.models.organisation import Organisation
from superagi.models.toolkit import Toolkit
from superagi.models.tool import Tool
+from sqlalchemy.orm import Session
@pytest.fixture
def mock_session():
@@ -243,3 +244,23 @@ def test_fetch_tool_ids_from_toolkit(mock_tool, mock_session):
# Assert
assert result == [mock_tool.id for _ in toolkit_ids]
+
+def test_get_tool_and_toolkit_arr_with_nonexistent_toolkit():
+ # Create a mock session
+ session = create_autospec(Session)
+
+ # Configure the session query to return None for toolkit
+ session.query.return_value.filter.return_value.first.return_value = None
+
+ # Call the method under test with a non-existent toolkit
+ agent_config_tools_arr = [
+ {"name": "NonExistentToolkit", "tools": ["Tool1", "Tool2"]},
+ ]
+
+ # Use a context manager to capture the raised exception and its message
+ with pytest.raises(Exception) as exc_info:
+ Toolkit.get_tool_and_toolkit_arr(session, agent_config_tools_arr)
+
+ # Assert that the expected error message is contained within the raised exception message
+ expected_error_message = "One or more of the Tool(s)/Toolkit(s) does not exist."
+ assert expected_error_message in str(exc_info.value)