Skip to content

Commit

Permalink
Merge pull request #296 from oshoma/75-percent-notifier-rebased
Browse files Browse the repository at this point in the history
75 percent notifier rebased
  • Loading branch information
Eyobyb authored Feb 16, 2024
2 parents d7aa5b6 + e621320 commit eeb770b
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 14 deletions.
5 changes: 4 additions & 1 deletion src/.env-sample
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,7 @@ SLACK_VERIFICATION_TOKEN= # from `Basic Information` page of your Slack
# AWS_SECRET_KEY= # AWS Secret Key

# Github. Optional.
# GITHUB_AUTH_TOKEN= # Authorization token for Github API
# GITHUB_AUTH_TOKEN= # Authorization token for Github API

# Language model usage limits. Optional.
# DAILY_TOKEN_LIMIT # Daily limit on the number of tokens users can use.
6 changes: 4 additions & 2 deletions src/apps/slackapp/slackapp/bolt_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,11 @@ def event_test(client, say, event):
)
combined_id = user_id + "_" + team_id

slack_verbose_logger = SlackVerboseLogger(say, thread_ts)
if cfg.FLASK_DEBUG:
can_excute = True
else:
user_db = UserUsageTracker(max_daily_token=cfg.DAILY_TOKEN_LIMIT)
user_db = UserUsageTracker(verbose_logger=slack_verbose_logger)

usage_cheker = user_db.check_usage(
user_id=user_id,
Expand Down Expand Up @@ -245,13 +246,14 @@ def event_test(client, say, event):
openai_api_key=cfg.OPENAI_API_KEY,
user_id=user_id,
team_id=team_id,
verbose_logger = slack_verbose_logger,
temperature=cfg.TEMPRATURE,
)

results = get_response(
question,
previous_messages,
verbose_logger=SlackVerboseLogger(say, thread_ts),
verbose_logger=slack_verbose_logger,
bot_info=bot,
llm=llm,
)
Expand Down
2 changes: 1 addition & 1 deletion src/sherpa_ai/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
LOG_LEVEL = environ.get("LOG_LEVEL", "INFO").upper()

# Usage setting
DAILY_TOKEN_LIMIT = environ.get("DAILY_TOKEN_LIMIT") or 20000
DAILY_TOKEN_LIMIT = float(environ.get("DAILY_TOKEN_LIMIT") or 20000)
DAILY_LIMIT_REACHED_MESSAGE = (
environ.get("DAILY_LIMIT_REACHED_MESSAGE")
or "Sorry for the inconvenience, but it seems that you have exceeded your daily token limit. As a result, you will need to try again after 24 hours. Thank you for your understanding."
Expand Down
76 changes: 68 additions & 8 deletions src/sherpa_ai/database/user_usage_tracker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import time

from anyio import Path
import boto3
from sqlalchemy import TIMESTAMP, Boolean, Column, Integer, String, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

import sherpa_ai.config as cfg
from sherpa_ai.verbose_loggers.base import BaseVerboseLogger
from sherpa_ai.verbose_loggers.verbose_loggers import DummyVerboseLogger

Base = declarative_base()

Expand All @@ -18,6 +21,7 @@ class UsageTracker(Base):
token = Column(Integer)
timestamp = Column(Integer)
reset_timestamp = Column(Boolean)
reminded_timestamp = Column(Boolean)


class Whitelist(Base):
Expand All @@ -31,17 +35,22 @@ class UserUsageTracker:
def __init__(
self,
db_name=cfg.DB_NAME,
max_daily_token=20000,
verbose_logger: BaseVerboseLogger = DummyVerboseLogger(),
):
self.engine = create_engine(db_name)
Session = sessionmaker(bind=self.engine)
self.session = Session()
self.create_table()
self.max_daily_token = int(max_daily_token)
self.max_daily_token = cfg.DAILY_TOKEN_LIMIT
self.verbose_logger = verbose_logger
self.is_reminded = False
self.usage_percentage_allowed = 75

def download_from_s3(self, bucket_name, s3_file_key, local_file_path):
s3 = boto3.client("s3")
s3.download_file(bucket_name, s3_file_key, local_file_path)
file_path = Path("./token_counter.db")
if not file_path.exists():
s3 = boto3.client("s3")
s3.download_file(bucket_name, s3_file_key, local_file_path)

def upload_to_s3(self, local_file_path, bucket_name, s3_file_key):
s3 = boto3.client("s3")
Expand Down Expand Up @@ -70,16 +79,55 @@ def get_whitelist_by_user_id(self, user_id):
def is_in_whitelist(self, user_id):
return bool(self.get_whitelist_by_user_id(user_id))

def add_data(self, combined_id, token, reset_timestamp=False):
def add_and_check_data(
self, combined_id, token, reset_timestamp=False, reminded_timestamp=False
):
self.add_data(
combined_id=combined_id,
token=token,
reset_timestamp=reset_timestamp,
reminded_timestamp=reminded_timestamp,
)
self.remind_user_of_daily_token_limit(combined_id=combined_id)

def add_data(
self, combined_id, token, reset_timestamp=False, reminded_timestamp=False
):
data = UsageTracker(
user_id=combined_id,
token=token,
timestamp=int(time.time()),
reset_timestamp=reset_timestamp,
reminded_timestamp=reminded_timestamp,
)
self.session.add(data)
self.session.commit()

def percentage_used(self, combined_id):
total_token_since_last_reset = self.get_sum_of_tokens_since_last_reset(
user_id=combined_id
)
return (total_token_since_last_reset * 100) / self.max_daily_token

def remind_user_of_daily_token_limit(self, combined_id):
split_parts = combined_id.split("_")
user_id = ""
if len(split_parts) > 0:
user_id = split_parts[0]

user_is_whitelisted = self.is_in_whitelist(user_id)
self.is_reminded = self.check_if_reminded(combined_id=combined_id)
if not user_is_whitelisted and not self.is_reminded:
if (
self.percentage_used(combined_id=combined_id) > self.usage_percentage_allowed
and not self.is_reminded
):
self.add_data(combined_id=combined_id, token=0, reminded_timestamp=True)

self.verbose_logger.log(
f"Hi friend, you have used up {self.usage_percentage_allowed}% of your daily token limit. once you go over the limit there will be a 24 hour cool down period after which you can continue using Sherpa! be awesome!"
)

def get_data_since_last_reset(self, user_id):
last_reset_info = self.get_last_reset_info(user_id)

Expand All @@ -92,6 +140,7 @@ def get_data_since_last_reset(self, user_id):
"token": item.token,
"timestamp": item.timestamp,
"reset_timestamp": item.reset_timestamp,
"reminded_timestamp": item.reminded_timestamp,
}
for item in data
]
Expand All @@ -111,10 +160,19 @@ def get_data_since_last_reset(self, user_id):
"token": item.token,
"timestamp": item.timestamp,
"reset_timestamp": item.reset_timestamp,
"reminded_timestamp": item.reminded_timestamp,
}
for item in data
]

