forked from BadDownload/vistell-discord-bot-with-gpt-vision
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
236 lines (193 loc) · 9.66 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import os
from dotenv import load_dotenv
import discord
import asyncio
import logging
import base64
import requests
import aiohttp
from PIL import Image
from io import BytesIO
from gradio_client import Client, file
# Load environment variables from .env file
load_dotenv()
# Discord Bot Token
DISCORD_BOT_TOKEN = os.getenv('DISCORD_BOT_TOKEN')
# OpenAI API Key
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
OPENAI_BASE_URL = os.getenv('OPENAI_BASE_URL')
# Gradio API URL
GRADIO_API_URL = os.getenv('GRADIO_API_URL')
vision_model_url = f"{OPENAI_BASE_URL}/v1/chat/completions"
# Parse the list of channel IDs from the environment variable and convert it to a set
CHANNEL_IDS = os.getenv('CHANNEL_IDS')
if CHANNEL_IDS:
CHANNEL_IDS = set(map(int, CHANNEL_IDS.split(',')))
else:
CHANNEL_IDS = None
# Starting message for image analysis
STARTING_MESSAGE = os.getenv('STARTING_MESSAGE', "What’s in this image? If the image is mostly text, please provide the full text.")
# Max tokens amount for OpenAI ChatCompletion
MAX_TOKENS = int(os.getenv('MAX_TOKENS', 300))
# Message prefix
MESSAGE_PREFIX = os.getenv('MESSAGE_PREFIX', "Image Description:")
# Flag to determine if the bot should reply to image links
REPLY_TO_LINKS = os.getenv('REPLY_TO_LINKS', 'true').lower() == 'true'
allowed_domain = "cdn.discordapp.com"
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info(f"Openai compatible vision api url: {vision_model_url}")
# Initialize Discord bot with intents for messages and message content
intents = discord.Intents.default()
intents.messages = True
intents.message_content = True
bot = discord.Client(intents=intents)
async def describe_image_with_gradio(image_url):
try:
logger.info("Sending request to the Gradio API for image analysis...")
if not image_url.startswith(f"https://{allowed_domain}"):
raise ValueError("Invalid image URL domain")
# Fetch the image from the URL
#response = requests.get(image_url)
#image_data = response.content
#image = Image.open(BytesIO(image_data))
#png_buffer = BytesIO()
#image.save(png_buffer, format="PNG")
#png_data = png_buffer.getvalue()
# Encode the image in base64
#base64_data = base64.b64encode(png_data).decode('utf-8')
# Initialize Gradio client and send the request
client = Client(GRADIO_API_URL)
result = client.predict(
image=file(f"{image_url}"),
threshold=0.2,
api_name="/predict"
)
# Process and return the result
tag_string = result[0] # Assuming the first element is the "tag string"
return [tag_string]
except Exception as e:
logger.error(f"Error analyzing image with Gradio API: {e}")
return ["Error analyzing image with Gradio API."]
async def describe_image_with_openai(image_url, message_content):
if message_content != "":
IMAGE_PROMPT = f"Please answer this question about the image. Only output raw information. Follow the question exactly.\nUser question: {message_content}"
logger.info(f"Custom message: {IMAGE_PROMPT}")
else:
IMAGE_PROMPT = STARTING_MESSAGE
try:
logger.info("Sending request to the model for image analysis...")
# Check if the URL is from the allowed domain
if not image_url.startswith(f"https://{allowed_domain}"):
raise ValueError("Invalid image URL domain")
# Fetch the image from the URL
response = requests.get(image_url)
image_data = response.content
# Convert the image to PNG format
image = Image.open(BytesIO(image_data))
png_buffer = BytesIO()
image.save(png_buffer, format="PNG")
png_data = png_buffer.getvalue()
# Encode the PNG image in base64
base64_data = base64.b64encode(png_data).decode("utf-8")
# Send the image to the vision API
messages = []
if message_content != "":
messages.append({
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_data}"},
},
{"type": "text", "text": IMAGE_PROMPT},
],
})
else:
messages.append({
"role": "user",
"content": [
{"type": "text", "text": IMAGE_PROMPT},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_data}"},
},
],
})
# Send the request to the vision API
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "gpt-4-vision-preview",
"messages": messages,
"max_tokens": MAX_TOKENS,
}
async with session.post(vision_model_url, json=payload, headers=headers) as response:
data = await response.json()
logger.info("Received response from the model.")
# Extracting and returning the response
if 'choices' in data:
# Extract the text from the first choice
first_choice_text = data["choices"][0]["message"]["content"].strip()
# Split the text into chunks to fit within Discord message character limit
max_message_length = 1800 # Discord message character limit
description_chunks = [first_choice_text[i:i+max_message_length] for i in range(0, len(first_choice_text), max_message_length)]
return description_chunks
else:
return ["Failed to obtain a description from the model."]
except Exception as e:
logger.error(f"Error analyzing image with model: {e}")
return ["Error analyzing image with model."]
@bot.event
async def on_ready():
await bot.change_presence(activity=discord.Activity(type=discord.ActivityType.watching, name='Everything 👀'))
logger.info(f'{bot.user} has connected to Discord!')
@bot.event
async def on_message(message):
# Ignore messages sent by the bot and in dms
if message.author == bot.user or message.channel.type == discord.ChannelType.private:
return
# Check if no specific channels are specified or if the message is in one of the specified channels
try:
if not CHANNEL_IDS or message.channel.id in CHANNEL_IDS:
if message.content.lower().startswith("quiet"):
return # Do nothing if message starts with "quiet"
# Process attachments if any
if message.attachments:
async with message.channel.typing():
for attachment in message.attachments:
if any(attachment.filename.lower().endswith(ext) for ext in ['jpg', 'jpeg', 'png', 'gif', 'webp']):
if message.content.lower().startswith("tags"):
description_chunks = await describe_image_with_gradio(attachment.url)
else:
description_chunks = await describe_image_with_openai(attachment.url, message.content)
original_message = None # Store the original message containing the image attachment
# Send each description chunk as a separate message
for i, chunk in enumerate(description_chunks):
# Split message into multiple parts if exceeds the character limit
while chunk:
# Truncate the chunk to fit within the Discord message length limit
truncated_chunk = chunk[:1800]
# Send the message as a reply to the original message
if i == 0:
original_message = await message.reply(f"{MESSAGE_PREFIX} {truncated_chunk}")
logger.info("Sending message to Discord...")
logger.info("Message sent successfully.")
else:
# Send subsequent messages as replies to the original message
await original_message.reply(truncated_chunk)
logger.info("Sending message to Discord...")
logger.info("Message sent successfully.")
# Wait for a short delay before sending the next message to avoid rate-limiting
await asyncio.sleep(1)
chunk = chunk[1800:]
except Exception as e:
logger.error(f"Error analyzing image with model: {e}")
# Run the bot
async def main():
await bot.start(DISCORD_BOT_TOKEN)
asyncio.run(main())