Skip to content

Commit

Permalink
Merge pull request #5 from plasma-umass/adding_functionality
Browse files Browse the repository at this point in the history
Adds logging by default, other mods to allow use with litellm
  • Loading branch information
emeryberger authored Feb 5, 2024
2 parents 067b62b + b320d18 commit d142235
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
{ name="Emery Berger", email="[email protected]" },
{ name="Sam Stern", email="[email protected]" }
]
dependencies = ["tiktoken>=0.5.1", "openai>=1.11.0", "botocore>=1.34.34", "botocore-types>=0.2.2", "types-requests>=2.31.0.20240125"]
dependencies = ["tiktoken>=0.5.2", "openai>=1.11.0", "botocore>=1.34.34", "botocore-types>=0.2.2", "types-requests>=2.31.0.20240125"]
description = "Utilities for our LLM projects (CWhy, ChatDBG, ...)."
readme = "README.md"
requires-python = ">=3.9"
Expand Down
10 changes: 5 additions & 5 deletions src/llm_utils/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from llm_utils.utils import contains_valid_json, extract_code_blocks

log = logging.getLogger("rich")

logging.basicConfig(filename='llm_utils.log', encoding='utf-8', level=logging.DEBUG)

class ChatAPI(abc.ABC, Generic[T]):
prompt_tokens: int
Expand Down Expand Up @@ -117,8 +117,8 @@ class Claude(ChatAPI[ClaudeMessageParam]):
MODEL_ID: str = "anthropic.claude-v2"
SERVICE_NAME: str = "bedrock"
MAX_RETRY: int = 5
prompt_tokens: int = 0
completion_tokens: int = 0
prompt_tokens: int = 0 # FIXME not yet implemented
completion_tokens: int = 0 # ibid

@classmethod
def generate_chatlog(cls, conversations: List[ClaudeMessageParam]) -> str:
Expand Down Expand Up @@ -164,9 +164,9 @@ async def send_message(
first_msg.rstrip() + "\n" + conversation[0]["content"]
)
payload = cls.create_payload(conversation)
for _ in range(5):
for _ in range(cls.MAX_RETRY):
inference = cls.get_inference(payload)
print("INFERENCE", inference["completion"])
log.info(f'Result: {inference["completion"]}')

jsonified_completion = contains_valid_json(inference["completion"])
if jsonified_completion is not None:
Expand Down
14 changes: 11 additions & 3 deletions src/llm_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@
# OpenAI specific.
def count_tokens(model: str, string: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.encoding_for_model(model)
num_tokens = len(encoding.encode(string))
return num_tokens
def extract_after_slash(s: str) -> str:
# Split the string by '/' and return the part after it if '/' is found, else return the whole string
parts = s.split('/', 1) # The '1' ensures we split at the first '/' only
return parts[1] if len(parts) > 1 else s

try:
encoding = tiktoken.encoding_for_model(extract_after_slash(model))
num_tokens = len(encoding.encode(string))
return num_tokens
except KeyError:
return 0


# OpenAI specific.
Expand Down

0 comments on commit d142235

Please sign in to comment.