def check_if_reminded(self, combined_id):
data_list = self.get_data_since_last_reset(combined_id)
is_reminded_true = any(
item.get("reminded_timestamp", False) for item in data_list
)

return is_reminded_true

def get_sum_of_tokens_since_last_reset(self, user_id):
data_since_last_reset = self.get_data_since_last_reset(user_id)

Expand All @@ -125,7 +183,9 @@ def get_sum_of_tokens_since_last_reset(self, user_id):
return token_sum

def reset_usage(self, combined_id, token_amount):
self.add_data(combined_id=combined_id, token=token_amount, reset_timestamp=True)
self.add_and_check_data(
combined_id=combined_id, token=token_amount, reset_timestamp=True
)

def get_last_reset_info(self, combined_id):
data = (
Expand Down Expand Up @@ -192,7 +252,7 @@ def check_usage(self, user_id, combined_id, token_amount):
"time_left": self.seconds_to_hms(time_since_last_reset),
}
else:
self.add_data(combined_id=combined_id, token=token_amount)
self.add_and_check_data(combined_id=combined_id, token=token_amount)
return {
"token-left": self.max_daily_token
- total_token_since_last_reset,
Expand Down
7 changes: 5 additions & 2 deletions src/sherpa_ai/models/sherpa_base_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from langchain.schema import BaseMessage, ChatResult

from sherpa_ai.database.user_usage_tracker import UserUsageTracker
from sherpa_ai.verbose_loggers.base import BaseVerboseLogger


class SherpaBaseChatModel(BaseChatModel):
team_id: typing.Optional[str] = None
user_id: typing.Optional[str] = None
verbose_logger: BaseVerboseLogger = None

def _agenerate(
self,
Expand Down Expand Up @@ -44,7 +46,7 @@ def _generate(
total_token = token_before + token_after
if self.team_id and self.user_id:
combined_id = self.user_id + "_" + self.team_id
user_db = UserUsageTracker()
user_db = UserUsageTracker(verbose_logger=self.verbose_logger)
user_db.add_data(combined_id=combined_id, token=total_token)
user_db.close_connection()

Expand All @@ -54,6 +56,7 @@ def _generate(
class SherpaChatOpenAI(ChatOpenAI):
team_id: typing.Optional[str] = None
user_id: typing.Optional[str] = None
verbose_logger: BaseVerboseLogger = None

def _agenerate(
self,
Expand Down Expand Up @@ -83,7 +86,7 @@ def _generate(
total_token = token_before + token_after
if self.team_id and self.user_id:
combined_id = self.user_id + "_" + self.team_id
user_db = UserUsageTracker()
user_db = UserUsageTracker(verbose_logger=self.verbose_logger)
user_db.add_data(combined_id=combined_id, token=total_token)
user_db.close_connection()

Expand Down

0 comments on commit eeb770b

Please sign in to comment.