Skip to content

Commit

Permalink
Check early if file tokens exceed or near limit
Browse files Browse the repository at this point in the history
  • Loading branch information
biobootloader authored and PCSwingle committed Aug 4, 2023
1 parent 77c2bd0 commit c6db6ee
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 12 deletions.
10 changes: 4 additions & 6 deletions mentat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .conversation import Conversation
from .errors import MentatError, UserError
from .git_handler import get_shared_git_root_for_paths
from .llm_api import CostTracker, count_tokens, setup_api_key
from .llm_api import CostTracker, setup_api_key
from .logging_config import setup_logging
from .user_input_manager import UserInputManager, UserQuitInterrupt

Expand Down Expand Up @@ -77,7 +77,7 @@ def run(paths: Iterable[str], exclude_paths: Optional[Iterable[str]] = None):
MentatError,
) as e:
if str(e):
cprint(e, "light_yellow")
cprint("\n" + str(e), "red")
finally:
cost_tracker.display_total_cost()

Expand All @@ -89,7 +89,6 @@ def loop(
) -> None:
git_root = get_shared_git_root_for_paths(paths)
config = ConfigManager(git_root)
conv = Conversation(config, cost_tracker)
user_input_manager = UserInputManager(config)
code_file_manager = CodeFileManager(
paths,
Expand All @@ -98,17 +97,16 @@ def loop(
config,
git_root,
)
conv = Conversation(config, cost_tracker, code_file_manager)

tokens = count_tokens(code_file_manager.get_code_message())
cprint(f"\nFile token count: {tokens}", "cyan")
cprint("Type 'q' or use Ctrl-C to quit at any time.\n", color="cyan")
cprint("What can I do for you?", color="light_blue")
need_user_request = True
while True:
if need_user_request:
user_response = user_input_manager.collect_user_input()
conv.add_user_message(user_response)
explanation, code_changes = conv.get_model_response(code_file_manager, config)
explanation, code_changes = conv.get_model_response(config)

if code_changes:
need_user_request = get_user_feedback_on_changes(
Expand Down
1 change: 1 addition & 0 deletions mentat/code_file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
get_paths_with_git_diffs(self.git_root),
self.git_root,
)
print()

def _set_file_paths(
self, paths: Iterable[str], exclude_paths: Iterable[str]
Expand Down
34 changes: 28 additions & 6 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from termcolor import cprint

from .code_change import CodeChange
from .code_file_manager import CodeFileManager
from .config_manager import ConfigManager
Expand All @@ -7,12 +9,34 @@


class Conversation:
def __init__(self, config: ConfigManager, cost_tracker: CostTracker):
def __init__(
self,
config: ConfigManager,
cost_tracker: CostTracker,
code_file_manager: CodeFileManager,
):
self.messages = []
self.add_system_message(system_prompt)
self.cost_tracker = cost_tracker
self.code_file_manager = code_file_manager
self.allow_32k = check_model_availability(config.allow_32k())

tokens = count_tokens(code_file_manager.get_code_message())
token_limit = 32768 if self.allow_32k else 8192
if tokens > token_limit:
raise KeyboardInterrupt(
f"Included files already exceed token limit ({tokens} / {token_limit})."
" Please try running again with a reduced number of files."
)
elif tokens + 1000 > token_limit:
cprint(
f"Warning: Included files are close to token limit ({tokens} /"
f" {token_limit}), you may not be able to have a long conversation.",
"red",
)
else:
cprint(f"File token count: {tokens} / {token_limit}", "cyan")

def add_system_message(self, message: str):
self.messages.append({"role": "system", "content": message})

Expand All @@ -22,18 +46,16 @@ def add_user_message(self, message: str):
def add_assistant_message(self, message: str):
self.messages.append({"role": "assistant", "content": message})

def get_model_response(
self, code_file_manager: CodeFileManager, config: ConfigManager
) -> (str, list[CodeChange]):
def get_model_response(self, config: ConfigManager) -> (str, list[CodeChange]):
messages = self.messages.copy()

code_message = code_file_manager.get_code_message()
code_message = self.code_file_manager.get_code_message()
messages.append({"role": "system", "content": code_message})

model, num_prompt_tokens = choose_model(messages, self.allow_32k)

state = run_async_stream_and_parse_llm_response(
messages, model, code_file_manager
messages, model, self.code_file_manager
)

self.cost_tracker.display_api_call_stats(
Expand Down

0 comments on commit c6db6ee

Please sign in to comment.