Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix&Add&Enhance: Fix token calculate; Optimize Styling; Add Animation #68

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
29 changes: 8 additions & 21 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,23 +1,10 @@
black==23.1.0
certifi==2022.12.7
charset-normalizer==3.0.1
click==8.1.3
idna==3.4
markdown-it-py==2.2.0
mdurl==0.1.2
mypy-extensions==1.0.0
packaging==23.0
pathspec==0.11.0
platformdirs==3.1.0
prompt-toolkit==3.0.38
Pygments==2.14.0
PyYAML>=6.0,<6.1
requests==2.28.2
rich==13.4.0
tomli==2.0.1
typing_extensions==4.6.2
urllib3==1.26.14
wcwidth==0.2.6
xdg-base-dirs~=6.0.0
colorama==0.4.6
prompt_toolkit==3.0.42
pyperclip==1.8.2
PySocks==1.7.1
PyYAML==6.0.1
PyYAML==6.0.1
Requests==2.32.3
rich==13.7.1
tabulate==0.9.0
xdg_base_dirs==6.0.1
Binary file modified screenshot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
138 changes: 98 additions & 40 deletions src/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,22 @@
import requests
import sys
import yaml
import threading
import time
import itertools

from pathlib import Path
from prompt_toolkit import PromptSession, HTML
from prompt_toolkit.history import FileHistory
from prompt_toolkit.application.current import get_app
from prompt_toolkit.key_binding import KeyBindings
from rich.console import Console
from rich.logging import RichHandler
from rich.markdown import Markdown
from typing import Optional
from xdg_base_dirs import xdg_config_home
from tabulate import tabulate
from colorama import init, Fore


BASE = Path(xdg_config_home(), "chatgpt-cli")
Expand Down Expand Up @@ -60,7 +67,8 @@
level="INFO",
format="%(message)s",
handlers=[
RichHandler(show_time=False, show_level=False, show_path=False, markup=True)
RichHandler(show_time=False, show_level=False,
show_path=False, markup=True)
],
)

Expand Down Expand Up @@ -93,6 +101,31 @@
"proxy": "socks5://127.0.0.1:2080",
}

# Initialize colorama
init(autoreset=True)

# Function to show spinner in a separate thread


def show_spinner(stop_event, length_char=7, spinner_clear_length=9):
# Initialize colorama
init(autoreset=True)

# Using red, blue, green spinners
spinner_colors = itertools.cycle([Fore.GREEN])
spinner_char = ">" # Using square block character for the spinner

while not stop_event.is_set():
for i in range(length_char + 1):
spinner = f"{next(spinner_colors)}{spinner_char * i:<{length_char}}"
sys.stdout.write(f"\r{spinner}")
sys.stdout.flush()
time.sleep(0.1)

# Clear spinner once done
sys.stdout.write("\r" + " " * spinner_clear_length + "\r")
sys.stdout.flush()


def load_config(config_file: str) -> dict:
"""
Expand Down Expand Up @@ -135,7 +168,8 @@ def get_last_save_file() -> str:
"""
files = [f for f in os.listdir(SAVE_FOLDER) if f.endswith(".json")]
if files:
ts = [f.replace("chatgpt-session-", "").replace(".json", "") for f in files]
ts = [f.replace("chatgpt-session-", "").replace(".json", "")
for f in files]
ts.sort()
return ts[-1]
return None
Expand Down Expand Up @@ -173,7 +207,8 @@ def add_markdown_system_message() -> None:
"""
Try to force ChatGPT to always respond with well formatted code blocks and tables if markdown is enabled.
"""
instruction = "Always use code blocks with the appropriate language tags. If asked for a table always format it using Markdown syntax."
current_date = datetime.datetime.now().strftime("%Y-%m-%d")
instruction = f"You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture. Current date: {current_date}. Always use code blocks with the appropriate language tags. If asked for a table always format it using Markdown syntax."
messages.append({"role": "system", "content": instruction})


