diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b94860d37..8bbd7a241 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -91,6 +91,7 @@ jobs: ENV: DEV PLAIN_OUTPUT: True REDIS_URL: "localhost:6379" + IS_TESTING: True - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 diff --git a/docker-compose-dev.yaml b/docker-compose-dev.yaml index 94044916b..66a7086d1 100644 --- a/docker-compose-dev.yaml +++ b/docker-compose-dev.yaml @@ -73,4 +73,4 @@ networks: driver: bridge volumes: superagi_postgres_data: - redis_data: \ No newline at end of file + redis_data: diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js index 1fb1f5d48..3f7f62b30 100644 --- a/gui/pages/Content/Agents/AgentCreate.js +++ b/gui/pages/Content/Agents/AgentCreate.js @@ -534,7 +534,7 @@ export default function AgentCreate({ setEditButtonClicked(true); agentData.agent_id = editAgentId; const name = agentData.name - const adjustedDate = new Date((new Date()).getTime() + 6*24*60*60*1000 - 1*60*1000); + const adjustedDate = new Date((new Date()).getTime()); const formattedDate = `${adjustedDate.getDate()} ${['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December'][adjustedDate.getMonth()]} ${adjustedDate.getFullYear()} ${adjustedDate.getHours().toString().padStart(2, '0')}:${adjustedDate.getMinutes().toString().padStart(2, '0')}`; agentData.name = "Run " + formattedDate addAgentRun(agentData) diff --git a/gui/pages/Content/Agents/Agents.module.css b/gui/pages/Content/Agents/Agents.module.css index 62d782ef9..c6553a8f9 100644 --- a/gui/pages/Content/Agents/Agents.module.css +++ b/gui/pages/Content/Agents/Agents.module.css @@ -429,4 +429,35 @@ color: #888888 !important; text-decoration: line-through; pointerEvents: none !important; +} + +.modal_buttons{ + display: flex; + justify-content: flex-end; + margin-top: 20px +} + +.modal_info_class{ + margin-left: -5px; + margin-right: 5px; +} + +.table_contents{ + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + margin-top: 40px; + width: 100% +} + +.create_settings_button{ + display: flex; + justify-content: center; + align-items: center; + margin-top: 10px +} + +.button_margin{ + margin-top: -10px; } \ No newline at end of file diff --git a/gui/pages/Content/Marketplace/Market.module.css b/gui/pages/Content/Marketplace/Market.module.css index 545146d52..4bd162ade 100644 --- a/gui/pages/Content/Marketplace/Market.module.css +++ b/gui/pages/Content/Marketplace/Market.module.css @@ -515,3 +515,17 @@ overflow-y:scroll; overflow-x:hidden; } + +.settings_tab_button_clicked{ + background: #454254; + padding-right: 15px +} + +.settings_tab_button{ + background: transparent; + padding-right: 15px +} + +.settings_tab_img{ + margin-top: -1px; +} diff --git a/gui/pages/Dashboard/Settings/ApiKeys.js b/gui/pages/Dashboard/Settings/ApiKeys.js new file mode 100644 index 000000000..61654b4ca --- /dev/null +++ b/gui/pages/Dashboard/Settings/ApiKeys.js @@ -0,0 +1,286 @@ +import React, {useState, useEffect, useRef} from 'react'; +import {ToastContainer, toast} from 'react-toastify'; +import 'react-toastify/dist/ReactToastify.css'; +import agentStyles from "@/pages/Content/Agents/Agents.module.css"; +import { + createApiKey, deleteApiKey, + editApiKey, getApiKeys, +} from "@/pages/api/DashboardService"; +import {EventBus} from "@/utils/eventBus"; +import {createInternalId, loadingTextEffect, preventDefault, removeTab, returnToolkitIcon} from "@/utils/utils"; +import Image from "next/image"; +import styles from "@/pages/Content/Marketplace/Market.module.css"; +import styles1 from "@/pages/Content/Knowledge/Knowledge.module.css"; + +export default function ApiKeys() { + const [apiKeys, setApiKeys] = useState([]); + const [keyName, setKeyName] = useState(''); + const [editKey, setEditKey] = useState(''); + const apiKeyRef = useRef(null); + const editKeyRef = useRef(null); + const [editKeyId, setEditKeyId] = useState(-1); + const [deleteKey, setDeleteKey] = useState('') + const [isLoading, setIsLoading] = useState(true) + const [activeDropdown, setActiveDropdown] = useState(null); + const [editModal, setEditModal] = useState(false); + const [deleteKeyId, setDeleteKeyId] = useState(-1); + const [deleteModal, setDeleteModal] = useState(false); + const [createModal, setCreateModal] = useState(false); + const [displayModal, setDisplayModal] = useState(false); + const [apiKeyGenerated, setApiKeyGenerated] = useState(''); + const [loadingText, setLoadingText] = useState("Loading Api Keys"); + + + + useEffect(() => { + loadingTextEffect('Loading Api Keys', setLoadingText, 500); + fetchApiKeys() + }, []); + + + const handleModelApiKey = (event) => { + setKeyName(event.target.value); + }; + + const handleEditApiKey = (event) => { + setEditKey(event.target.value); + }; + + const createApikey = () => { + if(!keyName){ + toast.error("Enter key name", {autoClose: 1800}); + return; + } + createApiKey({name : keyName}) + .then((response) => { + setApiKeyGenerated(response.data.api_key) + toast.success("Api Key Generated", {autoClose: 1800}); + setCreateModal(false); + setDisplayModal(true); + fetchApiKeys(); + }) + .catch((error) => { + console.error('Error creating api key', error); + }); + } + const handleCopyClick = async () => { + if (apiKeyRef.current) { + try { + await navigator.clipboard.writeText(apiKeyRef.current.value); + toast.success("Key Copied", {autoClose: 1800}); + } catch (err) { + toast.error('Failed to Copy', {autoClose: 1800}); + } + } + }; + + const fetchApiKeys = () => { + getApiKeys() + .then((response) => { + const formattedData = response.data.map(item => { + return { + ...item, + created_at: `${new Date(item.created_at).getDate()}-${["JAN", "FEB", "MAR", "APR", "MAY", "JUN", "JUL", "AUG", "SEP", "OCT", "NOV", "DEC"][new Date(item.created_at).getMonth()]}-${new Date(item.created_at).getFullYear()}` + }; + }); + setApiKeys(formattedData) + setIsLoading(false) + }) + .catch((error) => { + console.error('Error fetching Api Keys', error); + }); + } + + const handleEditClick = () => { + if(editKeyRef.current.value.length <1){ + toast.error("Enter valid key name", {autoClose: 1800}); + return; + } + editApiKey({id: editKeyId,name : editKey}) + .then((response) => { + toast.success("Api Key Edited", {autoClose: 1800}); + fetchApiKeys(); + setEditModal(false); + setEditKey('') + setEditKeyId(-1) + }) + .catch((error) => { + console.error('Error editing api key', error); + }); + } + + const handleDeleteClick = () => { + deleteApiKey(deleteKeyId) + .then((response) => { + toast.success("Api Key Deleted", {autoClose: 1800}); + fetchApiKeys(); + setDeleteModal(false); + setDeleteKeyId(-1) + setDeleteKey('') + }) + .catch((error) => { + toast.error("Error deleting api key", {autoClose: 1800}); + console.error('Error deleting api key', error); + }); + } + + return (<> +
+
+
+ {!isLoading ?
+
+
API Keys
+ {apiKeys && apiKeys.length > 0 && !isLoading && + } +
+
+ + + + {apiKeys.length < 1 &&
+ no-permissions + No API Keys created! +
+ +
+
} + + {apiKeys.length > 0 &&
+ + + + + + + + + +
NameKeyCreated Date
+
+ + + {apiKeys.map((item, index) => ( + + + + + + + ))} + +
{item.name}{item.key.slice(0, 2) + "****" + item.key.slice(-4)}{item.created_at} setActiveDropdown(null)} onClick={() => { + if (activeDropdown === index) { + setActiveDropdown(null); + } else { + setActiveDropdown(index); + } + }}> + run-icon +
setActiveDropdown(null)}> +
    +
  • {setEditKey(item.name); setEditKeyId(item.id); setEditModal(true); setActiveDropdown(null);}}>Edit
  • +
  • {setDeleteKeyId(item.id); setDeleteKey(item.name) ; setDeleteModal(true); setActiveDropdown(null);}}>Delete
  • +
