Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more resilient anthropic (model settings, mdown prompt structure, json correction) #34

Merged
merged 4 commits into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ bedrock = [
"botocore>=1.34.11",
"boto3>=1.34.11",
"awscli>=1.32.11",
"fix_busted_json>0.0.18",
]
chere = [
"cohere>=4.46",
Expand Down
52 changes: 31 additions & 21 deletions src/pydantic_prompter/llm_providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,40 @@
from typing import Type, Dict, Union
from pydantic_prompter.annotation_parser import AnnotationParser
from pydantic_prompter.common import logger
from pydantic_prompter.llm_providers.bedrock_anthropic import BedRockAnthropic
from pydantic_prompter.llm_providers.bedrock_cohere import BedRockCohere
from pydantic_prompter.llm_providers.bedrock_llama2 import BedRockLlama2
from pydantic_prompter.llm_providers.cohere import Cohere
from pydantic_prompter.llm_providers.openai import OpenAI
from pydantic_prompter.llm_providers.base import LLM

# Mapping of llm type and model_name prefixes to their respective classes
LLM_MODEL_MAP: Dict[str, Dict[str, Type[LLM]]] = {
"openai": {
"default": OpenAI,
},
"bedrock": {
"anthropic": BedRockAnthropic,
"cohere": BedRockCohere,
"meta": BedRockLlama2,
},
"cohere": {
"command": Cohere,
}
}

def get_llm(llm: str, model_name: str, parser: AnnotationParser) -> "LLM":
if llm == "openai":
llm_inst = OpenAI(model_name, parser)
elif llm == "bedrock" and model_name.startswith("anthropic"):
logger.debug("Using bedrock provider with Anthropic model")
llm_inst = BedRockAnthropic(model_name, parser)
elif llm == "bedrock" and model_name.startswith("cohere"):
logger.debug("Using bedrock provider with Cohere model")
llm_inst = BedRockCohere(model_name, parser)
elif llm == "bedrock" and model_name.startswith("meta"):
logger.debug("Using bedrock provider with Cohere model")
llm_inst = BedRockLlama2(model_name, parser)
elif llm == "cohere" and model_name.startswith("command"):
logger.debug("Using Cohere model")
llm_inst = Cohere(model_name, parser)
else:
raise Exception(f"Model not implemented {llm}, {model_name}")
logger.debug(
f"Using {llm_inst.__class__.__name__} provider with model {model_name}"
)
return llm_inst
def get_llm(llm: str, model_name: str, parser: AnnotationParser, model_settings: dict | None = None) -> LLM:
if llm not in LLM_MODEL_MAP:
raise ValueError(f"LLM type '{llm}' is not implemented")

# Extract the prefix from the model name. Adjust this logic as necessary.
model_prefix = model_name.split('.')[0] # Extract 'anthropic' from 'anthropic.claude-3-sonnet-20240229-v1:0'

model_class = LLM_MODEL_MAP.get(llm, {}).get(model_prefix, None)

if model_class is None:
raise ValueError(f"Model prefix '{model_prefix}' for LLM type '{llm}' is not implemented")

logger.debug(f"Using {model_class.__name__} provider with model {model_name}")

return model_class(model_name, parser, model_settings)
70 changes: 44 additions & 26 deletions src/pydantic_prompter/llm_providers/bedrock_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import json
import random
from typing import List

from typing import List, Optional, Dict
from fix_busted_json import repair_json, largest_json
from pydantic_prompter.common import Message, logger
from pydantic_prompter.llm_providers.bedrock_base import BedRock

from pydantic_prompter.annotation_parser import AnnotationParser

class BedRockAnthropic(BedRock):
def __init__(self, model_name: str, parser: AnnotationParser, model_settings: Optional[Dict] = None):
super().__init__(model_name, parser)
self.model_settings = model_settings or {
"temperature": random.uniform(0, 1),
"max_tokens": 8000,
"stop_sequences": ["Human:"],
"anthropic_version": "bedrock-2023-05-31",
}

def _build_prompt(self, messages: List[Message], params: dict | str):
return "\n".join([m.content for m in messages])

Expand Down Expand Up @@ -45,37 +54,46 @@ def call(
) -> str:

if scheme:
system_message = f"""Act like a REST API that answers the question contained in <question> tags.
Your response should be within <json></json> xml tags in JSON format with the schema
specified in the <json_schema> tags.
DO NOT add any other text other than the JSON response
system_message = f"""Act like a REST API that performs the requested operation the user asked according to guidelines provided.
Your response should be a valid JSON format, strictly adhering to the Pydantic schema provided in the pydantic_schema section.
Stick to the facts and details in the provided data, and follow the guidelines closely.
Respond in a structured JSON format according to the provided schema.
DO NOT add any other text other than the requested JSON response.

## pydantic_schema:

<json_schema>
{json.dumps(scheme, indent=4)}
</json_schema>

"""
else: # return_type:
system_message = f"""Act like an answer bot that answers the question contained in <question> tags.
Your response should be within <{return_type}></{return_type}> xml tags in {return_type} format .
DO NOT add any other text other than the STRING response
"""
system_message = f"""Act like a REST API that performs the requested operation the user asked according to guidelines provided.
Your response should be according to the format requested in the return_type section.
Stick to the facts and details in the provided data, and follow the guidelines closely.
Respond in a structured JSON format according to the provided schema.
DO NOT add any other text other than the requested return_type response.

## return_type:

{return_type}

"""

final_messages = [m.model_dump() for m in messages]
final_messages = self.fix_messages(final_messages)
body = json.dumps(
{
"system": system_message,
"max_tokens": 8000,
"messages": final_messages,
"stop_sequences": [self._stop_sequence],
"temperature": random.uniform(0, 1),
"anthropic_version": "bedrock-2023-05-31",
}
)

response = self._boto_invoke(body)
res = response.get("body").read().decode()
response_body = json.loads(res)
# Ensure stop_sequences and anthropic_version are always included
body = {
"system": system_message,
"messages": final_messages,
"stop_sequences": self.model_settings.get("stop_sequences", [self._stop_sequence]),
"anthropic_version": self.model_settings.get("anthropic_version", "bedrock-2023-05-31"),
**self.model_settings
}

response = self._boto_invoke(json.dumps(body))
response_text = response.get("body").read().decode()
response_json = repair_json(largest_json(response_text))
response_body = json.loads(response_json)

logger.info(response_body)
return response_body.get("content")[0]["text"]
18 changes: 13 additions & 5 deletions src/pydantic_prompter/llm_providers/bedrock_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
import json
import random
from typing import List

from jinja2 import Template

from fix_busted_json import repair_json
from pydantic_prompter.common import Message, logger
from pydantic_prompter.exceptions import BedRockAuthenticationError
from pydantic_prompter.llm_providers.base import LLM
Expand Down Expand Up @@ -63,15 +62,22 @@ def _boto_invoke(self, body):
try:
logger.debug(f"Request body: \n{body}")
import boto3

from botocore.config import Config

session = boto3.Session(
aws_access_key_id=self.settings.aws_access_key_id,
aws_secret_access_key=self.settings.aws_secret_access_key,
aws_session_token=self.settings.aws_session_token,
profile_name=self.settings.aws_profile,
region_name=self.settings.aws_default_region,
)
client = session.client("bedrock-runtime")
config = Config(
read_timeout=120, # 2 minutes before timeout
retries={'max_attempts': 5, 'mode': 'adaptive'}, # retry 5 times adaptivly
max_pool_connections=50 # allow for significant concurrency
)
client = session.client('bedrock-runtime', config=config)
#execute the model
response = client.invoke_model(
body=body,
modelId=self.model_name,
Expand Down Expand Up @@ -101,6 +107,8 @@ def call(
)

response = self._boto_invoke(body)
response_body = json.loads(response.get("body").read().decode())
response_text = response.get("body").read().decode()
response_json = repair_json(response_text)
response_body = json.loads(response_json)
logger.info(response_body)
return response_body.get("completion")
12 changes: 7 additions & 5 deletions src/pydantic_prompter/prompter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional, Dict

from jinja2 import Template
from retry import retry
Expand All @@ -15,11 +15,11 @@


class _Pr:
def __init__(self, function, llm: str, model_name: str, jinja: bool):
def __init__(self, function, llm: str, model_name: str, jinja: bool, model_settings: Optional[Dict] = None):
self.jinja = jinja
self.function = function
self.parser = AnnotationParser.get_parser(function)
self.llm = get_llm(llm=llm, parser=self.parser, model_name=model_name)
self.llm = get_llm(llm=llm, parser=self.parser, model_name=model_name, model_settings=model_settings)

@retry(tries=3, delay=1, logger=logger, exceptions=(Retryable,))
def __call__(self, *args, **inputs):
Expand Down Expand Up @@ -92,15 +92,17 @@ def call_llm(self, llm_data: LLMDataAndResult) -> LLMDataAndResult:


class Prompter:
def __init__(self, llm: str, model_name: str, jinja=False):
def __init__(self, llm: str, model_name: str, jinja=False, model_settings: Optional[Dict] = None):
self.model_name = model_name
self.llm = llm
self.jinja = jinja
self.model_settings = model_settings

def __call__(self, function):
return _Pr(
function=function,
jinja=self.jinja,
llm=self.llm,
model_name=self.model_name,
)
model_settings=self.model_settings,
)