diff --git a/gui/pages/Content/Toolkits/ToolkitWorkspace.js b/gui/pages/Content/Toolkits/ToolkitWorkspace.js
index 51cf7897a..a724c5ecf 100644
--- a/gui/pages/Content/Toolkits/ToolkitWorkspace.js
+++ b/gui/pages/Content/Toolkits/ToolkitWorkspace.js
@@ -1,7 +1,7 @@
import React, {useEffect, useState} from 'react';
import Image from 'next/image';
import {ToastContainer, toast} from 'react-toastify';
-import {updateToolConfig, getToolConfig, authenticateGoogleCred} from "@/pages/api/DashboardService";
+import {updateToolConfig, getToolConfig, authenticateGoogleCred, authenticateTwitterCred} from "@/pages/api/DashboardService";
import styles from './Tool.module.css';
import {setLocalStorageValue, setLocalStorageArray} from "@/utils/utils";
@@ -25,6 +25,13 @@ export default function ToolkitWorkspace({toolkitDetails, internalId}){
window.location.href = `https://accounts.google.com/o/oauth2/v2/auth?client_id=${client_id}&redirect_uri=${redirect_uri}&access_type=offline&response_type=code&scope=${scope}`;
}
+ function getTwitterToken(oauth_data){
+ const oauth_token = oauth_data.oauth_token
+ const oauth_token_secret = oauth_data.oauth_token_secret
+ const authUrl = `https://api.twitter.com/oauth/authenticate?oauth_token=${oauth_token}`
+ window.location.href = authUrl
+ }
+
useEffect(() => {
if(toolkitDetails !== null) {
if (toolkitDetails.tools) {
@@ -37,7 +44,7 @@ export default function ToolkitWorkspace({toolkitDetails, internalId}){
const apiConfigs = response.data || [];
setApiConfigs(localStoredConfigs ? JSON.parse(localStoredConfigs) : apiConfigs);
})
- .catch((error) => {
+ .catch((errPor) => {
console.log('Error fetching API data:', error);
})
.finally(() => {
@@ -72,6 +79,17 @@ export default function ToolkitWorkspace({toolkitDetails, internalId}){
});
};
+ const handleTwitterAuthClick = async () => {
+ authenticateTwitterCred(toolkitDetails.id)
+ .then((response) => {
+ getTwitterToken(response.data);
+ localStorage.setItem("twitter_toolkit_id", toolkitDetails.id)
+ })
+ .catch((error) => {
+ console.error('Error fetching data: ', error);
+ });
+ };
+
useEffect(() => {
const active_tab = localStorage.getItem('toolkit_tab_' + String(internalId));
if(active_tab) {
@@ -124,6 +142,7 @@ export default function ToolkitWorkspace({toolkitDetails, internalId}){
{toolkitDetails.name === 'Google Calendar Toolkit' && }
+ {toolkitDetails.name === 'Twitter Toolkit' && }
diff --git a/gui/pages/Dashboard/Content.js b/gui/pages/Dashboard/Content.js
index 3698c6a06..e1e841213 100644
--- a/gui/pages/Dashboard/Content.js
+++ b/gui/pages/Dashboard/Content.js
@@ -7,9 +7,13 @@ import Settings from "./Settings/Settings";
import styles from './Dashboard.module.css';
import Image from "next/image";
import { EventBus } from "@/utils/eventBus";
-import {getAgents, getToolKit, getLastActiveAgent} from "@/pages/api/DashboardService";
+import {getAgents, getToolKit, getLastActiveAgent, sendTwitterCreds} from "@/pages/api/DashboardService";
import Market from "../Content/Marketplace/Market";
import AgentTemplatesList from '../Content/Agents/AgentTemplatesList';
+import { useRouter } from 'next/router';
+import querystring from 'querystring';
+import { userInfo } from 'os';
+import { parse } from 'path';
import AddTool from "@/pages/Content/Toolkits/AddTool";
import {createInternalId, removeInternalId} from "@/utils/utils";
@@ -20,6 +24,7 @@ export default function Content({env, selectedView, selectedProjectId, organisat
const [toolkits, setToolkits] = useState(null);
const tabContainerRef = useRef(null);
const [toolkitDetails, setToolkitDetails] = useState({})
+ const router = useRouter();
function fetchAgents() {
getAgents(selectedProjectId)
@@ -140,6 +145,21 @@ export default function Content({env, selectedView, selectedProjectId, organisat
}
}
}
+ const queryParams = router.asPath.split('?')[1];
+ const parsedParams = querystring.parse(queryParams);
+ parsedParams["toolkit_id"] = toolkitDetails.toolkit_id;
+ if (window.location.href.indexOf("twitter_creds") > -1){
+ const toolkit_id = localStorage.getItem("twitter_toolkit_id") || null;
+ parsedParams["toolkit_id"] = toolkit_id;
+ const params = JSON.stringify(parsedParams)
+ sendTwitterCreds(params)
+ .then((response) => {
+ console.log("Authentication completed successfully");
+ })
+ .catch((error) => {
+ console.error("Error fetching data: ",error);
+ })
+ };
}, [selectedTab]);
useEffect(() => {
@@ -171,7 +191,6 @@ export default function Content({env, selectedView, selectedProjectId, organisat
EventBus.off('openNewTab', openNewTab);
EventBus.off('reFetchAgents', fetchAgents);
EventBus.off('removeTab', removeTab);
- EventBus.off('openToolkitTab', openToolkitTab);
};
});
diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js
index aa86d1f2e..dc3ed4530 100644
--- a/gui/pages/api/DashboardService.js
+++ b/gui/pages/api/DashboardService.js
@@ -132,6 +132,14 @@ export const authenticateGoogleCred = (toolKitId) => {
return api.get(`/google/get_google_creds/toolkit_id/${toolKitId}`);
}
+export const authenticateTwitterCred = (toolKitId) => {
+ return api.get(`/twitter/get_twitter_creds/toolkit_id/${toolKitId}`);
+}
+
+export const sendTwitterCreds = (twitter_creds) => {
+ return api.post(`/twitter/send_twitter_creds/${twitter_creds}`);
+}
+
export const fetchToolTemplateList = () => {
return api.get(`/toolkits/get/list?page=0`);
}
diff --git a/main.py b/main.py
index 222e9dd4c..7d723fb2b 100644
--- a/main.py
+++ b/main.py
@@ -15,6 +15,11 @@
from sqlalchemy.orm import sessionmaker
import superagi
+import urllib.parse
+import json
+import http.client as http_client
+from superagi.helper.twitter_tokens import TwitterTokens
+from datetime import datetime, timedelta
from superagi.agent.agent_prompt_builder import AgentPromptBuilder
from superagi.config.config import get_config
from superagi.controllers.agent import router as agent_router
@@ -28,6 +33,7 @@
from superagi.controllers.config import router as config_router
from superagi.controllers.organisation import router as organisation_router
from superagi.controllers.project import router as project_router
+from superagi.controllers.twitter_oauth import router as twitter_oauth_router
from superagi.controllers.resources import router as resources_router
from superagi.controllers.tool import router as tool_router
from superagi.controllers.tool_config import router as tool_config_router
@@ -36,11 +42,13 @@
from superagi.helper.tool_helper import register_toolkits
from superagi.lib.logger import logger
from superagi.llms.openai import OpenAi
+from superagi.helper.auth import get_current_user
from superagi.models.agent_workflow import AgentWorkflow
from superagi.models.agent_workflow_step import AgentWorkflowStep
from superagi.models.organisation import Organisation
from superagi.models.tool_config import ToolConfig
from superagi.models.toolkit import Toolkit
+from superagi.models.oauth_tokens import OauthTokens
from superagi.models.types.login_request import LoginRequest
from superagi.models.user import User
@@ -97,6 +105,7 @@
app.include_router(config_router, prefix="/configs")
app.include_router(agent_template_router, prefix="/agent_templates")
app.include_router(agent_workflow_router, prefix="/agent_workflows")
+app.include_router(twitter_oauth_router, prefix="/twitter")
# in production you can use Settings management
@@ -320,7 +329,6 @@ async def google_auth_calendar(code: str = Query(...), Authorize: AuthJWT = Depe
frontend_url = superagi.config.config.get_config("FRONTEND_URL", "http://localhost:3000")
return RedirectResponse(frontend_url)
-
@app.get('/github-login')
def github_login():
"""GitHub login"""
@@ -411,7 +419,6 @@ def get_google_calendar_tool_configs(toolkit_id: int):
"client_id": google_calendar_config.value
}
-
@app.get("/validate-open-ai-key/{open_ai_key}")
async def root(open_ai_key: str, Authorize: AuthJWT = Depends()):
"""API to validate Open AI Key"""
diff --git a/migrations/versions/c5c19944c90c_create_oauth_tokens.py b/migrations/versions/c5c19944c90c_create_oauth_tokens.py
new file mode 100644
index 000000000..8986afcdc
--- /dev/null
+++ b/migrations/versions/c5c19944c90c_create_oauth_tokens.py
@@ -0,0 +1,48 @@
+"""Create Oauth Tokens
+
+Revision ID: c5c19944c90c
+Revises: 7a3e336c0fba
+Create Date: 2023-06-30 07:26:29.180784
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = 'c5c19944c90c'
+down_revision = '7a3e336c0fba'
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('oauth_tokens',
+ sa.Column('created_at', sa.DateTime(), nullable=True),
+ sa.Column('updated_at', sa.DateTime(), nullable=True),
+ sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
+ sa.Column('user_id', sa.Integer(), nullable=True),
+ sa.Column('organisation_id', sa.Integer(), nullable=True),
+ sa.Column('toolkit_id', sa.Integer(), nullable=True),
+ sa.Column('key', sa.String(), nullable=True),
+ sa.Column('value', sa.Text(), nullable=True),
+ sa.PrimaryKeyConstraint('id')
+ )
+ op.drop_index('ix_agent_execution_permissions_agent_execution_id', table_name='agent_execution_permissions')
+ op.drop_index('ix_atc_agnt_template_id_key', table_name='agent_template_configs')
+ op.drop_index('ix_agt_agnt_name', table_name='agent_templates')
+ op.drop_index('ix_agt_agnt_organisation_id', table_name='agent_templates')
+ op.drop_index('ix_agt_agnt_workflow_id', table_name='agent_templates')
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_index('ix_agt_agnt_workflow_id', 'agent_templates', ['agent_workflow_id'], unique=False)
+ op.create_index('ix_agt_agnt_organisation_id', 'agent_templates', ['organisation_id'], unique=False)
+ op.create_index('ix_agt_agnt_name', 'agent_templates', ['name'], unique=False)
+ op.create_index('ix_atc_agnt_template_id_key', 'agent_template_configs', ['agent_template_id', 'key'], unique=False)
+ op.create_index('ix_agent_execution_permissions_agent_execution_id', 'agent_execution_permissions', ['agent_execution_id'], unique=False)
+ op.drop_table('oauth_tokens')
+ # ### end Alembic commands ###
diff --git a/superagi/controllers/twitter_oauth.py b/superagi/controllers/twitter_oauth.py
new file mode 100644
index 000000000..b79b7be66
--- /dev/null
+++ b/superagi/controllers/twitter_oauth.py
@@ -0,0 +1,70 @@
+from fastapi import Depends, Query
+from fastapi import APIRouter
+from fastapi.responses import RedirectResponse
+from fastapi_jwt_auth import AuthJWT
+from fastapi_sqlalchemy import db
+from sqlalchemy.orm import sessionmaker
+
+import superagi
+import json
+from superagi.models.db import connect_db
+import http.client as http_client
+from superagi.helper.twitter_tokens import TwitterTokens
+from superagi.helper.auth import get_current_user
+from superagi.models.tool_config import ToolConfig
+from superagi.models.toolkit import Toolkit
+from superagi.models.oauth_tokens import OauthTokens
+
+router = APIRouter()
+
+@router.get('/oauth-tokens')
+async def twitter_oauth(oauth_token: str = Query(...),oauth_verifier: str = Query(...), Authorize: AuthJWT = Depends()):
+ print("///////////////////////////")
+ print(oauth_token)
+ token_uri = f'https://api.twitter.com/oauth/access_token?oauth_verifier={oauth_verifier}&oauth_token={oauth_token}'
+ conn = http_client.HTTPSConnection("api.twitter.com")
+ conn.request("POST", token_uri, "")
+ res = conn.getresponse()
+ response_data = res.read().decode('utf-8')
+ frontend_url = superagi.config.config.get_config("FRONTEND_URL", "http://localhost:3000")
+ redirect_url_success = f"{frontend_url}/twitter_creds/?{response_data}"
+ return RedirectResponse(url=redirect_url_success)
+
+@router.post("/send_twitter_creds/{twitter_creds}")
+def send_twitter_tool_configs(twitter_creds: str, Authorize: AuthJWT = Depends()):
+ engine = connect_db()
+ Session = sessionmaker(bind=engine)
+ session = Session()
+ current_user = get_current_user()
+ user_id = current_user.id
+ credentials = json.loads(twitter_creds)
+ credentials["user_id"] = user_id
+ toolkit = db.session.query(Toolkit).filter(Toolkit.id == credentials["toolkit_id"]).first()
+ api_key = db.session.query(ToolConfig).filter(ToolConfig.key == "TWITTER_API_KEY", ToolConfig.toolkit_id == credentials["toolkit_id"]).first()
+ api_key_secret = db.session.query(ToolConfig).filter(ToolConfig.key == "TWITTER_API_SECRET", ToolConfig.toolkit_id == credentials["toolkit_id"]).first()
+ final_creds = {
+ "api_key": api_key.value,
+ "api_key_secret": api_key_secret.value,
+ "oauth_token": credentials["oauth_token"],
+ "oauth_token_secret": credentials["oauth_token_secret"]
+ }
+ tokens = OauthTokens().add_or_update(session, credentials["toolkit_id"], user_id, toolkit.organisation_id, "TWITTER_OAUTH_TOKENS", str(final_creds))
+ if tokens:
+ success = True
+ else:
+ success = False
+ return success
+
+@router.get("/get_twitter_creds/toolkit_id/{toolkit_id}")
+def get_twitter_tool_configs(toolkit_id: int):
+ engine = connect_db()
+ Session = sessionmaker(bind=engine)
+ session = Session()
+ twitter_config_key = session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id,ToolConfig.key == "TWITTER_API_KEY").first()
+ twitter_config_secret = session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id,ToolConfig.key == "TWITTER_API_SECRET").first()
+ api_data = {
+ "api_key": twitter_config_key.value,
+ "api_secret": twitter_config_secret.value
+ }
+ response = TwitterTokens(session).get_request_token(api_data)
+ return response
\ No newline at end of file
diff --git a/superagi/helper/auth.py b/superagi/helper/auth.py
index 5a80677e6..a185ece83 100644
--- a/superagi/helper/auth.py
+++ b/superagi/helper/auth.py
@@ -33,6 +33,13 @@ def get_user_organisation(Authorize: AuthJWT = Depends(check_auth)):
Returns:
Organisation: Instance of Organisation class to which the authenticated user belongs.
"""
+ user = get_current_user()
+ if user is None:
+ raise HTTPException(status_code=401, detail="Unauthenticated")
+ organisation = db.session.query(Organisation).filter(Organisation.id == user.organisation_id).first()
+ return organisation
+
+def get_current_user(Authorize: AuthJWT = Depends(check_auth)):
env = get_config("ENV", "DEV")
if env == "DEV":
@@ -43,7 +50,4 @@ def get_user_organisation(Authorize: AuthJWT = Depends(check_auth)):
# Query the User table to find the user by their email
user = db.session.query(User).filter(User.email == email).first()
- if user is None:
- raise HTTPException(status_code=401, detail="Unauthenticated")
- organisation = db.session.query(Organisation).filter(Organisation.id == user.organisation_id).first()
- return organisation
\ No newline at end of file
+ return user
\ No newline at end of file
diff --git a/superagi/helper/twitter_helper.py b/superagi/helper/twitter_helper.py
new file mode 100644
index 000000000..e47c3d60d
--- /dev/null
+++ b/superagi/helper/twitter_helper.py
@@ -0,0 +1,42 @@
+import os
+import json
+import base64
+import requests
+from requests_oauthlib import OAuth1
+from requests_oauthlib import OAuth1Session
+from superagi.helper.resource_helper import ResourceHelper
+
+class TwitterHelper:
+
+ def get_media_ids(self, media_files, creds, agent_id):
+ media_ids = []
+ oauth = OAuth1(creds.api_key,
+ client_secret=creds.api_key_secret,
+ resource_owner_key=creds.oauth_token,
+ resource_owner_secret=creds.oauth_token_secret)
+ for file in media_files:
+ file_path = self.get_file_path(file, agent_id)
+ image_data = open(file_path, 'rb').read()
+ b64_image = base64.b64encode(image_data)
+ upload_endpoint = 'https://upload.twitter.com/1.1/media/upload.json'
+ headers = {'Authorization': 'application/octet-stream'}
+ response = requests.post(upload_endpoint, headers=headers,
+ data={'media_data': b64_image},
+ auth=oauth)
+ ids = json.loads(response.text)['media_id']
+ media_ids.append(str(ids))
+ return media_ids
+
+ def get_file_path(self, file_name, agent_id):
+ final_path = ResourceHelper().get_agent_resource_path(file_name, agent_id)
+ return final_path
+
+ def send_tweets(self, params, creds):
+ tweet_endpoint = "https://api.twitter.com/2/tweets"
+ oauth = OAuth1Session(creds.api_key,
+ client_secret=creds.api_key_secret,
+ resource_owner_key=creds.oauth_token,
+ resource_owner_secret=creds.oauth_token_secret)
+
+ response = oauth.post(tweet_endpoint,json=params)
+ return response
diff --git a/superagi/helper/twitter_tokens.py b/superagi/helper/twitter_tokens.py
new file mode 100644
index 000000000..10b36d59a
--- /dev/null
+++ b/superagi/helper/twitter_tokens.py
@@ -0,0 +1,77 @@
+import hmac
+import time
+import random
+import base64
+import hashlib
+import urllib.parse
+import ast
+import http.client as http_client
+from sqlalchemy.orm import Session
+from superagi.models.toolkit import Toolkit
+from superagi.models.oauth_tokens import OauthTokens
+
+class Creds:
+
+ def __init__(self,api_key, api_key_secret, oauth_token, oauth_token_secret):
+ self.api_key = api_key
+ self.api_key_secret = api_key_secret
+ self.oauth_token = oauth_token
+ self.oauth_token_secret = oauth_token_secret
+
+class TwitterTokens:
+
+ def __init__(self, session: Session):
+ self.session = session
+
+ def get_request_token(self,api_data):
+ api_key = api_data["api_key"]
+ api_secret_key = api_data["api_secret"]
+ http_method = 'POST'
+ base_url = 'https://api.twitter.com/oauth/request_token'
+
+ params = {
+ 'oauth_callback': 'http://localhost:3000/api/twitter/oauth-tokens',
+ 'oauth_consumer_key': api_key,
+ 'oauth_nonce': self.gen_nonce(),
+ 'oauth_signature_method': 'HMAC-SHA1',
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_version': '1.0'
+ }
+
+ params_sorted = sorted(params.items())
+ params_qs = '&'.join([f'{k}={self.percent_encode(str(v))}' for k, v in params_sorted])
+
+ base_string = f'{http_method}&{self.percent_encode(base_url)}&{self.percent_encode(params_qs)}'
+
+ signing_key = f'{self.percent_encode(api_secret_key)}&'
+ signature = hmac.new(signing_key.encode(), base_string.encode(), hashlib.sha1)
+ params['oauth_signature'] = base64.b64encode(signature.digest()).decode()
+
+ auth_header = 'OAuth ' + ', '.join([f'{k}="{self.percent_encode(str(v))}"' for k, v in params.items()])
+
+ headers = {
+ 'Content-Type': 'application/x-www-form-urlencoded',
+ 'Authorization': auth_header
+ }
+ conn = http_client.HTTPSConnection("api.twitter.com")
+ conn.request("POST", "/oauth/request_token", "", headers)
+ res = conn.getresponse()
+ response_data = res.read().decode('utf-8')
+ conn.close()
+ request_token_resp = dict(urllib.parse.parse_qsl(response_data))
+ return request_token_resp
+
+ def percent_encode(self, val):
+ return urllib.parse.quote(val, safe='')
+
+ def gen_nonce(self):
+ nonce = ''.join([str(random.randint(0, 9)) for i in range(32)])
+ return nonce
+
+ def get_twitter_creds(self, toolkit_id):
+ toolkit = self.session.query(Toolkit).filter(Toolkit.id == toolkit_id).first()
+ organisation_id = toolkit.organisation_id
+ twitter_creds = self.session.query(OauthTokens).filter(OauthTokens.toolkit_id == toolkit_id, OauthTokens.organisation_id == organisation_id).first()
+ twitter_creds = ast.literal_eval(twitter_creds.value)
+ final_creds = Creds(twitter_creds['api_key'], twitter_creds['api_key_secret'], twitter_creds['oauth_token'], twitter_creds['oauth_token_secret'])
+ return final_creds
\ No newline at end of file
diff --git a/superagi/models/oauth_tokens.py b/superagi/models/oauth_tokens.py
new file mode 100644
index 000000000..996c0dae8
--- /dev/null
+++ b/superagi/models/oauth_tokens.py
@@ -0,0 +1,53 @@
+from sqlalchemy import Column, Integer, String, Text
+from sqlalchemy.orm import Session
+
+from superagi.models.base_model import DBBaseModel
+import json
+import yaml
+
+
+
+class OauthTokens(DBBaseModel):
+ """
+ Model representing a OauthTokens.
+
+ Attributes:
+ id (Integer): The primary key of the oauth token.
+ user_id (Integer): The ID of the user associated with the Tokens.
+ toolkit_id (Integer): The ID of the toolkit associated with the Tokens.
+ key (String): The Token Key.
+ value (Text): The Token value.
+ """
+
+ __tablename__ = 'oauth_tokens'
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ user_id = Column(Integer)
+ organisation_id = Column(Integer)
+ toolkit_id = Column(Integer)
+ key = Column(String)
+ value = Column(Text)
+
+ def __repr__(self):
+ """
+ Returns a string representation of the OauthTokens object.
+
+ Returns:
+ str: String representation of the OauthTokens object.
+ """
+
+ return f"Tokens(id={self.id}, user_id={self.user_id}, organisation_id={self.organisation_id} toolkit_id={self.toolkit_id}, key={self.key}, value={self.value})"
+
+ @classmethod
+ def add_or_update(self, session: Session, toolkit_id: int, user_id: int, organisation_id: int, key: str, value: Text = None):
+ oauth_tokens = session.query(OauthTokens).filter_by(toolkit_id=toolkit_id, user_id=user_id).first()
+ if oauth_tokens:
+ # Update existing oauth tokens
+ if value is not None:
+ oauth_tokens.value = value
+ else:
+ # Create new oauth tokens
+ oauth_tokens = OauthTokens(toolkit_id=toolkit_id, user_id=user_id, organisation_id=organisation_id, key=key, value=value)
+ session.add(oauth_tokens)
+
+ session.commit()
\ No newline at end of file
diff --git a/superagi/tools/twitter/send_tweets.py b/superagi/tools/twitter/send_tweets.py
new file mode 100644
index 000000000..79d19db2b
--- /dev/null
+++ b/superagi/tools/twitter/send_tweets.py
@@ -0,0 +1,38 @@
+import os
+import json
+import base64
+import requests
+from typing import Any, Type
+from pydantic import BaseModel, Field
+from superagi.tools.base_tool import BaseTool
+from superagi.helper.twitter_tokens import TwitterTokens
+from superagi.helper.twitter_helper import TwitterHelper
+
+class SendTweetsInput(BaseModel):
+ tweet_text: str = Field(..., description="Tweet text to be posted from twitter handle, if no value is given keep the default value as 'None'")
+ is_media: bool = Field(..., description="'True' if there is any media to be posted with Tweet else 'False'.")
+ media_files: list = Field(..., description="Name of the media files to be uploaded.")
+
+class SendTweetsTool(BaseTool):
+ name: str = "Send Tweets Tool"
+ args_schema: Type[BaseModel] = SendTweetsInput
+ description: str = "Send and Schedule Tweets for your Twitter Handle"
+ agent_id: int = None
+
+ def _execute(self, is_media: bool, tweet_text: str = 'None', media_files: list = []):
+ toolkit_id = self.toolkit_config.toolkit_id
+ session = self.toolkit_config.session
+ creds = TwitterTokens(session).get_twitter_creds(toolkit_id)
+ params = {}
+ if is_media:
+ media_ids = TwitterHelper().get_media_ids(media_files, creds, self.agent_id)
+ params["media"] = {"media_ids": media_ids}
+ if tweet_text is not None:
+ params["text"] = tweet_text
+ tweet_response = TwitterHelper().send_tweets(params, creds)
+ if tweet_response.status_code == 201:
+ return "Tweet posted successfully!!"
+ else:
+ return "Error posting tweet. (Status code: {})".format(tweet_response.status_code)
+
+
\ No newline at end of file
diff --git a/superagi/tools/twitter/twitter_toolkit.py b/superagi/tools/twitter/twitter_toolkit.py
new file mode 100644
index 000000000..75c7eea95
--- /dev/null
+++ b/superagi/tools/twitter/twitter_toolkit.py
@@ -0,0 +1,15 @@
+from abc import ABC
+from superagi.tools.base_tool import BaseToolkit, BaseTool
+from typing import Type, List
+from superagi.tools.twitter.send_tweets import SendTweetsTool
+
+
+class TwitterToolkit(BaseToolkit, ABC):
+ name: str = "Twitter Toolkit"
+ description: str = "Twitter Tool kit contains all tools related to Twitter"
+
+ def get_tools(self) -> List[BaseTool]:
+ return [SendTweetsTool()]
+
+ def get_env_keys(self) -> List[str]:
+ return ["TWITTER_API_KEY", "TWITTER_API_SECRET"]
diff --git a/tests/unit_tests/helper/test_twitter_helper.py b/tests/unit_tests/helper/test_twitter_helper.py
new file mode 100644
index 000000000..6bf4e0015
--- /dev/null
+++ b/tests/unit_tests/helper/test_twitter_helper.py
@@ -0,0 +1,56 @@
+import unittest
+from unittest.mock import Mock, patch
+from requests.models import Response
+from requests_oauthlib import OAuth1Session
+from superagi.helper.twitter_helper import TwitterHelper
+
+class TestSendTweets(unittest.TestCase):
+
+ @patch.object(OAuth1Session, 'post')
+ def test_send_tweets_success(self, mock_post):
+ # Prepare test data and mocks
+ test_params = {"status": "Hello, Twitter!"}
+ test_creds = Mock()
+ test_oauth = OAuth1Session(test_creds.api_key)
+
+ # Mock successful posting
+ resp = Response()
+ resp.status_code = 200
+ mock_post.return_value = resp
+
+ # Call the method under test
+ response = TwitterHelper().send_tweets(test_params, test_creds)
+
+ # Assert the post request was called correctly
+ test_oauth.post.assert_called_once_with(
+ "https://api.twitter.com/2/tweets",
+ json=test_params)
+
+ # Assert the response is correct
+ self.assertEqual(response.status_code, 200)
+
+ @patch.object(OAuth1Session, 'post')
+ def test_send_tweets_failure(self, mock_post):
+ # Prepare test data and mocks
+ test_params = {"status": "Hello, Twitter!"}
+ test_creds = Mock()
+ test_oauth = OAuth1Session(test_creds.api_key)
+
+ # Mock unsuccessful posting
+ resp = Response()
+ resp.status_code = 400
+ mock_post.return_value = resp
+
+ # Call the method under test
+ response = TwitterHelper().send_tweets(test_params, test_creds)
+
+ # Assert the post request was called correctly
+ test_oauth.post.assert_called_once_with(
+ "https://api.twitter.com/2/tweets",
+ json=test_params)
+
+ # Assert the response is correct
+ self.assertEqual(response.status_code, 400)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/unit_tests/helper/test_twitter_tokens.py b/tests/unit_tests/helper/test_twitter_tokens.py
new file mode 100644
index 000000000..1fc9f50bd
--- /dev/null
+++ b/tests/unit_tests/helper/test_twitter_tokens.py
@@ -0,0 +1,51 @@
+import unittest
+from unittest.mock import patch, Mock, MagicMock
+from typing import NamedTuple
+import ast
+from sqlalchemy.orm import Session
+from superagi.helper.twitter_tokens import Creds, TwitterTokens
+from superagi.models.toolkit import Toolkit
+from superagi.models.oauth_tokens import OauthTokens
+import time
+import http.client
+
+
+class TestCreds(unittest.TestCase):
+ def test_init(self):
+ creds = Creds('api_key', 'api_key_secret', 'oauth_token', 'oauth_token_secret')
+ self.assertEqual(creds.api_key, 'api_key')
+ self.assertEqual(creds.api_key_secret, 'api_key_secret')
+ self.assertEqual(creds.oauth_token, 'oauth_token')
+ self.assertEqual(creds.oauth_token_secret, 'oauth_token_secret')
+
+
+class TestTwitterTokens(unittest.TestCase):
+ twitter_tokens = TwitterTokens(Session)
+ def setUp(self):
+ self.mock_session = Mock(spec=Session)
+ self.twitter_tokens = TwitterTokens(session=self.mock_session)
+
+ def test_init(self):
+ self.assertEqual(self.twitter_tokens.session, self.mock_session)
+
+ def test_percent_encode(self):
+ self.assertEqual(self.twitter_tokens.percent_encode("#"), "%23")
+
+ def test_gen_nonce(self):
+ self.assertEqual(len(self.twitter_tokens.gen_nonce()), 32)
+
+ @patch.object(time, 'time', return_value=1234567890)
+ @patch.object(http.client, 'HTTPSConnection')
+ @patch('superagi.helper.twitter_tokens.TwitterTokens.gen_nonce', return_value=123456) # Replace '__main__' with actual module name
+ @patch('superagi.helper.twitter_tokens.TwitterTokens.percent_encode', return_value="encoded") # Replace '__main__' with actual module name
+ def test_get_request_token(self, mock_percent_encode, mock_gen_nonce, mock_https_connection, mock_time):
+ response_mock = Mock()
+ response_mock.read.return_value = b'oauth_token=test_token&oauth_token_secret=test_secret'
+ mock_https_connection.return_value.getresponse.return_value = response_mock
+
+ api_data = {"api_key": "test_key", "api_secret": "test_secret"}
+ expected_result = {'oauth_token': 'test_token', 'oauth_token_secret': 'test_secret'}
+ self.assertEqual(self.twitter_tokens.get_request_token(api_data), expected_result)
+
+if __name__ == "__main__":
+ unittest.main()
\ No newline at end of file
diff --git a/tests/unit_tests/tools/twitter/test_send_tweets.py b/tests/unit_tests/tools/twitter/test_send_tweets.py
new file mode 100644
index 000000000..a50fa1045
--- /dev/null
+++ b/tests/unit_tests/tools/twitter/test_send_tweets.py
@@ -0,0 +1,51 @@
+import unittest
+from unittest.mock import MagicMock, patch
+from superagi.tools.twitter.send_tweets import SendTweetsInput, SendTweetsTool
+
+
+class TestSendTweetsInput(unittest.TestCase):
+ def test_fields(self):
+ # Creating object
+ data = SendTweetsInput(tweet_text='Hello world', is_media=True, media_files=['image1.png', 'image2.png'])
+ # Testing object
+ self.assertEqual(data.tweet_text, 'Hello world')
+ self.assertEqual(data.is_media, True)
+ self.assertEqual(data.media_files, ['image1.png', 'image2.png'])
+
+
+class TestSendTweetsTool(unittest.TestCase):
+ @patch('superagi.helper.twitter_tokens.TwitterTokens.get_twitter_creds', return_value={'token': '123', 'token_secret': '456'})
+ @patch('superagi.helper.twitter_helper.TwitterHelper.get_media_ids', return_value=[789])
+ @patch('superagi.helper.twitter_helper.TwitterHelper.send_tweets')
+ def test_execute(self, mock_send_tweets, mock_get_media_ids, mock_get_twitter_creds):
+ # Mock the response from 'send_tweets'
+ responseMock = MagicMock()
+ responseMock.status_code = 201
+ mock_send_tweets.return_value = responseMock
+
+ # Creating SendTweetsTool object
+ obj = SendTweetsTool()
+ obj.toolkit_config = MagicMock()
+ obj.toolkit_config.toolkit_id = 1
+ obj.toolkit_config.session = MagicMock()
+ obj.agent_id = 99
+
+ # Testing when 'is_media' is True, 'tweet_text' is 'None' and 'media_files' is an empty list
+ self.assertEqual(obj._execute(True), "Tweet posted successfully!!")
+ mock_get_twitter_creds.assert_called_once_with(1)
+ mock_get_media_ids.assert_called_once_with([], {'token': '123', 'token_secret': '456'}, 99)
+ mock_send_tweets.assert_called_once_with({'media': {'media_ids': [789]}, 'text': 'None'}, {'token': '123', 'token_secret': '456'})
+
+ # Testing when 'is_media' is False, 'tweet_text' is 'Hello world' and 'media_files' is a list with elements
+ mock_get_twitter_creds.reset_mock()
+ mock_get_media_ids.reset_mock()
+ mock_send_tweets.reset_mock()
+ responseMock.status_code = 400
+ self.assertEqual(obj._execute(False, 'Hello world', ['image1.png']), "Error posting tweet. (Status code: 400)")
+ mock_get_twitter_creds.assert_called_once_with(1)
+ mock_get_media_ids.assert_not_called()
+ mock_send_tweets.assert_called_once_with({'text': 'Hello world'}, {'token': '123', 'token_secret': '456'})
+
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/workspace/input/testing.txt b/workspace/input/testing.txt
new file mode 100644
index 000000000..fe857725e
--- /dev/null
+++ b/workspace/input/testing.txt
@@ -0,0 +1 @@
+"Hello world"
diff --git a/workspace/output/testing.txt b/workspace/output/testing.txt
new file mode 100644
index 000000000..06ae699f2
--- /dev/null
+++ b/workspace/output/testing.txt
@@ -0,0 +1 @@
+"Hello World"