From 74b5dc4db20b8e3274e8d0bf8392ebc78dfe5117 Mon Sep 17 00:00:00 2001 From: Emery Berger Date: Sun, 4 Feb 2024 17:07:19 -0500 Subject: [PATCH 1/5] Switched to logging here. Needs parameterization for the logger. --- src/llm_utils/chat.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llm_utils/chat.py b/src/llm_utils/chat.py index 886e7dd..13e3fdc 100644 --- a/src/llm_utils/chat.py +++ b/src/llm_utils/chat.py @@ -25,7 +25,6 @@ log = logging.getLogger("rich") - class ChatAPI(abc.ABC, Generic[T]): prompt_tokens: int completion_tokens: int @@ -166,7 +165,7 @@ async def send_message( payload = cls.create_payload(conversation) for _ in range(5): inference = cls.get_inference(payload) - print("INFERENCE", inference["completion"]) + log.info(f'Result: {inference["completion"]}') jsonified_completion = contains_valid_json(inference["completion"]) if jsonified_completion is not None: From 55934df5f508216edcf7a7746def1ebe9455c7bf Mon Sep 17 00:00:00 2001 From: Emery Berger Date: Sun, 4 Feb 2024 17:20:40 -0500 Subject: [PATCH 2/5] Some clarification and use of a variable, log to file for now. --- src/llm_utils/chat.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/llm_utils/chat.py b/src/llm_utils/chat.py index 13e3fdc..2c5d16d 100644 --- a/src/llm_utils/chat.py +++ b/src/llm_utils/chat.py @@ -24,6 +24,7 @@ from llm_utils.utils import contains_valid_json, extract_code_blocks log = logging.getLogger("rich") +logging.basicConfig(filename='llm_utils.log', encoding='utf-8', level=logging.DEBUG) class ChatAPI(abc.ABC, Generic[T]): prompt_tokens: int @@ -116,8 +117,8 @@ class Claude(ChatAPI[ClaudeMessageParam]): MODEL_ID: str = "anthropic.claude-v2" SERVICE_NAME: str = "bedrock" MAX_RETRY: int = 5 - prompt_tokens: int = 0 - completion_tokens: int = 0 + prompt_tokens: int = 0 # FIXME not yet implemented + completion_tokens: int = 0 # ibid @classmethod def generate_chatlog(cls, conversations: List[ClaudeMessageParam]) -> str: @@ -163,7 +164,7 @@ async def send_message( first_msg.rstrip() + "\n" + conversation[0]["content"] ) payload = cls.create_payload(conversation) - for _ in range(5): + for _ in range(cls.MAX_RETRY): inference = cls.get_inference(payload) log.info(f'Result: {inference["completion"]}') From ca460dcc157d8d963136cb91cc62e7ce2e346969 Mon Sep 17 00:00:00 2001 From: Emery Berger Date: Mon, 5 Feb 2024 12:44:28 -0500 Subject: [PATCH 3/5] Newer tiktoken. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f4c1334..e550853 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ { name="Emery Berger", email="emery.berger@gmail.com" }, { name="Sam Stern", email="sternj@umass.edu" } ] -dependencies = ["tiktoken>=0.5.1", "openai>=1.11.0", "botocore>=1.34.34", "botocore-types>=0.2.2", "types-requests>=2.31.0.20240125"] +dependencies = ["tiktoken>=0.5.2", "openai>=1.11.0", "botocore>=1.34.34", "botocore-types>=0.2.2", "types-requests>=2.31.0.20240125"] description = "Utilities for our LLM projects (CWhy, ChatDBG, ...)." readme = "README.md" requires-python = ">=3.9" From 4410ff3c6a682c916d1243ef0dd806c62962afd4 Mon Sep 17 00:00:00 2001 From: Emery Berger Date: Mon, 5 Feb 2024 12:44:52 -0500 Subject: [PATCH 4/5] Allow slashes in model names. Degrade gracefully if tokenizer is not present. --- src/llm_utils/utils.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/llm_utils/utils.py b/src/llm_utils/utils.py index 8d5e77d..3fa5ebc 100644 --- a/src/llm_utils/utils.py +++ b/src/llm_utils/utils.py @@ -7,9 +7,17 @@ # OpenAI specific. def count_tokens(model: str, string: str) -> int: """Returns the number of tokens in a text string.""" - encoding = tiktoken.encoding_for_model(model) - num_tokens = len(encoding.encode(string)) - return num_tokens + def extract_after_slash(s): + # Split the string by '/' and return the part after it if '/' is found, else return the whole string + parts = s.split('/', 1) # The '1' ensures we split at the first '/' only + return parts[1] if len(parts) > 1 else s + + try: + encoding = tiktoken.encoding_for_model(extract_after_slash(model)) + num_tokens = len(encoding.encode(string)) + return num_tokens + except KeyError: + return 0 # OpenAI specific. From b320d18e7dbc1445b87131393a9b20f2f08844d8 Mon Sep 17 00:00:00 2001 From: Emery Berger Date: Mon, 5 Feb 2024 13:05:28 -0500 Subject: [PATCH 5/5] Added type annotations. --- src/llm_utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llm_utils/utils.py b/src/llm_utils/utils.py index 3fa5ebc..e90cdca 100644 --- a/src/llm_utils/utils.py +++ b/src/llm_utils/utils.py @@ -7,7 +7,7 @@ # OpenAI specific. def count_tokens(model: str, string: str) -> int: """Returns the number of tokens in a text string.""" - def extract_after_slash(s): + def extract_after_slash(s: str) -> str: # Split the string by '/' and return the part after it if '/' is found, else return the whole string parts = s.split('/', 1) # The '1' ensures we split at the first '/' only return parts[1] if len(parts) > 1 else s