Expand All @@ -182,45 +217,50 @@ def calculate_expense(
completion_tokens: int,
prompt_pricing: float,
completion_pricing: float,
) -> float:
"""
Calculate the expense, given the number of tokens and the pricing rates
) -> tuple:
"""
Calculate the total expense, prompt expense, and completion expense,
given the number of tokens and the pricing rates
"""
expense = ((prompt_tokens / 1000) * prompt_pricing) + (
(completion_tokens / 1000) * completion_pricing
)
prompt_expense = (prompt_tokens / 1000) * prompt_pricing
completion_expense = (completion_tokens / 1000) * completion_pricing
total_expense = prompt_expense + completion_expense

# Format to display in decimal notation rounded to 6 decimals
expense = "{:.6f}".format(round(expense, 6))
prompt_expense = "{:.6f}".format(round(prompt_expense, 6))
completion_expense = "{:.6f}".format(round(completion_expense, 6))
total_expense = "{:.6f}".format(round(total_expense, 6))

return expense
return total_expense, prompt_expense, completion_expense


def display_expense(model: str) -> None:
"""
Given the model used, display total tokens used and estimated expense
"""
logger.info(
f"\nTotal tokens used: [green bold]{prompt_tokens + completion_tokens}",
extra={"highlighter": None},
)

console.rule(f"[bold green]Usage and Expense")
if model in PRICING_RATE:
total_expense = calculate_expense(
total_expense, prompt_expense, completion_expense = calculate_expense(
prompt_tokens,
completion_tokens,
PRICING_RATE[model]["prompt"],
PRICING_RATE[model]["completion"],
)
logger.info(
f"Estimated expense: [green bold]${total_expense}",
extra={"highlighter": None},
)
else:
logger.warning(
f"[red bold]No expense estimate available for model {model}",
extra={"highlighter": None},
)
table_data = [
["Token Type", "Count", "Price"],
["Prompt Tokens", prompt_tokens, prompt_expense],
["Total Completion Tokens", completion_tokens, completion_expense],
["Total Tokens", prompt_tokens + completion_tokens, total_expense]
]
table_str = tabulate(table_data, headers="firstrow", tablefmt="grid")

logger.info(table_str)
console.rule(f"[bold green]ChatGPT CLI End")


def print_markdown(content: str, code_blocks: Optional[dict] = None):
Expand All @@ -233,7 +273,8 @@ def print_markdown(content: str, code_blocks: Optional[dict] = None):
return

lines = content.split("\n")
code_block_id = 0 if code_blocks is None else 1 + max(code_blocks.keys(), default=0)
code_block_id = 0 if code_blocks is None else 1 + \
max(code_blocks.keys(), default=0)
code_block_open = False
code_block_language = ""
code_block_content = []
Expand All @@ -252,7 +293,8 @@ def print_markdown(content: str, code_blocks: Optional[dict] = None):
if code_blocks is not None:
code_blocks[code_block_id] = snippet_text
formatted_code_block = f"```{code_block_language}\n{snippet_text}\n```"
console.print(f"Block {code_block_id}", style="blue", justify="right")
console.print(f"Block {code_block_id}",
style="blue", justify="right")
console.print(Markdown(formatted_code_block))
code_block_id += 1
code_block_content = []
Expand Down Expand Up @@ -282,14 +324,12 @@ def start_prompt(
global prompt_tokens, completion_tokens

message = ""

current_time_Human = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
console.rule(f"[bold blue]Human | {current_time_Human}", style="blue")
if config["non_interactive"]:
message = sys.stdin.read()
else:
message = session.prompt(
HTML(f"<b>[{prompt_tokens + completion_tokens}] >>> </b>")
)

message = session.prompt(HTML('<b><ansiblue>>>> </ansiblue></b>'))
if message.lower().strip() == "/q":
raise EOFError
if message.lower() == "":
Expand Down Expand Up @@ -343,7 +383,10 @@ def start_prompt(
body["max_tokens"] = config["max_tokens"]
if config["json_mode"]:
body["response_format"] = {"type": "json_object"}

stop_spinner = threading.Event()
spinner_thread = threading.Thread(
target=show_spinner, args=(stop_spinner,))
spinner_thread.start()
try:
if config["supplier"] == "azure":
headers = {
Expand All @@ -368,37 +411,49 @@ def start_prompt(
proxies=proxy,
)
except requests.ConnectionError:
stop_spinner.set()
logger.error(
"[red bold]Connection error, try again...", extra={"highlighter": None}
)
messages.pop()
raise KeyboardInterrupt
except requests.Timeout:
stop_spinner.set()
logger.error(
"[red bold]Connection timed out, try again...", extra={"highlighter": None}
)
messages.pop()
raise KeyboardInterrupt

finally:
stop_spinner.set()
spinner_thread.join()
match r.status_code:
case 200:
response = r.json()

message_response = response["choices"][0]["message"]
usage_response = response["usage"]

current_time_Agent = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
current_completions_tokens = usage_response["completion_tokens"]
current_prompt_tokens = usage_response["prompt_tokens"]
console.rule(
f"[bold blue][{current_prompt_tokens}]", style="blue")
console.rule(
f"[bold red]Agent | {current_time_Agent}", style="bold red")
if not config["non_interactive"]:
console.line()
if config["markdown"]:
print_markdown(message_response["content"].strip(), copyable_blocks)
print_markdown(
message_response["content"].strip(), copyable_blocks)
else:
print(message_response["content"].strip())
if not config["non_interactive"]:
console.line()

# Update message history and token counters
console.rule(
f"[bold red][{current_completions_tokens}]", style="red")
messages.append(message_response)
prompt_tokens += usage_response["prompt_tokens"]
prompt_tokens = usage_response["prompt_tokens"]
completion_tokens += usage_response["completion_tokens"]
save_history(model, messages, prompt_tokens, completion_tokens)

Expand All @@ -416,11 +471,13 @@ def start_prompt(
)
raise EOFError
# TODO: Develop a better strategy to manage this case
logger.error("[red bold]Invalid request", extra={"highlighter": None})
logger.error("[red bold]Invalid request",
extra={"highlighter": None})
raise EOFError

case 401:
logger.error("[red bold]Invalid API Key", extra={"highlighter": None})
logger.error("[red bold]Invalid API Key",
extra={"highlighter": None})
raise EOFError

case 429:
Expand Down Expand Up @@ -492,8 +549,7 @@ def main(
# If non interactive suppress the logging messages
if non_interactive:
logger.setLevel("ERROR")

logger.info("[bold]ChatGPT CLI", extra={"highlighter": None})
console.rule(f"[bold green]ChatGPT CLI Start")

history = FileHistory(HISTORY_FILE)

Expand Down Expand Up @@ -559,7 +615,8 @@ def main(
logger.info(
f"Supplier: [green bold]{config['supplier']}", extra={"highlighter": None}
)
logger.info(f"Model in use: [green bold]{model}", extra={"highlighter": None})
logger.info(f"Model in use: [green bold]{model}", extra={
"highlighter": None})

# Add the system message for code blocks in case markdown is enabled in the config file
if config["markdown"]:
Expand All @@ -584,7 +641,8 @@ def main(
global prompt_tokens, completion_tokens
# If this feature is used --context is cleared
messages.clear()
history_data = load_history_data(os.path.join(SAVE_FOLDER, restore_file))
history_data = load_history_data(
os.path.join(SAVE_FOLDER, restore_file))
for message in history_data["messages"]:
messages.append(message)
prompt_tokens += history_data["prompt_tokens"]
Expand Down