diff --git a/gui/pages/Content/Toolkits/ToolkitWorkspace.js b/gui/pages/Content/Toolkits/ToolkitWorkspace.js index 51cf7897a..a724c5ecf 100644 --- a/gui/pages/Content/Toolkits/ToolkitWorkspace.js +++ b/gui/pages/Content/Toolkits/ToolkitWorkspace.js @@ -1,7 +1,7 @@ import React, {useEffect, useState} from 'react'; import Image from 'next/image'; import {ToastContainer, toast} from 'react-toastify'; -import {updateToolConfig, getToolConfig, authenticateGoogleCred} from "@/pages/api/DashboardService"; +import {updateToolConfig, getToolConfig, authenticateGoogleCred, authenticateTwitterCred} from "@/pages/api/DashboardService"; import styles from './Tool.module.css'; import {setLocalStorageValue, setLocalStorageArray} from "@/utils/utils"; @@ -25,6 +25,13 @@ export default function ToolkitWorkspace({toolkitDetails, internalId}){ window.location.href = `https://accounts.google.com/o/oauth2/v2/auth?client_id=${client_id}&redirect_uri=${redirect_uri}&access_type=offline&response_type=code&scope=${scope}`; } + function getTwitterToken(oauth_data){ + const oauth_token = oauth_data.oauth_token + const oauth_token_secret = oauth_data.oauth_token_secret + const authUrl = `https://api.twitter.com/oauth/authenticate?oauth_token=${oauth_token}` + window.location.href = authUrl + } + useEffect(() => { if(toolkitDetails !== null) { if (toolkitDetails.tools) { @@ -37,7 +44,7 @@ export default function ToolkitWorkspace({toolkitDetails, internalId}){ const apiConfigs = response.data || []; setApiConfigs(localStoredConfigs ? JSON.parse(localStoredConfigs) : apiConfigs); }) - .catch((error) => { + .catch((errPor) => { console.log('Error fetching API data:', error); }) .finally(() => { @@ -72,6 +79,17 @@ export default function ToolkitWorkspace({toolkitDetails, internalId}){ }); }; + const handleTwitterAuthClick = async () => { + authenticateTwitterCred(toolkitDetails.id) + .then((response) => { + getTwitterToken(response.data); + localStorage.setItem("twitter_toolkit_id", toolkitDetails.id) + }) + .catch((error) => { + console.error('Error fetching data: ', error); + }); + }; + useEffect(() => { const active_tab = localStorage.getItem('toolkit_tab_' + String(internalId)); if(active_tab) { @@ -124,6 +142,7 @@ export default function ToolkitWorkspace({toolkitDetails, internalId}){
{toolkitDetails.name === 'Google Calendar Toolkit' && } + {toolkitDetails.name === 'Twitter Toolkit' && }
diff --git a/gui/pages/Dashboard/Content.js b/gui/pages/Dashboard/Content.js index 3698c6a06..e1e841213 100644 --- a/gui/pages/Dashboard/Content.js +++ b/gui/pages/Dashboard/Content.js @@ -7,9 +7,13 @@ import Settings from "./Settings/Settings"; import styles from './Dashboard.module.css'; import Image from "next/image"; import { EventBus } from "@/utils/eventBus"; -import {getAgents, getToolKit, getLastActiveAgent} from "@/pages/api/DashboardService"; +import {getAgents, getToolKit, getLastActiveAgent, sendTwitterCreds} from "@/pages/api/DashboardService"; import Market from "../Content/Marketplace/Market"; import AgentTemplatesList from '../Content/Agents/AgentTemplatesList'; +import { useRouter } from 'next/router'; +import querystring from 'querystring'; +import { userInfo } from 'os'; +import { parse } from 'path'; import AddTool from "@/pages/Content/Toolkits/AddTool"; import {createInternalId, removeInternalId} from "@/utils/utils"; @@ -20,6 +24,7 @@ export default function Content({env, selectedView, selectedProjectId, organisat const [toolkits, setToolkits] = useState(null); const tabContainerRef = useRef(null); const [toolkitDetails, setToolkitDetails] = useState({}) + const router = useRouter(); function fetchAgents() { getAgents(selectedProjectId) @@ -140,6 +145,21 @@ export default function Content({env, selectedView, selectedProjectId, organisat } } } + const queryParams = router.asPath.split('?')[1]; + const parsedParams = querystring.parse(queryParams); + parsedParams["toolkit_id"] = toolkitDetails.toolkit_id; + if (window.location.href.indexOf("twitter_creds") > -1){ + const toolkit_id = localStorage.getItem("twitter_toolkit_id") || null; + parsedParams["toolkit_id"] = toolkit_id; + const params = JSON.stringify(parsedParams) + sendTwitterCreds(params) + .then((response) => { + console.log("Authentication completed successfully"); + }) + .catch((error) => { + console.error("Error fetching data: ",error); + }) + }; }, [selectedTab]); useEffect(() => { @@ -171,7 +191,6 @@ export default function Content({env, selectedView, selectedProjectId, organisat EventBus.off('openNewTab', openNewTab); EventBus.off('reFetchAgents', fetchAgents); EventBus.off('removeTab', removeTab); - EventBus.off('openToolkitTab', openToolkitTab); }; }); diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js index aa86d1f2e..dc3ed4530 100644 --- a/gui/pages/api/DashboardService.js +++ b/gui/pages/api/DashboardService.js @@ -132,6 +132,14 @@ export const authenticateGoogleCred = (toolKitId) => { return api.get(`/google/get_google_creds/toolkit_id/${toolKitId}`); } +export const authenticateTwitterCred = (toolKitId) => { + return api.get(`/twitter/get_twitter_creds/toolkit_id/${toolKitId}`); +} + +export const sendTwitterCreds = (twitter_creds) => { + return api.post(`/twitter/send_twitter_creds/${twitter_creds}`); +} + export const fetchToolTemplateList = () => { return api.get(`/toolkits/get/list?page=0`); } diff --git a/main.py b/main.py index 222e9dd4c..7d723fb2b 100644 --- a/main.py +++ b/main.py @@ -15,6 +15,11 @@ from sqlalchemy.orm import sessionmaker import superagi +import urllib.parse +import json +import http.client as http_client +from superagi.helper.twitter_tokens import TwitterTokens +from datetime import datetime, timedelta from superagi.agent.agent_prompt_builder import AgentPromptBuilder from superagi.config.config import get_config from superagi.controllers.agent import router as agent_router @@ -28,6 +33,7 @@ from superagi.controllers.config import router as config_router from superagi.controllers.organisation import router as organisation_router from superagi.controllers.project import router as project_router +from superagi.controllers.twitter_oauth import router as twitter_oauth_router from superagi.controllers.resources import router as resources_router from superagi.controllers.tool import router as tool_router from superagi.controllers.tool_config import router as tool_config_router @@ -36,11 +42,13 @@ from superagi.helper.tool_helper import register_toolkits from superagi.lib.logger import logger from superagi.llms.openai import OpenAi +from superagi.helper.auth import get_current_user from superagi.models.agent_workflow import AgentWorkflow from superagi.models.agent_workflow_step import AgentWorkflowStep from superagi.models.organisation import Organisation from superagi.models.tool_config import ToolConfig from superagi.models.toolkit import Toolkit +from superagi.models.oauth_tokens import OauthTokens from superagi.models.types.login_request import LoginRequest from superagi.models.user import User @@ -97,6 +105,7 @@ app.include_router(config_router, prefix="/configs") app.include_router(agent_template_router, prefix="/agent_templates") app.include_router(agent_workflow_router, prefix="/agent_workflows") +app.include_router(twitter_oauth_router, prefix="/twitter") # in production you can use Settings management @@ -320,7 +329,6 @@ async def google_auth_calendar(code: str = Query(...), Authorize: AuthJWT = Depe frontend_url = superagi.config.config.get_config("FRONTEND_URL", "http://localhost:3000") return RedirectResponse(frontend_url) - @app.get('/github-login') def github_login(): """GitHub login""" @@ -411,7 +419,6 @@ def get_google_calendar_tool_configs(toolkit_id: int): "client_id": google_calendar_config.value } - @app.get("/validate-open-ai-key/{open_ai_key}") async def root(open_ai_key: str, Authorize: AuthJWT = Depends()): """API to validate Open AI Key""" diff --git a/migrations/versions/c5c19944c90c_create_oauth_tokens.py b/migrations/versions/c5c19944c90c_create_oauth_tokens.py new file mode 100644 index 000000000..8986afcdc --- /dev/null +++ b/migrations/versions/c5c19944c90c_create_oauth_tokens.py @@ -0,0 +1,48 @@ +"""Create Oauth Tokens + +Revision ID: c5c19944c90c +Revises: 7a3e336c0fba +Create Date: 2023-06-30 07:26:29.180784 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'c5c19944c90c' +down_revision = '7a3e336c0fba' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('oauth_tokens', + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.Column('organisation_id', sa.Integer(), nullable=True), + sa.Column('toolkit_id', sa.Integer(), nullable=True), + sa.Column('key', sa.String(), nullable=True), + sa.Column('value', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.drop_index('ix_agent_execution_permissions_agent_execution_id', table_name='agent_execution_permissions') + op.drop_index('ix_atc_agnt_template_id_key', table_name='agent_template_configs') + op.drop_index('ix_agt_agnt_name', table_name='agent_templates') + op.drop_index('ix_agt_agnt_organisation_id', table_name='agent_templates') + op.drop_index('ix_agt_agnt_workflow_id', table_name='agent_templates') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_index('ix_agt_agnt_workflow_id', 'agent_templates', ['agent_workflow_id'], unique=False) + op.create_index('ix_agt_agnt_organisation_id', 'agent_templates', ['organisation_id'], unique=False) + op.create_index('ix_agt_agnt_name', 'agent_templates', ['name'], unique=False) + op.create_index('ix_atc_agnt_template_id_key', 'agent_template_configs', ['agent_template_id', 'key'], unique=False) + op.create_index('ix_agent_execution_permissions_agent_execution_id', 'agent_execution_permissions', ['agent_execution_id'], unique=False) + op.drop_table('oauth_tokens') + # ### end Alembic commands ### diff --git a/superagi/controllers/twitter_oauth.py b/superagi/controllers/twitter_oauth.py new file mode 100644 index 000000000..b79b7be66 --- /dev/null +++ b/superagi/controllers/twitter_oauth.py @@ -0,0 +1,70 @@ +from fastapi import Depends, Query +from fastapi import APIRouter +from fastapi.responses import RedirectResponse +from fastapi_jwt_auth import AuthJWT +from fastapi_sqlalchemy import db +from sqlalchemy.orm import sessionmaker + +import superagi +import json +from superagi.models.db import connect_db +import http.client as http_client +from superagi.helper.twitter_tokens import TwitterTokens +from superagi.helper.auth import get_current_user +from superagi.models.tool_config import ToolConfig +from superagi.models.toolkit import Toolkit +from superagi.models.oauth_tokens import OauthTokens + +router = APIRouter() + +@router.get('/oauth-tokens') +async def twitter_oauth(oauth_token: str = Query(...),oauth_verifier: str = Query(...), Authorize: AuthJWT = Depends()): + print("///////////////////////////") + print(oauth_token) + token_uri = f'https://api.twitter.com/oauth/access_token?oauth_verifier={oauth_verifier}&oauth_token={oauth_token}' + conn = http_client.HTTPSConnection("api.twitter.com") + conn.request("POST", token_uri, "") + res = conn.getresponse() + response_data = res.read().decode('utf-8') + frontend_url = superagi.config.config.get_config("FRONTEND_URL", "http://localhost:3000") + redirect_url_success = f"{frontend_url}/twitter_creds/?{response_data}" + return RedirectResponse(url=redirect_url_success) + +@router.post("/send_twitter_creds/{twitter_creds}") +def send_twitter_tool_configs(twitter_creds: str, Authorize: AuthJWT = Depends()): + engine = connect_db() + Session = sessionmaker(bind=engine) + session = Session() + current_user = get_current_user() + user_id = current_user.id + credentials = json.loads(twitter_creds) + credentials["user_id"] = user_id + toolkit = db.session.query(Toolkit).filter(Toolkit.id == credentials["toolkit_id"]).first() + api_key = db.session.query(ToolConfig).filter(ToolConfig.key == "TWITTER_API_KEY", ToolConfig.toolkit_id == credentials["toolkit_id"]).first() + api_key_secret = db.session.query(ToolConfig).filter(ToolConfig.key == "TWITTER_API_SECRET", ToolConfig.toolkit_id == credentials["toolkit_id"]).first() + final_creds = { + "api_key": api_key.value, + "api_key_secret": api_key_secret.value, + "oauth_token": credentials["oauth_token"], + "oauth_token_secret": credentials["oauth_token_secret"] + } + tokens = OauthTokens().add_or_update(session, credentials["toolkit_id"], user_id, toolkit.organisation_id, "TWITTER_OAUTH_TOKENS", str(final_creds)) + if tokens: + success = True + else: + success = False + return success + +@router.get("/get_twitter_creds/toolkit_id/{toolkit_id}") +def get_twitter_tool_configs(toolkit_id: int): + engine = connect_db() + Session = sessionmaker(bind=engine) + session = Session() + twitter_config_key = session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id,ToolConfig.key == "TWITTER_API_KEY").first() + twitter_config_secret = session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id,ToolConfig.key == "TWITTER_API_SECRET").first() + api_data = { + "api_key": twitter_config_key.value, + "api_secret": twitter_config_secret.value + } + response = TwitterTokens(session).get_request_token(api_data) + return response \ No newline at end of file diff --git a/superagi/helper/auth.py b/superagi/helper/auth.py index 5a80677e6..a185ece83 100644 --- a/superagi/helper/auth.py +++ b/superagi/helper/auth.py @@ -33,6 +33,13 @@ def get_user_organisation(Authorize: AuthJWT = Depends(check_auth)): Returns: Organisation: Instance of Organisation class to which the authenticated user belongs. """ + user = get_current_user() + if user is None: + raise HTTPException(status_code=401, detail="Unauthenticated") + organisation = db.session.query(Organisation).filter(Organisation.id == user.organisation_id).first() + return organisation + +def get_current_user(Authorize: AuthJWT = Depends(check_auth)): env = get_config("ENV", "DEV") if env == "DEV": @@ -43,7 +50,4 @@ def get_user_organisation(Authorize: AuthJWT = Depends(check_auth)): # Query the User table to find the user by their email user = db.session.query(User).filter(User.email == email).first() - if user is None: - raise HTTPException(status_code=401, detail="Unauthenticated") - organisation = db.session.query(Organisation).filter(Organisation.id == user.organisation_id).first() - return organisation \ No newline at end of file + return user \ No newline at end of file diff --git a/superagi/helper/twitter_helper.py b/superagi/helper/twitter_helper.py new file mode 100644 index 000000000..e47c3d60d --- /dev/null +++ b/superagi/helper/twitter_helper.py @@ -0,0 +1,42 @@ +import os +import json +import base64 +import requests +from requests_oauthlib import OAuth1 +from requests_oauthlib import OAuth1Session +from superagi.helper.resource_helper import ResourceHelper + +class TwitterHelper: + + def get_media_ids(self, media_files, creds, agent_id): + media_ids = [] + oauth = OAuth1(creds.api_key, + client_secret=creds.api_key_secret, + resource_owner_key=creds.oauth_token, + resource_owner_secret=creds.oauth_token_secret) + for file in media_files: + file_path = self.get_file_path(file, agent_id) + image_data = open(file_path, 'rb').read() + b64_image = base64.b64encode(image_data) + upload_endpoint = 'https://upload.twitter.com/1.1/media/upload.json' + headers = {'Authorization': 'application/octet-stream'} + response = requests.post(upload_endpoint, headers=headers, + data={'media_data': b64_image}, + auth=oauth) + ids = json.loads(response.text)['media_id'] + media_ids.append(str(ids)) + return media_ids + + def get_file_path(self, file_name, agent_id): + final_path = ResourceHelper().get_agent_resource_path(file_name, agent_id) + return final_path + + def send_tweets(self, params, creds): + tweet_endpoint = "https://api.twitter.com/2/tweets" + oauth = OAuth1Session(creds.api_key, + client_secret=creds.api_key_secret, + resource_owner_key=creds.oauth_token, + resource_owner_secret=creds.oauth_token_secret) + + response = oauth.post(tweet_endpoint,json=params) + return response diff --git a/superagi/helper/twitter_tokens.py b/superagi/helper/twitter_tokens.py new file mode 100644 index 000000000..10b36d59a --- /dev/null +++ b/superagi/helper/twitter_tokens.py @@ -0,0 +1,77 @@ +import hmac +import time +import random +import base64 +import hashlib +import urllib.parse +import ast +import http.client as http_client +from sqlalchemy.orm import Session +from superagi.models.toolkit import Toolkit +from superagi.models.oauth_tokens import OauthTokens + +class Creds: + + def __init__(self,api_key, api_key_secret, oauth_token, oauth_token_secret): + self.api_key = api_key + self.api_key_secret = api_key_secret + self.oauth_token = oauth_token + self.oauth_token_secret = oauth_token_secret + +class TwitterTokens: + + def __init__(self, session: Session): + self.session = session + + def get_request_token(self,api_data): + api_key = api_data["api_key"] + api_secret_key = api_data["api_secret"] + http_method = 'POST' + base_url = 'https://api.twitter.com/oauth/request_token' + + params = { + 'oauth_callback': 'http://localhost:3000/api/twitter/oauth-tokens', + 'oauth_consumer_key': api_key, + 'oauth_nonce': self.gen_nonce(), + 'oauth_signature_method': 'HMAC-SHA1', + 'oauth_timestamp': int(time.time()), + 'oauth_version': '1.0' + } + + params_sorted = sorted(params.items()) + params_qs = '&'.join([f'{k}={self.percent_encode(str(v))}' for k, v in params_sorted]) + + base_string = f'{http_method}&{self.percent_encode(base_url)}&{self.percent_encode(params_qs)}' + + signing_key = f'{self.percent_encode(api_secret_key)}&' + signature = hmac.new(signing_key.encode(), base_string.encode(), hashlib.sha1) + params['oauth_signature'] = base64.b64encode(signature.digest()).decode() + + auth_header = 'OAuth ' + ', '.join([f'{k}="{self.percent_encode(str(v))}"' for k, v in params.items()]) + + headers = { + 'Content-Type': 'application/x-www-form-urlencoded', + 'Authorization': auth_header + } + conn = http_client.HTTPSConnection("api.twitter.com") + conn.request("POST", "/oauth/request_token", "", headers) + res = conn.getresponse() + response_data = res.read().decode('utf-8') + conn.close() + request_token_resp = dict(urllib.parse.parse_qsl(response_data)) + return request_token_resp + + def percent_encode(self, val): + return urllib.parse.quote(val, safe='') + + def gen_nonce(self): + nonce = ''.join([str(random.randint(0, 9)) for i in range(32)]) + return nonce + + def get_twitter_creds(self, toolkit_id): + toolkit = self.session.query(Toolkit).filter(Toolkit.id == toolkit_id).first() + organisation_id = toolkit.organisation_id + twitter_creds = self.session.query(OauthTokens).filter(OauthTokens.toolkit_id == toolkit_id, OauthTokens.organisation_id == organisation_id).first() + twitter_creds = ast.literal_eval(twitter_creds.value) + final_creds = Creds(twitter_creds['api_key'], twitter_creds['api_key_secret'], twitter_creds['oauth_token'], twitter_creds['oauth_token_secret']) + return final_creds \ No newline at end of file diff --git a/superagi/models/oauth_tokens.py b/superagi/models/oauth_tokens.py new file mode 100644 index 000000000..996c0dae8 --- /dev/null +++ b/superagi/models/oauth_tokens.py @@ -0,0 +1,53 @@ +from sqlalchemy import Column, Integer, String, Text +from sqlalchemy.orm import Session + +from superagi.models.base_model import DBBaseModel +import json +import yaml + + + +class OauthTokens(DBBaseModel): + """ + Model representing a OauthTokens. + + Attributes: + id (Integer): The primary key of the oauth token. + user_id (Integer): The ID of the user associated with the Tokens. + toolkit_id (Integer): The ID of the toolkit associated with the Tokens. + key (String): The Token Key. + value (Text): The Token value. + """ + + __tablename__ = 'oauth_tokens' + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer) + organisation_id = Column(Integer) + toolkit_id = Column(Integer) + key = Column(String) + value = Column(Text) + + def __repr__(self): + """ + Returns a string representation of the OauthTokens object. + + Returns: + str: String representation of the OauthTokens object. + """ + + return f"Tokens(id={self.id}, user_id={self.user_id}, organisation_id={self.organisation_id} toolkit_id={self.toolkit_id}, key={self.key}, value={self.value})" + + @classmethod + def add_or_update(self, session: Session, toolkit_id: int, user_id: int, organisation_id: int, key: str, value: Text = None): + oauth_tokens = session.query(OauthTokens).filter_by(toolkit_id=toolkit_id, user_id=user_id).first() + if oauth_tokens: + # Update existing oauth tokens + if value is not None: + oauth_tokens.value = value + else: + # Create new oauth tokens + oauth_tokens = OauthTokens(toolkit_id=toolkit_id, user_id=user_id, organisation_id=organisation_id, key=key, value=value) + session.add(oauth_tokens) + + session.commit() \ No newline at end of file diff --git a/superagi/tools/twitter/send_tweets.py b/superagi/tools/twitter/send_tweets.py new file mode 100644 index 000000000..79d19db2b --- /dev/null +++ b/superagi/tools/twitter/send_tweets.py @@ -0,0 +1,38 @@ +import os +import json +import base64 +import requests +from typing import Any, Type +from pydantic import BaseModel, Field +from superagi.tools.base_tool import BaseTool +from superagi.helper.twitter_tokens import TwitterTokens +from superagi.helper.twitter_helper import TwitterHelper + +class SendTweetsInput(BaseModel): + tweet_text: str = Field(..., description="Tweet text to be posted from twitter handle, if no value is given keep the default value as 'None'") + is_media: bool = Field(..., description="'True' if there is any media to be posted with Tweet else 'False'.") + media_files: list = Field(..., description="Name of the media files to be uploaded.") + +class SendTweetsTool(BaseTool): + name: str = "Send Tweets Tool" + args_schema: Type[BaseModel] = SendTweetsInput + description: str = "Send and Schedule Tweets for your Twitter Handle" + agent_id: int = None + + def _execute(self, is_media: bool, tweet_text: str = 'None', media_files: list = []): + toolkit_id = self.toolkit_config.toolkit_id + session = self.toolkit_config.session + creds = TwitterTokens(session).get_twitter_creds(toolkit_id) + params = {} + if is_media: + media_ids = TwitterHelper().get_media_ids(media_files, creds, self.agent_id) + params["media"] = {"media_ids": media_ids} + if tweet_text is not None: + params["text"] = tweet_text + tweet_response = TwitterHelper().send_tweets(params, creds) + if tweet_response.status_code == 201: + return "Tweet posted successfully!!" + else: + return "Error posting tweet. (Status code: {})".format(tweet_response.status_code) + + \ No newline at end of file diff --git a/superagi/tools/twitter/twitter_toolkit.py b/superagi/tools/twitter/twitter_toolkit.py new file mode 100644 index 000000000..75c7eea95 --- /dev/null +++ b/superagi/tools/twitter/twitter_toolkit.py @@ -0,0 +1,15 @@ +from abc import ABC +from superagi.tools.base_tool import BaseToolkit, BaseTool +from typing import Type, List +from superagi.tools.twitter.send_tweets import SendTweetsTool + + +class TwitterToolkit(BaseToolkit, ABC): + name: str = "Twitter Toolkit" + description: str = "Twitter Tool kit contains all tools related to Twitter" + + def get_tools(self) -> List[BaseTool]: + return [SendTweetsTool()] + + def get_env_keys(self) -> List[str]: + return ["TWITTER_API_KEY", "TWITTER_API_SECRET"] diff --git a/tests/unit_tests/helper/test_twitter_helper.py b/tests/unit_tests/helper/test_twitter_helper.py new file mode 100644 index 000000000..6bf4e0015 --- /dev/null +++ b/tests/unit_tests/helper/test_twitter_helper.py @@ -0,0 +1,56 @@ +import unittest +from unittest.mock import Mock, patch +from requests.models import Response +from requests_oauthlib import OAuth1Session +from superagi.helper.twitter_helper import TwitterHelper + +class TestSendTweets(unittest.TestCase): + + @patch.object(OAuth1Session, 'post') + def test_send_tweets_success(self, mock_post): + # Prepare test data and mocks + test_params = {"status": "Hello, Twitter!"} + test_creds = Mock() + test_oauth = OAuth1Session(test_creds.api_key) + + # Mock successful posting + resp = Response() + resp.status_code = 200 + mock_post.return_value = resp + + # Call the method under test + response = TwitterHelper().send_tweets(test_params, test_creds) + + # Assert the post request was called correctly + test_oauth.post.assert_called_once_with( + "https://api.twitter.com/2/tweets", + json=test_params) + + # Assert the response is correct + self.assertEqual(response.status_code, 200) + + @patch.object(OAuth1Session, 'post') + def test_send_tweets_failure(self, mock_post): + # Prepare test data and mocks + test_params = {"status": "Hello, Twitter!"} + test_creds = Mock() + test_oauth = OAuth1Session(test_creds.api_key) + + # Mock unsuccessful posting + resp = Response() + resp.status_code = 400 + mock_post.return_value = resp + + # Call the method under test + response = TwitterHelper().send_tweets(test_params, test_creds) + + # Assert the post request was called correctly + test_oauth.post.assert_called_once_with( + "https://api.twitter.com/2/tweets", + json=test_params) + + # Assert the response is correct + self.assertEqual(response.status_code, 400) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/unit_tests/helper/test_twitter_tokens.py b/tests/unit_tests/helper/test_twitter_tokens.py new file mode 100644 index 000000000..1fc9f50bd --- /dev/null +++ b/tests/unit_tests/helper/test_twitter_tokens.py @@ -0,0 +1,51 @@ +import unittest +from unittest.mock import patch, Mock, MagicMock +from typing import NamedTuple +import ast +from sqlalchemy.orm import Session +from superagi.helper.twitter_tokens import Creds, TwitterTokens +from superagi.models.toolkit import Toolkit +from superagi.models.oauth_tokens import OauthTokens +import time +import http.client + + +class TestCreds(unittest.TestCase): + def test_init(self): + creds = Creds('api_key', 'api_key_secret', 'oauth_token', 'oauth_token_secret') + self.assertEqual(creds.api_key, 'api_key') + self.assertEqual(creds.api_key_secret, 'api_key_secret') + self.assertEqual(creds.oauth_token, 'oauth_token') + self.assertEqual(creds.oauth_token_secret, 'oauth_token_secret') + + +class TestTwitterTokens(unittest.TestCase): + twitter_tokens = TwitterTokens(Session) + def setUp(self): + self.mock_session = Mock(spec=Session) + self.twitter_tokens = TwitterTokens(session=self.mock_session) + + def test_init(self): + self.assertEqual(self.twitter_tokens.session, self.mock_session) + + def test_percent_encode(self): + self.assertEqual(self.twitter_tokens.percent_encode("#"), "%23") + + def test_gen_nonce(self): + self.assertEqual(len(self.twitter_tokens.gen_nonce()), 32) + + @patch.object(time, 'time', return_value=1234567890) + @patch.object(http.client, 'HTTPSConnection') + @patch('superagi.helper.twitter_tokens.TwitterTokens.gen_nonce', return_value=123456) # Replace '__main__' with actual module name + @patch('superagi.helper.twitter_tokens.TwitterTokens.percent_encode', return_value="encoded") # Replace '__main__' with actual module name + def test_get_request_token(self, mock_percent_encode, mock_gen_nonce, mock_https_connection, mock_time): + response_mock = Mock() + response_mock.read.return_value = b'oauth_token=test_token&oauth_token_secret=test_secret' + mock_https_connection.return_value.getresponse.return_value = response_mock + + api_data = {"api_key": "test_key", "api_secret": "test_secret"} + expected_result = {'oauth_token': 'test_token', 'oauth_token_secret': 'test_secret'} + self.assertEqual(self.twitter_tokens.get_request_token(api_data), expected_result) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/unit_tests/tools/twitter/test_send_tweets.py b/tests/unit_tests/tools/twitter/test_send_tweets.py new file mode 100644 index 000000000..a50fa1045 --- /dev/null +++ b/tests/unit_tests/tools/twitter/test_send_tweets.py @@ -0,0 +1,51 @@ +import unittest +from unittest.mock import MagicMock, patch +from superagi.tools.twitter.send_tweets import SendTweetsInput, SendTweetsTool + + +class TestSendTweetsInput(unittest.TestCase): + def test_fields(self): + # Creating object + data = SendTweetsInput(tweet_text='Hello world', is_media=True, media_files=['image1.png', 'image2.png']) + # Testing object + self.assertEqual(data.tweet_text, 'Hello world') + self.assertEqual(data.is_media, True) + self.assertEqual(data.media_files, ['image1.png', 'image2.png']) + + +class TestSendTweetsTool(unittest.TestCase): + @patch('superagi.helper.twitter_tokens.TwitterTokens.get_twitter_creds', return_value={'token': '123', 'token_secret': '456'}) + @patch('superagi.helper.twitter_helper.TwitterHelper.get_media_ids', return_value=[789]) + @patch('superagi.helper.twitter_helper.TwitterHelper.send_tweets') + def test_execute(self, mock_send_tweets, mock_get_media_ids, mock_get_twitter_creds): + # Mock the response from 'send_tweets' + responseMock = MagicMock() + responseMock.status_code = 201 + mock_send_tweets.return_value = responseMock + + # Creating SendTweetsTool object + obj = SendTweetsTool() + obj.toolkit_config = MagicMock() + obj.toolkit_config.toolkit_id = 1 + obj.toolkit_config.session = MagicMock() + obj.agent_id = 99 + + # Testing when 'is_media' is True, 'tweet_text' is 'None' and 'media_files' is an empty list + self.assertEqual(obj._execute(True), "Tweet posted successfully!!") + mock_get_twitter_creds.assert_called_once_with(1) + mock_get_media_ids.assert_called_once_with([], {'token': '123', 'token_secret': '456'}, 99) + mock_send_tweets.assert_called_once_with({'media': {'media_ids': [789]}, 'text': 'None'}, {'token': '123', 'token_secret': '456'}) + + # Testing when 'is_media' is False, 'tweet_text' is 'Hello world' and 'media_files' is a list with elements + mock_get_twitter_creds.reset_mock() + mock_get_media_ids.reset_mock() + mock_send_tweets.reset_mock() + responseMock.status_code = 400 + self.assertEqual(obj._execute(False, 'Hello world', ['image1.png']), "Error posting tweet. (Status code: 400)") + mock_get_twitter_creds.assert_called_once_with(1) + mock_get_media_ids.assert_not_called() + mock_send_tweets.assert_called_once_with({'text': 'Hello world'}, {'token': '123', 'token_secret': '456'}) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/workspace/input/testing.txt b/workspace/input/testing.txt new file mode 100644 index 000000000..fe857725e --- /dev/null +++ b/workspace/input/testing.txt @@ -0,0 +1 @@ +"Hello world" diff --git a/workspace/output/testing.txt b/workspace/output/testing.txt new file mode 100644 index 000000000..06ae699f2 --- /dev/null +++ b/workspace/output/testing.txt @@ -0,0 +1 @@ +"Hello World"