+
+
} +
+
:
+
{loadingText}
+
} +
+
+
+ + {createModal && (
setCreateModal(false)}> +
+
Create new API Key
+
+ + +
+
+ + +
+
+
)} + + {displayModal && apiKeyGenerated && (
setDisplayModal(false)}> +
+
{keyName} is created
+
+
+
+
+ info-icon +
+
+ Your secret API keys are sensitive pieces of information that should be kept confidential. Do not share them with anyone, and do not expose them in any way. If your secret API keys are compromised, someone could use them to access your API and make unauthorized changes to your data. This secret key is only displayed once for security reasons. Please save it in a secure location where you can access it easily. +
+
+
+
+
+
+
+
+ +
+
+
+
+
+ +
+
+
)} + + {editModal && (
{setEditModal(false); setEditKey(''); setEditKeyId(-1)}}> +
+
Edit API Key
+
+ + +
+
+ + +
+
+
)} + + {deleteModal && (
{setDeleteModal(false); setDeleteKeyId(-1); setDeleteKey('')}}> +
+
Delete {deleteKey} Key
+
+ +
+
+ + +
+
+
)} + + ) +} \ No newline at end of file diff --git a/gui/pages/Dashboard/Settings/Settings.js b/gui/pages/Dashboard/Settings/Settings.js index 297e6a3d5..402fc6050 100644 --- a/gui/pages/Dashboard/Settings/Settings.js +++ b/gui/pages/Dashboard/Settings/Settings.js @@ -4,6 +4,7 @@ import styles from "@/pages/Content/Marketplace/Market.module.css"; import Image from "next/image"; import Model from "@/pages/Dashboard/Settings/Model"; import Database from "@/pages/Dashboard/Settings/Database"; +import ApiKeys from "@/pages/Dashboard/Settings/ApiKeys"; export default function Settings({organisationId, sendDatabaseData}) { const [activeTab, setActiveTab] = useState('model'); @@ -44,11 +45,18 @@ export default function Settings({organisationId, sendDatabaseData}) { alt="database-icon"/> Database +
+ +
{activeTab === 'model' && } {activeTab === 'database' && } + {activeTab === 'apikeys' && }
diff --git a/gui/pages/Dashboard/TopBar.js b/gui/pages/Dashboard/TopBar.js index bd954ebb2..3588b8880 100644 --- a/gui/pages/Dashboard/TopBar.js +++ b/gui/pages/Dashboard/TopBar.js @@ -56,10 +56,10 @@ export default function TopBar({selectedProject, userName, env}) { dropdown-icon {dropdown && env === 'PROD' && -
setDropdown(true)} +
setDropdown(true)} onMouseLeave={() => setDropdown(false)}>
    -
  • setDropdown(false)}>{userName}
  • + {userName &&
  • setDropdown(false)}>{userName}
  • }
  • Logout
} diff --git a/gui/pages/_app.css b/gui/pages/_app.css index c78b07085..7be317c86 100644 --- a/gui/pages/_app.css +++ b/gui/pages/_app.css @@ -997,13 +997,16 @@ p { .r_0{right: 0} .w_120p{width: 120px} +.w_4{width: 4%} .w_6{width: 6%} .w_10{width: 10%} .w_12{width: 12%} +.w_18{width: 18%} .w_20{width: 20%} .w_22{width: 22%} .w_35{width: 35%} .w_56{width: 56%} +.w_60{width: 60%} .w_100{width: 100%} .w_inherit{width: inherit} .w_fit_content{width:fit-content} @@ -1052,6 +1055,7 @@ p { .gap_16{gap:16px;} .gap_20{gap:20px;} +.border_top_none{border-top: none;} .border_radius_8{border-radius: 8px;} .border_radius_25{border-radius: 25px;} @@ -1072,12 +1076,15 @@ p { .padding_0_8{padding: 0px 8px;} .padding_0_15{padding: 0px 15px;} +.flex_1{flex: 1} .flex_wrap{flex-wrap: wrap;} .mix_blend_mode{mix-blend-mode: exclusion;} .ff_sourceCode{font-family: 'Source Code Pro'} +.rotate_90{transform: rotate(90deg)} + /*------------------------------- My ROWS AND COLUMNS -------------------------------*/ .my_rows { display: flex; @@ -1673,3 +1680,27 @@ tr{ .history_box_selected{ background: #474255; } + +.loading_container{ + display: flex; + justify-content: center; + align-items: center; + height: 50vh +} + +.loading_text{ + font-size: 16px; + font-family: 'Source Code Pro'; +} + +.table_container{ + background: #272335; + border-radius: 8px; + margin-top:15px +} + +.top_bar_profile_dropdown{ + display: flex; + flex-direction: row; + justify-content: center; +} diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js index ed9b0a1ef..625c6c2af 100644 --- a/gui/pages/api/DashboardService.js +++ b/gui/pages/api/DashboardService.js @@ -303,3 +303,20 @@ export const fetchKnowledgeTemplateOverview = (knowledgeName) => { export const installKnowledgeTemplate = (knowledgeName, indexId) => { return api.get(`/knowledges/install/${knowledgeName}/index/${indexId}`); }; + +export const createApiKey = (apiName) => { + return api.post(`/api-keys`, apiName); +}; + +export const getApiKeys = () => { + return api.get(`/api-keys`); +}; + +export const editApiKey = (apiDetails) => { + return api.put(`/api-keys`, apiDetails); +}; + +export const deleteApiKey = (apiId) => { + return api.delete(`/api-keys/${apiId}`); +}; + diff --git a/gui/public/images/copy_icon.svg b/gui/public/images/copy_icon.svg new file mode 100644 index 000000000..46644e0dc --- /dev/null +++ b/gui/public/images/copy_icon.svg @@ -0,0 +1,3 @@ + + + diff --git a/gui/public/images/key_white.svg b/gui/public/images/key_white.svg new file mode 100644 index 000000000..9b35b92e0 --- /dev/null +++ b/gui/public/images/key_white.svg @@ -0,0 +1,3 @@ + + + diff --git a/gui/public/images/weaviate.svg b/gui/public/images/weaviate.svg index 1f1ee9788..9b9f5260c 100644 --- a/gui/public/images/weaviate.svg +++ b/gui/public/images/weaviate.svg @@ -6,4 +6,4 @@ - \ No newline at end of file + diff --git a/main.py b/main.py index f6227a179..4c698eb9a 100644 --- a/main.py +++ b/main.py @@ -38,6 +38,9 @@ from superagi.controllers.vector_dbs import router as vector_dbs_router from superagi.controllers.vector_db_indices import router as vector_db_indices_router from superagi.controllers.marketplace_stats import router as marketplace_stats_router +from superagi.controllers.api_key import router as api_key_router +from superagi.controllers.api.agent import router as api_agent_router +from superagi.controllers.webhook import router as web_hook_router from superagi.helper.tool_helper import register_toolkits, register_marketplace_toolkits from superagi.lib.logger import logger from superagi.llms.google_palm import GooglePalm @@ -50,7 +53,6 @@ from superagi.models.workflows.agent_workflow import AgentWorkflow from superagi.models.workflows.iteration_workflow import IterationWorkflow from superagi.models.workflows.iteration_workflow_step import IterationWorkflowStep - app = FastAPI() database_url = get_config('POSTGRES_URL') @@ -113,7 +115,9 @@ app.include_router(vector_dbs_router, prefix="/vector_dbs") app.include_router(vector_db_indices_router, prefix="/vector_db_indices") app.include_router(marketplace_stats_router, prefix="/marketplace") - +app.include_router(api_key_router, prefix="/api-keys") +app.include_router(api_agent_router,prefix="/v1/agent") +app.include_router(web_hook_router,prefix="/webhook") # in production you can use Settings management # from pydantic to get secret key from .env @@ -370,3 +374,4 @@ def github_client_id(): # # __________________TO RUN____________________________ # # uvicorn main:app --host 0.0.0.0 --port 8001 --reload + diff --git a/migrations/versions/446884dcae58_add_api_key_and_web_hook.py b/migrations/versions/446884dcae58_add_api_key_and_web_hook.py new file mode 100644 index 000000000..c4b353756 --- /dev/null +++ b/migrations/versions/446884dcae58_add_api_key_and_web_hook.py @@ -0,0 +1,65 @@ +"""add api_key and web_hook + +Revision ID: 446884dcae58 +Revises: 71e3980d55f5 +Create Date: 2023-07-29 10:55:21.714245 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '446884dcae58' +down_revision = '2fbd6472112c' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('api_keys', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('org_id', sa.Integer(), nullable=True), + sa.Column('name', sa.String(), nullable=True), + sa.Column('key', sa.String(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('is_expired',sa.Boolean(),nullable=True,default=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('webhooks', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('org_id', sa.Integer(), nullable=True), + sa.Column('url', sa.String(), nullable=True), + sa.Column('headers', sa.JSON(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('is_deleted',sa.Boolean(),nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('webhook_events', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('agent_id', sa.Integer(), nullable=True), + sa.Column('run_id', sa.Integer(), nullable=True), + sa.Column('event', sa.String(), nullable=True), + sa.Column('status', sa.String(), nullable=True), + sa.Column('errors', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + #add index ********************* + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + op.drop_table('webhooks') + op.drop_table('api_keys') + op.drop_table('webhook_events') + + # ### end Alembic commands ### diff --git a/superagi/controllers/agent_execution.py b/superagi/controllers/agent_execution.py index 743aae05b..1a485562c 100644 --- a/superagi/controllers/agent_execution.py +++ b/superagi/controllers/agent_execution.py @@ -108,8 +108,7 @@ def create_agent_execution(agent_execution: AgentExecutionIn, agent_execution_configs[agent_config.key] = [] elif agent_config.key == "constraints": if agent_config.value: - constraints = [item.strip('"') for item in agent_config.value.strip('{}').split(',')] - agent_execution_configs[agent_config.key] = constraints + agent_execution_configs[agent_config.key] = agent_config.value else: agent_execution_configs[agent_config.key] = [] else: diff --git a/superagi/controllers/api/agent.py b/superagi/controllers/api/agent.py new file mode 100644 index 000000000..a71e60a8f --- /dev/null +++ b/superagi/controllers/api/agent.py @@ -0,0 +1,326 @@ +from fastapi import APIRouter +from fastapi import HTTPException, Depends ,Security + +from fastapi_sqlalchemy import db +from pydantic import BaseModel + +from superagi.worker import execute_agent +from superagi.helper.auth import validate_api_key,get_organisation_from_api_key +from superagi.models.agent import Agent +from superagi.models.agent_execution_config import AgentExecutionConfiguration +from superagi.models.agent_config import AgentConfiguration +from superagi.models.agent_schedule import AgentSchedule +from superagi.models.project import Project +from superagi.models.workflows.agent_workflow import AgentWorkflow +from superagi.models.agent_execution import AgentExecution +from superagi.models.organisation import Organisation +from superagi.models.resource import Resource +from superagi.controllers.types.agent_with_config import AgentConfigExtInput,AgentConfigUpdateExtInput +from superagi.models.workflows.iteration_workflow import IterationWorkflow +from superagi.helper.s3_helper import S3Helper +from datetime import datetime +from typing import Optional,List +from superagi.models.toolkit import Toolkit +from superagi.apm.event_handler import EventHandler +from superagi.config.config import get_config +router = APIRouter() + +class AgentExecutionIn(BaseModel): + name: Optional[str] + goal: Optional[List[str]] + instruction: Optional[List[str]] + + class Config: + orm_mode = True + +class RunFilterConfigIn(BaseModel): + run_ids:Optional[List[int]] + run_status_filter:Optional[str] + + class Config: + orm_mode = True + +class ExecutionStateChangeConfigIn(BaseModel): + run_ids:Optional[List[int]] + + class Config: + orm_mode = True + +class RunIDConfig(BaseModel): + run_ids:List[int] + + class Config: + orm_mode = True + +@router.post("", status_code=200) +def create_agent_with_config(agent_with_config: AgentConfigExtInput, + api_key: str = Security(validate_api_key), organisation:Organisation = Depends(get_organisation_from_api_key)): + project=Project.find_by_org_id(db.session, organisation.id) + try: + tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools) + except Exception as e: + raise HTTPException(status_code=404, detail=str(e)) + + agent_with_config.tools=tools_arr + agent_with_config.project_id=project.id + agent_with_config.exit="No exit criterion" + agent_with_config.permission_type="God Mode" + agent_with_config.LTM_DB=None + db_agent = Agent.create_agent_with_config(db, agent_with_config) + + if agent_with_config.schedule is not None: + agent_schedule = AgentSchedule.save_schedule_from_config(db.session, db_agent, agent_with_config.schedule) + if agent_schedule is None: + raise HTTPException(status_code=500, detail="Failed to schedule agent") + EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_with_config.name, + 'model': agent_with_config.model}, db_agent.id, + organisation.id if organisation else 0) + db.session.commit() + return { + "agent_id": db_agent.id + } + + start_step = AgentWorkflow.fetch_trigger_step_id(db.session, db_agent.agent_workflow_id) + iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session, + start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1 + # Creating an execution with RUNNING status + execution = AgentExecution(status='CREATED', last_execution_time=datetime.now(), agent_id=db_agent.id, + name="New Run", current_agent_step_id=start_step.id, iteration_workflow_step_id=iteration_step_id) + agent_execution_configs = { + "goal": agent_with_config.goal, + "instruction": agent_with_config.instruction + } + db.session.add(execution) + db.session.commit() + db.session.flush() + AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=execution, + agent_execution_configs=agent_execution_configs) + + organisation = db_agent.get_agent_organisation(db.session) + EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_with_config.name, + 'model': agent_with_config.model}, db_agent.id, + organisation.id if organisation else 0) + # execute_agent.delay(execution.id, datetime.now()) + db.session.commit() + return { + "agent_id": db_agent.id + } + +@router.post("/{agent_id}/run",status_code=200) +def create_run(agent_id:int,agent_execution: AgentExecutionIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)): + agent=Agent.get_agent_from_id(db.session,agent_id) + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + project=Project.find_by_id(db.session, agent.project_id) + if project.organisation_id!=organisation.id: + raise HTTPException(status_code=404, detail="Agent not found") + db_schedule=AgentSchedule.find_by_agent_id(db.session, agent_id) + if db_schedule is not None: + raise HTTPException(status_code=409, detail="Agent is already scheduled,cannot run") + start_step_id = AgentWorkflow.fetch_trigger_step_id(db.session, agent.agent_workflow_id) + db_agent_execution=AgentExecution.get_execution_by_agent_id_and_status(db.session, agent_id, "CREATED") + + if db_agent_execution is None: + db_agent_execution = AgentExecution(status="RUNNING", last_execution_time=datetime.now(), + agent_id=agent_id, name=agent_execution.name, num_of_calls=0, + num_of_tokens=0, + current_step_id=start_step_id) + db.session.add(db_agent_execution) + else: + db_agent_execution.status = "RUNNING" + + db.session.commit() + db.session.flush() + + agent_execution_configs = {} + if agent_execution.goal is not None: + agent_execution_configs = { + "goal": agent_execution.goal, + } + + if agent_execution.instruction is not None: + agent_execution_configs["instructions"] = agent_execution.instruction, + + if agent_execution_configs != {}: + AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=db_agent_execution, + agent_execution_configs=agent_execution_configs) + EventHandler(session=db.session).create_event('run_created', {'agent_execution_id': db_agent_execution.id,'agent_execution_name':db_agent_execution.name}, + agent_id, organisation.id if organisation else 0) + + if db_agent_execution.status == "RUNNING": + execute_agent.delay(db_agent_execution.id, datetime.now()) + return { + "run_id":db_agent_execution.id + } + +@router.put("/{agent_id}",status_code=200) +def update_agent(agent_id: int, agent_with_config: AgentConfigUpdateExtInput,api_key: str = Security(validate_api_key), + organisation:Organisation = Depends(get_organisation_from_api_key)): + + db_agent= Agent.get_active_agent_by_id(db.session, agent_id) + if not db_agent: + raise HTTPException(status_code=404, detail="agent not found") + + project=Project.find_by_id(db.session, db_agent.project_id) + if project is None: + raise HTTPException(status_code=404, detail="Project not found") + + if project.organisation_id!=organisation.id: + raise HTTPException(status_code=404, detail="Agent not found") + + db_execution=AgentExecution.get_execution_by_agent_id_and_status(db.session, agent_id, "RUNNING") + if db_execution is not None: + raise HTTPException(status_code=409, detail="Agent is already running,please pause and then update") + + db_schedule=AgentSchedule.find_by_agent_id(db.session, agent_id) + if db_schedule is not None: + raise HTTPException(status_code=409, detail="Agent is already scheduled,cannot update") + + try: + tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools) + except Exception as e: + raise HTTPException(status_code=404,detail=str(e)) + + if agent_with_config.schedule is not None: + raise HTTPException(status_code=400,detail="Cannot schedule an existing agent") + agent_with_config.tools=tools_arr + agent_with_config.project_id=project.id + agent_with_config.exit="No exit criterion" + agent_with_config.permission_type="God Mode" + agent_with_config.LTM_DB=None + + for key,value in agent_with_config.dict().items(): + if hasattr(db_agent,key) and value is not None: + setattr(db_agent,key,value) + db.session.commit() + db.session.flush() + + start_step = AgentWorkflow.fetch_trigger_step_id(db.session, db_agent.agent_workflow_id) + iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session, + start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1 + execution = AgentExecution(status='CREATED', last_execution_time=datetime.now(), agent_id=db_agent.id, + name="New Run", current_agent_step_id=start_step.id, iteration_workflow_step_id=iteration_step_id) + agent_execution_configs = { + "goal": agent_with_config.goal, + "instruction": agent_with_config.instruction, + "tools":agent_with_config.tools, + "constraints": agent_with_config.constraints, + "iteration_interval": agent_with_config.iteration_interval, + "model": agent_with_config.model, + "max_iterations": agent_with_config.max_iterations, + "agent_workflow": agent_with_config.agent_workflow, + } + agent_configurations = [ + AgentConfiguration(agent_id=db_agent.id, key=key, value=str(value)) + for key, value in agent_execution_configs.items() + ] + db.session.add_all(agent_configurations) + db.session.add(execution) + db.session.commit() + db.session.flush() + AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=execution, + agent_execution_configs=agent_execution_configs) + db.session.commit() + + return { + "agent_id":db_agent.id + } + + +@router.post("/{agent_id}/run-status") +def get_agent_runs(agent_id:int,filter_config:RunFilterConfigIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)): + agent= Agent.get_active_agent_by_id(db.session, agent_id) + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + + project=Project.find_by_id(db.session, agent.project_id) + if project.organisation_id!=organisation.id: + raise HTTPException(status_code=404, detail="Agent not found") + + db_execution_arr=[] + if filter_config.run_status_filter is not None: + filter_config.run_status_filter=filter_config.run_status_filter.upper() + + db_execution_arr=AgentExecution.get_all_executions_by_filter_config(db.session, agent.id, filter_config) + + response_arr=[] + for ind_execution in db_execution_arr: + response_arr.append({"run_id":ind_execution.id, "status":ind_execution.status}) + + return response_arr + + +@router.post("/{agent_id}/pause",status_code=200) +def pause_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateChangeConfigIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)): + agent= Agent.get_active_agent_by_id(db.session, agent_id) + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + + project=Project.find_by_id(db.session, agent.project_id) + if project.organisation_id!=organisation.id: + raise HTTPException(status_code=404, detail="Agent not found") + + #Checking if the run_ids whose output files are requested belong to the organisation + if execution_state_change_input.run_ids is not None: + try: + AgentExecution.validate_run_ids(db.session,execution_state_change_input.run_ids,organisation.id) + except Exception as e: + raise HTTPException(status_code=404, detail="One or more run_ids not found") + + db_execution_arr=AgentExecution.get_all_executions_by_status_and_agent_id(db.session, agent.id, execution_state_change_input, "RUNNING") + for ind_execution in db_execution_arr: + ind_execution.status="PAUSED" + db.session.commit() + db.session.flush() + return { + "result":"success" + } + +@router.post("/{agent_id}/resume",status_code=200) +def resume_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateChangeConfigIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)): + agent= Agent.get_active_agent_by_id(db.session, agent_id) + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + + project=Project.find_by_id(db.session, agent.project_id) + if project.organisation_id!=organisation.id: + raise HTTPException(status_code=404, detail="Agent not found") + + if execution_state_change_input.run_ids is not None: + try: + AgentExecution.validate_run_ids(db.session,execution_state_change_input.run_ids,organisation.id) + except Exception as e: + raise HTTPException(status_code=404, detail="One or more run_ids not found") + + db_execution_arr=AgentExecution.get_all_executions_by_status_and_agent_id(db.session, agent.id, execution_state_change_input, "PAUSED") + for ind_execution in db_execution_arr: + ind_execution.status="RUNNING" + + db.session.commit() + db.session.flush() + return { + "result":"success" + } + +@router.post("/resources/output",status_code=201) +def get_run_resources(run_id_config:RunIDConfig,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)): + if get_config('STORAGE_TYPE') != "S3": + raise HTTPException(status_code=400,detail="This endpoint only works when S3 is configured") + run_ids_arr=run_id_config.run_ids + if len(run_ids_arr)==0: + raise HTTPException(status_code=404, + detail=f"No execution_id found") + #Checking if the run_ids whose output files are requested belong to the organisation + try: + AgentExecution.validate_run_ids(db.session,run_ids_arr,organisation.id) + except Exception as e: + raise HTTPException(status_code=404, detail="One or more run_ids not found") + + db_resources_arr=Resource.find_by_run_ids(db.session, run_ids_arr) + + try: + response_obj=S3Helper().get_download_url_of_resources(db_resources_arr) + except: + raise HTTPException(status_code=401,detail="Invalid S3 credentials") + return response_obj + diff --git a/superagi/controllers/api_key.py b/superagi/controllers/api_key.py new file mode 100644 index 000000000..57e5c739b --- /dev/null +++ b/superagi/controllers/api_key.py @@ -0,0 +1,55 @@ +import json +import uuid +from fastapi import APIRouter, Body +from fastapi import HTTPException, Depends +from fastapi_jwt_auth import AuthJWT +from fastapi_sqlalchemy import db +from pydantic import BaseModel +from superagi.helper.auth import get_user_organisation +from superagi.helper.auth import check_auth +from superagi.models.api_key import ApiKey +from typing import Optional, Annotated +router = APIRouter() + +class ApiKeyIn(BaseModel): + id:int + name: str + class Config: + orm_mode = True + +class ApiKeyDeleteIn(BaseModel): + id:int + class Config: + orm_mode = True + +@router.post("") +def create_api_key(name: Annotated[str,Body(embed=True)], Authorize: AuthJWT = Depends(check_auth), organisation=Depends(get_user_organisation)): + api_key=str(uuid.uuid4()) + obj=ApiKey(key=api_key,name=name,org_id=organisation.id) + db.session.add(obj) + db.session.commit() + db.session.flush() + return {"api_key": api_key} + +@router.get("") +def get_all(Authorize: AuthJWT = Depends(check_auth), organisation=Depends(get_user_organisation)): + api_keys=ApiKey.get_by_org_id(db.session, organisation.id) + return api_keys + +@router.delete("/{api_key_id}") +def delete_api_key(api_key_id:int, Authorize: AuthJWT = Depends(check_auth)): + api_key=ApiKey.get_by_id(db.session, api_key_id) + if api_key is None: + raise HTTPException(status_code=404, detail="API key not found") + ApiKey.delete_by_id(db.session, api_key_id) + return {"success": True} + +@router.put("") +def edit_api_key(api_key_in:ApiKeyIn,Authorize: AuthJWT = Depends(check_auth)): + api_key=ApiKey.get_by_id(db.session, api_key_in.id) + if api_key is None: + raise HTTPException(status_code=404, detail="API key not found") + ApiKey.update_api_key(db.session, api_key_in.id, api_key_in.name) + return {"success": True} + + diff --git a/superagi/controllers/types/agent_with_config.py b/superagi/controllers/types/agent_with_config.py index f3995b5df..5ce81d211 100644 --- a/superagi/controllers/types/agent_with_config.py +++ b/superagi/controllers/types/agent_with_config.py @@ -1,6 +1,6 @@ from pydantic import BaseModel from typing import List, Optional - +from superagi.controllers.types.agent_schedule import AgentScheduleInput class AgentConfigInput(BaseModel): name: str @@ -20,3 +20,45 @@ class AgentConfigInput(BaseModel): max_iterations: int user_timezone: Optional[str] knowledge: Optional[int] + + + +class AgentConfigExtInput(BaseModel): + name: str + description: str + project_id: Optional[int] + goal: List[str] + instruction: List[str] + agent_workflow: str + constraints: List[str] + tools: List[dict] + LTM_DB:Optional[str] + exit: Optional[str] + permission_type: Optional[str] + iteration_interval: int + model: str + schedule: Optional[AgentScheduleInput] + max_iterations: int + user_timezone: Optional[str] + knowledge: Optional[int] + +class AgentConfigUpdateExtInput(BaseModel): + name: Optional[str] + description: Optional[str] + project_id: Optional[int] + goal: Optional[List[str]] + instruction: Optional[List[str]] + agent_workflow: Optional[str] + constraints: Optional[List[str]] + tools: Optional[List[dict]] + LTM_DB:Optional[str] + exit: Optional[str] + permission_type: Optional[str] + iteration_interval: Optional[int] + model: Optional[str] + schedule: Optional[AgentScheduleInput] + max_iterations: Optional[int] + user_timezone: Optional[str] + knowledge: Optional[int] + + diff --git a/superagi/controllers/webhook.py b/superagi/controllers/webhook.py new file mode 100644 index 000000000..0a55bd216 --- /dev/null +++ b/superagi/controllers/webhook.py @@ -0,0 +1,60 @@ +from datetime import datetime + +from fastapi import APIRouter +from fastapi import Depends +from fastapi_jwt_auth import AuthJWT +from fastapi_sqlalchemy import db +from pydantic import BaseModel + +# from superagi.types.db import AgentOut, AgentIn +from superagi.helper.auth import check_auth, get_user_organisation +from superagi.models.webhooks import Webhooks + +router = APIRouter() + + +class WebHookIn(BaseModel): + name: str + url: str + headers: dict + + class Config: + orm_mode = True + + +class WebHookOut(BaseModel): + id: int + org_id: int + name: str + url: str + headers: dict + is_deleted: bool + created_at: datetime + updated_at: datetime + + class Config: + orm_mode = True + + +# CRUD Operations +@router.post("/add", response_model=WebHookOut, status_code=201) +def create_webhook(webhook: WebHookIn, Authorize: AuthJWT = Depends(check_auth), + organisation=Depends(get_user_organisation)): + """ + Creates a new webhook + + Args: + + Returns: + Agent: An object of Agent representing the created Agent. + + Raises: + HTTPException (Status Code=404): If the associated project is not found. + """ + db_webhook = Webhooks(name=webhook.name, url=webhook.url, headers=webhook.headers, org_id=organisation.id, + is_deleted=False) + db.session.add(db_webhook) + db.session.commit() + db.session.flush() + + return db_webhook diff --git a/superagi/helper/auth.py b/superagi/helper/auth.py index cc02d5643..f916a3ae5 100644 --- a/superagi/helper/auth.py +++ b/superagi/helper/auth.py @@ -1,10 +1,14 @@ -from fastapi import Depends, HTTPException +from fastapi import Depends, HTTPException, Header, Security, status +from fastapi.security import APIKeyHeader from fastapi_jwt_auth import AuthJWT from fastapi_sqlalchemy import db - +from fastapi.security.api_key import APIKeyHeader from superagi.config.config import get_config from superagi.models.organisation import Organisation from superagi.models.user import User +from superagi.models.api_key import ApiKey +from typing import Optional +from sqlalchemy import or_ def check_auth(Authorize: AuthJWT = Depends()): @@ -39,6 +43,7 @@ def get_user_organisation(Authorize: AuthJWT = Depends(check_auth)): organisation = db.session.query(Organisation).filter(Organisation.id == user.organisation_id).first() return organisation + def get_current_user(Authorize: AuthJWT = Depends(check_auth)): env = get_config("ENV", "DEV") @@ -50,4 +55,32 @@ def get_current_user(Authorize: AuthJWT = Depends(check_auth)): # Query the User table to find the user by their email user = db.session.query(User).filter(User.email == email).first() - return user \ No newline at end of file + return user + + +api_key_header = APIKeyHeader(name="X-API-Key") + + +def validate_api_key(api_key: str = Security(api_key_header)) -> str: + query_result = db.session.query(ApiKey).filter(ApiKey.key == api_key, + or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first() + if query_result is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing API Key", + ) + + return query_result.key + + +def get_organisation_from_api_key(api_key: str = Security(api_key_header)) -> Organisation: + query_result = db.session.query(ApiKey).filter(ApiKey.key == api_key, + or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first() + if query_result is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing API Key", + ) + + organisation = db.session.query(Organisation).filter(Organisation.id == query_result.org_id).first() + return organisation \ No newline at end of file diff --git a/superagi/helper/s3_helper.py b/superagi/helper/s3_helper.py index 2b2d0ba1f..2669b77ed 100644 --- a/superagi/helper/s3_helper.py +++ b/superagi/helper/s3_helper.py @@ -5,10 +5,9 @@ from superagi.config.config import get_config from superagi.lib.logger import logger +from urllib.parse import unquote import json - - class S3Helper: def __init__(self, bucket_name = get_config("BUCKET_NAME")): """ @@ -113,4 +112,27 @@ def upload_file_content(self, content, file_path): try: self.s3.put_object(Bucket=self.bucket_name, Key=file_path, Body=content) except: - raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.") \ No newline at end of file + raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.") + + def get_download_url_of_resources(self,db_resources_arr): + s3 = boto3.client( + 's3', + aws_access_key_id=get_config("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=get_config("AWS_SECRET_ACCESS_KEY"), + ) + response_obj={} + for db_resource in db_resources_arr: + response = self.s3.get_object(Bucket=get_config("BUCKET_NAME"), Key=db_resource.path) + content = response["Body"].read() + bucket_name = get_config("INSTAGRAM_TOOL_BUCKET_NAME") + file_name=db_resource.path.split('/')[-1] + file_name=''.join(char for char in file_name if char != "`") + object_key=f"public_resources/run_id{db_resource.agent_execution_id}/{file_name}" + s3.put_object(Bucket=bucket_name, Key=object_key, Body=content) + file_url = f"https://{bucket_name}.s3.amazonaws.com/{object_key}" + resource_execution_id=db_resource.agent_execution_id + if resource_execution_id in response_obj: + response_obj[resource_execution_id].append(file_url) + else: + response_obj[resource_execution_id]=[file_url] + return response_obj \ No newline at end of file diff --git a/superagi/helper/webhook_manager.py b/superagi/helper/webhook_manager.py new file mode 100644 index 000000000..cf5e988d2 --- /dev/null +++ b/superagi/helper/webhook_manager.py @@ -0,0 +1,37 @@ +from superagi.models.agent import Agent +from superagi.models.agent_execution import AgentExecution +from superagi.models.webhooks import Webhooks +from superagi.models.webhook_events import WebhookEvents +import requests +import json +from superagi.lib.logger import logger +class WebHookManager: + def __init__(self,session): + self.session=session + + def agent_status_change_callback(self, agent_execution_id, curr_status, old_status): + if curr_status=="CREATED" or agent_execution_id is None: + return + agent_id=AgentExecution.get_agent_execution_from_id(self.session,agent_execution_id).agent_id + agent=Agent.get_agent_from_id(self.session,agent_id) + org=agent.get_agent_organisation(self.session) + org_webhooks=self.session.query(Webhooks).filter(Webhooks.org_id == org.id).all() + + for webhook_obj in org_webhooks: + webhook_obj_body={"agent_id":agent_id,"org_id":org.id,"event":f"{old_status} to {curr_status}"} + error=None + request=None + status='sent' + try: + request = requests.post(webhook_obj.url.strip(), data=json.dumps(webhook_obj_body), headers=webhook_obj.headers) + except Exception as e: + logger.error(f"Exception occured in webhooks {e}") + error=str(e) + if request is not None and request.status_code not in [200,201] and error is None: + error=request.text + if error is not None: + status='Error' + webhook_event=WebhookEvents(agent_id=agent_id, run_id=agent_execution_id, event=f"{old_status} to {curr_status}", status=status, errors=error) + self.session.add(webhook_event) + self.session.commit() + diff --git a/superagi/models/agent.py b/superagi/models/agent.py index 38a8796fc..49bd93f9d 100644 --- a/superagi/models/agent.py +++ b/superagi/models/agent.py @@ -4,8 +4,10 @@ import json from sqlalchemy import Column, Integer, String, Boolean +from sqlalchemy import or_ from superagi.lib.logger import logger +from superagi.models.agent_config import AgentConfiguration from superagi.models.agent_template import AgentTemplate from superagi.models.agent_template_config import AgentTemplateConfig # from superagi.models import AgentConfiguration @@ -13,7 +15,7 @@ from superagi.models.organisation import Organisation from superagi.models.project import Project from superagi.models.workflows.agent_workflow import AgentWorkflow -from superagi.models.agent_config import AgentConfiguration + class Agent(DBBaseModel): """ @@ -35,8 +37,8 @@ class Agent(DBBaseModel): project_id = Column(Integer) description = Column(String) agent_workflow_id = Column(Integer) - is_deleted = Column(Boolean, default = False) - + is_deleted = Column(Boolean, default=False) + def __repr__(self): """ Returns a string representation of the Agent object. @@ -47,8 +49,8 @@ def __repr__(self): """ return f"Agent(id={self.id}, name='{self.name}', project_id={self.project_id}, " \ f"description='{self.description}', agent_workflow_id={self.agent_workflow_id}," \ - f"is_deleted='{self.is_deleted}')" - + f"is_deleted='{self.is_deleted}')" + @classmethod def fetch_configuration(cls, session, agent_id: int): """ @@ -105,7 +107,8 @@ def eval_agent_config(cls, key, value): """ - if key in ["name", "description", "exit", "model", "permission_type", "LTM_DB", "resource_summary", "knowledge"]: + if key in ["name", "description", "agent_type", "exit", "model", "permission_type", "LTM_DB", + "resource_summary", "knowledge"]: return value elif key in ["project_id", "memory_window", "max_iterations", "iteration_interval"]: return int(value) @@ -150,7 +153,6 @@ def create_agent_with_config(cls, db, agent_with_config): # AgentWorkflow.name == "Fixed Task Queue").first() # db_agent.agent_workflow_id = agent_workflow.id - db.session.commit() # Create Agent Configuration @@ -291,3 +293,9 @@ def find_org_by_agent_id(cls, session, agent_id: int): agent = session.query(Agent).filter_by(id=agent_id).first() project = session.query(Project).filter(Project.id == agent.project_id).first() return session.query(Organisation).filter(Organisation.id == project.organisation_id).first() + + @classmethod + def get_active_agent_by_id(cls, session, agent_id: int): + db_agent = session.query(Agent).filter(Agent.id == agent_id, + or_(Agent.is_deleted == False, Agent.is_deleted is None)).first() + return db_agent diff --git a/superagi/models/agent_execution.py b/superagi/models/agent_execution.py index f95afb85b..5cba5f509 100644 --- a/superagi/models/agent_execution.py +++ b/superagi/models/agent_execution.py @@ -164,4 +164,49 @@ def assign_next_step_id(cls, session, agent_execution_id: int, next_step_id: int if next_step.action_type == "ITERATION_WORKFLOW": trigger_step = IterationWorkflow.fetch_trigger_step_id(session, next_step.action_reference_id) agent_execution.iteration_workflow_step_id = trigger_step.id - session.commit() \ No newline at end of file + session.commit() + + @classmethod + def get_execution_by_agent_id_and_status(cls, session, agent_id: int, status_filter: str): + db_agent_execution = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id, AgentExecution.status == status_filter).first() + return db_agent_execution + + + @classmethod + def get_all_executions_by_status_and_agent_id(cls, session, agent_id, execution_state_change_input, current_status: str): + db_execution_arr = [] + if execution_state_change_input.run_ids is not None: + db_execution_arr = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id, AgentExecution.status == current_status,AgentExecution.id.in_(execution_state_change_input.run_ids)).all() + else: + db_execution_arr = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id, AgentExecution.status == current_status).all() + return db_execution_arr + + @classmethod + def get_all_executions_by_filter_config(cls, session, agent_id: int, filter_config): + db_execution_query = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id) + if filter_config.run_ids is not None: + db_execution_query = db_execution_query.filter(AgentExecution.id.in_(filter_config.run_ids)) + + if filter_config.run_status_filter is not None and filter_config.run_status_filter in ["CREATED", "RUNNING", + "PAUSED", "COMPLETED", + "TERMINATED"]: + db_execution_query = db_execution_query.filter(AgentExecution.status == filter_config.run_status_filter) + + db_execution_arr = db_execution_query.all() + return db_execution_arr + + @classmethod + def validate_run_ids(cls, session, run_ids: list, organisation_id: int): + from superagi.models.agent import Agent + from superagi.models.project import Project + + run_ids=list(set(run_ids)) + agent_ids=session.query(AgentExecution.agent_id).filter(AgentExecution.id.in_(run_ids)).distinct().all() + agent_ids = [id for (id,) in agent_ids] + project_ids=session.query(Agent.project_id).filter(Agent.id.in_(agent_ids)).distinct().all() + project_ids = [id for (id,) in project_ids] + org_ids=session.query(Project.organisation_id).filter(Project.id.in_(project_ids)).distinct().all() + org_ids = [id for (id,) in org_ids] + + if len(org_ids) > 1 or org_ids[0] != organisation_id: + raise Exception(f"one or more run IDs not found") diff --git a/superagi/models/agent_schedule.py b/superagi/models/agent_schedule.py index 6415c375c..32e8dcd84 100644 --- a/superagi/models/agent_schedule.py +++ b/superagi/models/agent_schedule.py @@ -45,4 +45,27 @@ def __repr__(self): f"expiry_date={self.expiry_date}, " \ f"expiry_runs={self.expiry_runs}), " \ f"current_runs={self.expiry_runs}), " \ - f"status={self.status}), " \ No newline at end of file + f"status={self.status}), " + + @classmethod + def save_schedule_from_config(cls, session, db_agent, schedule_config: AgentScheduleInput): + agent_schedule = AgentSchedule( + agent_id=db_agent.id, + start_time=schedule_config.start_time, + next_scheduled_time=schedule_config.start_time, + recurrence_interval=schedule_config.recurrence_interval, + expiry_date=schedule_config.expiry_date, + expiry_runs=schedule_config.expiry_runs, + current_runs=0, + status="SCHEDULED" + ) + + agent_schedule.agent_id = db_agent.id + session.add(agent_schedule) + session.commit() + return agent_schedule + + @classmethod + def find_by_agent_id(cls, session, agent_id: int): + db_schedule=session.query(AgentSchedule).filter(AgentSchedule.agent_id == agent_id).first() + return db_schedule diff --git a/superagi/models/api_key.py b/superagi/models/api_key.py new file mode 100644 index 000000000..1cc3e310a --- /dev/null +++ b/superagi/models/api_key.py @@ -0,0 +1,46 @@ +from sqlalchemy import Column, Integer, Text, String, Boolean, ForeignKey +from sqlalchemy.orm import relationship +from superagi.models.base_model import DBBaseModel +from superagi.models.agent_execution import AgentExecution +from sqlalchemy import func, or_ + +class ApiKey(DBBaseModel): + """ + + Attributes: + + + Methods: + """ + __tablename__ = 'api_keys' + + id = Column(Integer, primary_key=True) + org_id = Column(Integer) + name = Column(String) + key = Column(String) + is_expired= Column(Boolean) + + @classmethod + def get_by_org_id(cls, session, org_id: int): + db_api_keys=session.query(ApiKey).filter(ApiKey.org_id==org_id,or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).all() + return db_api_keys + + @classmethod + def get_by_id(cls, session, id: int): + db_api_key=session.query(ApiKey).filter(ApiKey.id==id,or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first() + return db_api_key + + @classmethod + def delete_by_id(cls, session,id: int): + db_api_key = session.query(ApiKey).filter(ApiKey.id == id).first() + db_api_key.is_expired = True + session.commit() + session.flush() + + @classmethod + def update_api_key(cls, session, id: int, name: str): + db_api_key = session.query(ApiKey).filter(ApiKey.id == id).first() + db_api_key.name = name + session.commit() + session.flush() + diff --git a/superagi/models/project.py b/superagi/models/project.py index 1de1eb624..c798e1d1a 100644 --- a/superagi/models/project.py +++ b/superagi/models/project.py @@ -55,3 +55,13 @@ def find_or_create_default_project(cls, session, organisation_id): else: default_project = project return default_project + + @classmethod + def find_by_org_id(cls, session, org_id: int): + project = session.query(Project).filter(Project.organisation_id == org_id).first() + return project + + @classmethod + def find_by_id(cls, session, project_id: int): + project = session.query(Project).filter(Project.id == project_id).first() + return project \ No newline at end of file diff --git a/superagi/models/resource.py b/superagi/models/resource.py index 926123e47..78713b690 100644 --- a/superagi/models/resource.py +++ b/superagi/models/resource.py @@ -58,6 +58,11 @@ def validate_resource_type(storage_type): if storage_type not in valid_types: raise InvalidResourceType("Invalid resource type") - + + @classmethod + def find_by_run_ids(cls, session, run_ids: list): + db_resources_arr=session.query(Resource).filter(Resource.agent_execution_id.in_(run_ids)).all() + return db_resources_arr + class InvalidResourceType(Exception): """Custom exception for invalid resource type""" diff --git a/superagi/models/toolkit.py b/superagi/models/toolkit.py index 2c89a38ab..5a9c0a0e9 100644 --- a/superagi/models/toolkit.py +++ b/superagi/models/toolkit.py @@ -138,3 +138,26 @@ def fetch_tool_ids_from_toolkit(cls, session, toolkit_ids): if tool is not None: agent_toolkit_tools.append(tool.id) return agent_toolkit_tools + + @classmethod + def get_tool_and_toolkit_arr(cls, session, agent_config_tools_arr: list): + from superagi.models.tool import Tool + toolkits_arr= set() + tools_arr= set() + for tool_obj in agent_config_tools_arr: + toolkit=session.query(Toolkit).filter(Toolkit.name == tool_obj["name"].strip()).first() + if toolkit is None: + raise Exception("One or more of the Tool(s)/Toolkit(s) does not exist.") + toolkits_arr.add(toolkit.id) + if tool_obj.get("tools"): + for tool_name_str in tool_obj["tools"]: + tool_db_obj=session.query(Tool).filter(Tool.name == tool_name_str.strip()).first() + if tool_db_obj is None: + raise Exception("One or more of the Tool(s)/Toolkit(s) does not exist.") + + tools_arr.add(tool_db_obj.id) + else: + tools=Tool.get_toolkit_tools(session, toolkit.id) + for tool_db_obj in tools: + tools_arr.add(tool_db_obj.id) + return list(tools_arr) diff --git a/superagi/models/webhook_events.py b/superagi/models/webhook_events.py new file mode 100644 index 000000000..af90ea492 --- /dev/null +++ b/superagi/models/webhook_events.py @@ -0,0 +1,25 @@ +from sqlalchemy import Column, Integer, Text, String, Boolean, ForeignKey +from sqlalchemy.orm import relationship +from superagi.models.base_model import DBBaseModel +from superagi.models.agent_execution import AgentExecution + + +class WebhookEvents(DBBaseModel): + """ + + Attributes: + + + Methods: + """ + __tablename__ = 'webhook_events' + + id = Column(Integer, primary_key=True) + agent_id=Column(Integer) + run_id = Column(Integer) + event = Column(String) + status = Column(String) + errors= Column(Text) + + + diff --git a/superagi/models/webhooks.py b/superagi/models/webhooks.py new file mode 100644 index 000000000..14d683472 --- /dev/null +++ b/superagi/models/webhooks.py @@ -0,0 +1,22 @@ +from sqlalchemy import Column, Integer, Text, String, Boolean, ForeignKey,JSON +from sqlalchemy.orm import relationship +from sqlalchemy.dialects.postgresql import JSONB +from superagi.models.base_model import DBBaseModel +from superagi.models.agent_execution import AgentExecution + +class Webhooks(DBBaseModel): + """ + + Attributes: + + + Methods: + """ + __tablename__ = 'webhooks' + + id = Column(Integer, primary_key=True) + name=Column(String) + org_id = Column(Integer) + url = Column(String) + headers=Column(JSON) + is_deleted=Column(Boolean) diff --git a/superagi/vector_store/vector_factory.py b/superagi/vector_store/vector_factory.py index 5d02c7f74..c67f8888b 100644 --- a/superagi/vector_store/vector_factory.py +++ b/superagi/vector_store/vector_factory.py @@ -94,11 +94,10 @@ def build_vector_storage(cls, vector_store: VectorStoreType, index_name, embeddi return qdrant.Qdrant(client, embedding_model, index_name) except: raise ValueError("Qdrant API key not found") - + if vector_store == VectorStoreType.WEAVIATE: try: client = weaviate.create_weaviate_client(creds["url"], creds["api_key"]) return weaviate.Weaviate(client, embedding_model, index_name) except: raise ValueError("Weaviate API key not found") - \ No newline at end of file diff --git a/superagi/worker.py b/superagi/worker.py index 9e41c2fd8..b6e5ea231 100644 --- a/superagi/worker.py +++ b/superagi/worker.py @@ -15,6 +15,10 @@ from superagi.models.db import connect_db from superagi.types.model_source_types import ModelSourceType +from sqlalchemy import event +from superagi.models.agent_execution import AgentExecution +from superagi.helper.webhook_manager import WebHookManager + redis_url = get_config('REDIS_URL', 'super__redis:6379') app = Celery("superagi", include=["superagi.worker"], imports=["superagi.worker"]) @@ -32,9 +36,16 @@ } app.conf.beat_schedule = beat_schedule +@event.listens_for(AgentExecution.status, "set") +def agent_status_change(target, val,old_val,initiator): + if get_config("IN_TESTING",False): + webhook_callback.delay(target.id,val,old_val) + + @app.task(name="initialize-schedule-agent", autoretry_for=(Exception,), retry_backoff=2, max_retries=5) def initialize_schedule_agent_task(): """Executing agent scheduling in the background.""" + schedule_helper = AgentScheduleHelper() schedule_helper.update_next_scheduled_time() schedule_helper.run_scheduled_agents() @@ -49,7 +60,7 @@ def execute_agent(agent_execution_id: int, time): AgentExecutor().execute_next_step(agent_execution_id=agent_execution_id) -@app.task(name="summarize_resource", autoretry_for=(Exception,), retry_backoff=2, max_retries=5, serializer='pickle') +@app.task(name="summarize_resource", autoretry_for=(Exception,), retry_backoff=2, max_retries=5,serializer='pickle') def summarize_resource(agent_id: int, resource_id: int): """Summarize a resource in background.""" from superagi.resource_manager.resource_summary import ResourceSummarizer @@ -77,3 +88,11 @@ def summarize_resource(agent_id: int, resource_id: int): resource_summarizer.add_to_vector_store_and_create_summary(resource_id=resource_id, documents=documents) session.close() + +@app.task(name="webhook_callback", autoretry_for=(Exception,), retry_backoff=2, max_retries=5,serializer='pickle') +def webhook_callback(agent_execution_id,val,old_val): + engine = connect_db() + Session = sessionmaker(bind=engine) + with Session() as session: + WebHookManager(session).agent_status_change_callback(agent_execution_id, val, old_val) + diff --git a/tests/unit_tests/controllers/api/__init__.py b/tests/unit_tests/controllers/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit_tests/controllers/api/test_agent.py b/tests/unit_tests/controllers/api/test_agent.py new file mode 100644 index 000000000..99fda13b2 --- /dev/null +++ b/tests/unit_tests/controllers/api/test_agent.py @@ -0,0 +1,220 @@ +import pytest +from fastapi.testclient import TestClient +from fastapi import HTTPException + +import superagi.config.config +from unittest.mock import MagicMock, patch,Mock +from main import app +from unittest.mock import patch,create_autospec +from sqlalchemy.orm import Session +from superagi.controllers.api.agent import ExecutionStateChangeConfigIn,AgentConfigUpdateExtInput +from superagi.models.agent import Agent +from superagi.models.project import Project + +client = TestClient(app) + +@pytest.fixture +def mock_api_key_get(): + mock_api_key = "your_mock_api_key" + return mock_api_key +@pytest.fixture +def mock_execution_state_change_input(): + return { + + } +@pytest.fixture +def mock_run_id_config(): + return { + "run_ids":[1,2] + } + +@pytest.fixture +def mock_agent_execution(): + return { + + } +@pytest.fixture +def mock_run_id_config_empty(): + return { + "run_ids":[] + } + +@pytest.fixture +def mock_run_id_config_invalid(): + return { + "run_ids":[12310] + } +@pytest.fixture +def mock_agent_config_update_ext_input(): + return AgentConfigUpdateExtInput( + tools=[{"name":"Image Generation Toolkit"}], + schedule=None, + goal=["Test Goal"], + instruction=["Test Instruction"], + constraints=["Test Constraints"], + iteration_interval=10, + model="Test Model", + max_iterations=100, + agent_type="Test Agent Type" + ) + +@pytest.fixture +def mock_update_agent_config(): + return { + "name": "agent_3_UPDATED", + "description": "AI assistant to solve complex problems", + "goal": ["create a photo of a cat"], + "agent_type": "Dynamic Task Workflow", + "constraints": [ + "~4000 word limit for short term memory.", + "Your long term memory is short, so immediately save important information to files.", + "If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.", + "No user assistance", + "Exclusively use the commands listed in double quotes e.g. \"command name\"" + ], + "instruction": ["Be accurate"], + "tools":[ + { + "name":"Image Generation Toolkit" + } + ], + "iteration_interval": 500, + "model": "gpt-4", + "max_iterations": 100 + } +# Define test cases + +def test_update_agent_not_found(mock_update_agent_config,mock_api_key_get): + with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \ + patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \ + patch('superagi.helper.auth.db') as mock_auth_db, \ + patch('superagi.controllers.api.agent.db') as db_mock: + + # Mock the session + mock_session = create_autospec(Session) + # # Configure session query methods to return None for agent + mock_session.query.return_value.filter.return_value.first.return_value = None + response = client.put( + "/v1/agent/1", + headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers + json=mock_update_agent_config + ) + assert response.status_code == 404 + assert response.text == '{"detail":"Agent not found"}' + + +def test_get_run_resources_no_run_ids(mock_run_id_config_empty,mock_api_key_get): + with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \ + patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \ + patch('superagi.helper.auth.db') as mock_auth_db, \ + patch('superagi.controllers.api.agent.db') as db_mock, \ + patch('superagi.controllers.api.agent.get_config', return_value="S3") as mock_get_config: + + # Mock the session + mock_session = create_autospec(Session) + # # Configure session query methods to return None for agent + mock_session.query.return_value.filter.return_value.first.return_value = None + response = client.post( + "v1/agent/resources/output", + headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers + json=mock_run_id_config_empty + ) + assert response.status_code == 404 + assert response.text == '{"detail":"No execution_id found"}' + +def test_get_run_resources_invalid_run_ids(mock_run_id_config_invalid,mock_api_key_get): + with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \ + patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \ + patch('superagi.helper.auth.db') as mock_auth_db, \ + patch('superagi.controllers.api.agent.db') as db_mock, \ + patch('superagi.controllers.api.agent.get_config', return_value="S3") as mock_get_config: + + # Mock the session + mock_session = create_autospec(Session) + # # Configure session query methods to return None for agent + mock_session.query.return_value.filter.return_value.first.return_value = None + response = client.post( + "v1/agent/resources/output", + headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers + json=mock_run_id_config_invalid + ) + assert response.status_code == 404 + assert response.text == '{"detail":"One or more run_ids not found"}' + +def test_resume_agent_runs_agent_not_found(mock_execution_state_change_input,mock_api_key_get): + with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \ + patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \ + patch('superagi.helper.auth.db') as mock_auth_db, \ + patch('superagi.controllers.api.agent.db') as db_mock: + + # Mock the session + mock_session = create_autospec(Session) + # # Configure session query methods to return None for agent + mock_session.query.return_value.filter.return_value.first.return_value = None + response = client.post( + "/v1/agent/1/resume", + headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers + json=mock_execution_state_change_input + ) + assert response.status_code == 404 + assert response.text == '{"detail":"Agent not found"}' + + +def test_pause_agent_runs_agent_not_found(mock_execution_state_change_input,mock_api_key_get): + with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \ + patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \ + patch('superagi.helper.auth.db') as mock_auth_db, \ + patch('superagi.controllers.api.agent.db') as db_mock: + + # Mock the session + mock_session = create_autospec(Session) + # # Configure session query methods to return None for agent + mock_session.query.return_value.filter.return_value.first.return_value = None + response = client.post( + "/v1/agent/1/pause", + headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers + json=mock_execution_state_change_input + ) + assert response.status_code == 404 + assert response.text == '{"detail":"Agent not found"}' + +def test_create_run_agent_not_found(mock_agent_execution,mock_api_key_get): + with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \ + patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \ + patch('superagi.helper.auth.db') as mock_auth_db, \ + patch('superagi.controllers.api.agent.db') as db_mock: + + # Mock the session + mock_session = create_autospec(Session) + # # Configure session query methods to return None for agent + mock_session.query.return_value.filter.return_value.first.return_value = None + response = client.post( + "/v1/agent/1/run", + headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers + json=mock_agent_execution + ) + assert response.status_code == 404 + assert response.text == '{"detail":"Agent not found"}' + +def test_create_run_project_not_matching_org(mock_agent_execution, mock_api_key_get): + with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \ + patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \ + patch('superagi.helper.auth.db') as mock_auth_db, \ + patch('superagi.controllers.api.agent.db') as db_mock: + + # Mock the session and configure query methods to return agent and project + mock_session = create_autospec(Session) + mock_agent = Agent(id=1, project_id=1, agent_workflow_id=1) + mock_session.query.return_value.filter.return_value.first.return_value = mock_agent + mock_project = Project(id=1, organisation_id=2) # Different organisation ID + db_mock.Project.find_by_id.return_value = mock_project + db_mock.session.return_value.__enter__.return_value = mock_session + + response = client.post( + "/v1/agent/1/run", + headers={"X-API-Key": mock_api_key_get}, + json=mock_agent_execution + ) + + assert response.status_code == 404 + assert response.text == '{"detail":"Agent not found"}' diff --git a/tests/unit_tests/models/test_agent.py b/tests/unit_tests/models/test_agent.py index da7e6c4c1..18d614ce0 100644 --- a/tests/unit_tests/models/test_agent.py +++ b/tests/unit_tests/models/test_agent.py @@ -22,6 +22,27 @@ def test_get_agent_from_id(): # Assert that the returned agent object matches the mock agent assert agent == mock_agent + +def test_get_active_agent_by_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample agent ID + agent_id = 1 + + # Create a mock agent object to be returned by the session query + mock_agent = Agent(id=agent_id, name="Test Agent", project_id=1, description="Agent for testing",is_deleted=False) + + # Configure the session query to return the mock agent + session.query.return_value.filter.return_value.first.return_value = mock_agent + + # Call the method under test + agent = Agent.get_active_agent_by_id(session, agent_id) + + # Assert that the returned agent object matches the mock agent + assert agent == mock_agent + assert agent.is_deleted == False + def test_eval_tools_key(): key = "tools" value = "[1, 2, 3]" diff --git a/tests/unit_tests/models/test_agent_execution.py b/tests/unit_tests/models/test_agent_execution.py index a2c91e581..3ecbdc84f 100644 --- a/tests/unit_tests/models/test_agent_execution.py +++ b/tests/unit_tests/models/test_agent_execution.py @@ -9,8 +9,6 @@ from superagi.models.agent_execution import AgentExecution from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep from superagi.models.workflows.iteration_workflow import IterationWorkflow - - def test_get_agent_execution_from_id(): # Create a mock session session = create_autospec(Session) @@ -89,3 +87,28 @@ def test_assign_next_step_id(mock_session, mocker): # Check that the attributes were updated assert mock_execution.current_agent_step_id == 2 assert mock_execution.iteration_workflow_step_id == 3 + +def test_get_execution_by_agent_id_and_status(): + session = create_autospec(Session) + + # Create a sample agent execution ID + agent_execution_id = 1 + + # Create a mock agent execution object to be returned by the session query + mock_agent_execution = AgentExecution(id=agent_execution_id, name="Test Execution", status="RUNNING") + + # Configure the session query to return the mock agent + session.query.return_value.filter.return_value.first.return_value = mock_agent_execution + + # Call the method under test + agent_execution = AgentExecution.get_execution_by_agent_id_and_status(session, agent_execution_id,"RUNNING") + + # Assert that the returned agent object matches the mock agent + assert agent_execution == mock_agent_execution + assert agent_execution.status == "RUNNING" + +@pytest.fixture +def mock_session(mocker): + return mocker.MagicMock() + + diff --git a/tests/unit_tests/models/test_agent_schedule.py b/tests/unit_tests/models/test_agent_schedule.py new file mode 100644 index 000000000..2f5f84355 --- /dev/null +++ b/tests/unit_tests/models/test_agent_schedule.py @@ -0,0 +1,23 @@ +from unittest.mock import create_autospec + +from sqlalchemy.orm import Session +from superagi.models.agent_schedule import AgentSchedule + +def test_find_by_agent_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample agent ID + agent_id = 1 + + # Create a mock agent schedule object to be returned by the session query + mock_agent_schedule = AgentSchedule(id=1,agent_id=agent_id, start_time="2023-08-10 12:17:00", recurrence_interval="2 Minutes", expiry_runs=2) + + # Configure the session query to return the mock agent + session.query.return_value.filter.return_value.first.return_value = mock_agent_schedule + + # Call the method under test + agent_schedule = AgentSchedule.find_by_agent_id(session, agent_id) + + # Assert that the returned agent object matches the mock agent + assert agent_schedule == mock_agent_schedule diff --git a/tests/unit_tests/models/test_api_key.py b/tests/unit_tests/models/test_api_key.py new file mode 100644 index 000000000..fe5752f9d --- /dev/null +++ b/tests/unit_tests/models/test_api_key.py @@ -0,0 +1,93 @@ +from unittest.mock import create_autospec + +from sqlalchemy.orm import Session +from superagi.models.api_key import ApiKey + +def test_get_by_org_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample organization ID + org_id = 1 + + # Create a mock ApiKey object to be returned by the session query + mock_api_keys = [ + ApiKey(id=1, org_id=org_id, key="key1", is_expired=False), + ApiKey(id=2, org_id=org_id, key="key2", is_expired=False), + ] + + # Configure the session query to return the mock api keys + session.query.return_value.filter.return_value.all.return_value = mock_api_keys + + # Call the method under test + api_keys = ApiKey.get_by_org_id(session, org_id) + + # Assert that the returned api keys match the mock api keys + assert api_keys == mock_api_keys + + +def test_get_by_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample api key ID + api_key_id = 1 + + # Create a mock ApiKey object to be returned by the session query + mock_api_key = ApiKey(id=api_key_id, org_id=1, key="key1", is_expired=False) + + # Configure the session query to return the mock api key + session.query.return_value.filter.return_value.first.return_value = mock_api_key + + # Call the method under test + api_key = ApiKey.get_by_id(session, api_key_id) + + # Assert that the returned api key matches the mock api key + assert api_key == mock_api_key + +def test_delete_by_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample api key ID + api_key_id = 1 + + # Create a mock ApiKey object to be returned by the session query + mock_api_key = ApiKey(id=api_key_id, org_id=1, key="key1", is_expired=False) + + # Configure the session query to return the mock api key + session.query.return_value.filter.return_value.first.return_value = mock_api_key + + # Call the method under test + ApiKey.delete_by_id(session, api_key_id) + + # Assert that the api key's is_expired attribute is set to True + assert mock_api_key.is_expired == True + + # Assert that the session.commit and session.flush methods were called + session.commit.assert_called_once() + session.flush.assert_called_once() + +def test_edit_by_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample api key ID and new name + api_key_id = 1 + new_name = "New Name" + + # Create a mock ApiKey object to be returned by the session query + mock_api_key = ApiKey(id=api_key_id, org_id=1, key="key1", is_expired=False) + + # Configure the session query to return the mock api key + session.query.return_value.filter.return_value.first.return_value = mock_api_key + + # Call the method under test + ApiKey.update_api_key(session, api_key_id, new_name) + + # Assert that the api key's name attribute is updated + assert mock_api_key.name == new_name + + # Assert that the session.commit and session.flush methods were called + session.commit.assert_called_once() + session.flush.assert_called_once() \ No newline at end of file diff --git a/tests/unit_tests/models/test_project.py b/tests/unit_tests/models/test_project.py new file mode 100644 index 000000000..9ac868217 --- /dev/null +++ b/tests/unit_tests/models/test_project.py @@ -0,0 +1,42 @@ +from unittest.mock import create_autospec + +from sqlalchemy.orm import Session +from superagi.models.project import Project + +def test_find_by_org_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample org ID + org_id = 123 + + # Create a mock project object to be returned by the session query + mock_project = Project(id=1, name="Test Project", organisation_id=org_id, description="Project for testing") + + # Configure the session query to return the mock project + session.query.return_value.filter.return_value.first.return_value = mock_project + + # Call the method under test + project = Project.find_by_org_id(session, org_id) + + # Assert that the returned project object matches the mock project + assert project == mock_project + +def test_find_by_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample project ID + project_id = 123 + + # Create a mock project object to be returned by the session query + mock_project = Project(id=project_id, name="Test Project", organisation_id=1, description="Project for testing") + + # Configure the session query to return the mock project + session.query.return_value.filter.return_value.first.return_value = mock_project + + # Call the method under test + project = Project.find_by_id(session, project_id) + + # Assert that the returned project object matches the mock project + assert project == mock_project \ No newline at end of file diff --git a/tests/unit_tests/models/test_toolkit.py b/tests/unit_tests/models/test_toolkit.py index 82302d97e..339c970c9 100644 --- a/tests/unit_tests/models/test_toolkit.py +++ b/tests/unit_tests/models/test_toolkit.py @@ -1,10 +1,11 @@ -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, patch, call,create_autospec,Mock import pytest from superagi.models.organisation import Organisation from superagi.models.toolkit import Toolkit from superagi.models.tool import Tool +from sqlalchemy.orm import Session @pytest.fixture def mock_session(): @@ -243,3 +244,23 @@ def test_fetch_tool_ids_from_toolkit(mock_tool, mock_session): # Assert assert result == [mock_tool.id for _ in toolkit_ids] + +def test_get_tool_and_toolkit_arr_with_nonexistent_toolkit(): + # Create a mock session + session = create_autospec(Session) + + # Configure the session query to return None for toolkit + session.query.return_value.filter.return_value.first.return_value = None + + # Call the method under test with a non-existent toolkit + agent_config_tools_arr = [ + {"name": "NonExistentToolkit", "tools": ["Tool1", "Tool2"]}, + ] + + # Use a context manager to capture the raised exception and its message + with pytest.raises(Exception) as exc_info: + Toolkit.get_tool_and_toolkit_arr(session, agent_config_tools_arr) + + # Assert that the expected error message is contained within the raised exception message + expected_error_message = "One or more of the Tool(s)/Toolkit(s) does not exist." + assert expected_error_message in str(exc_info.value)