From 53b349ceebab4ba0ca03f9989c43044e90998759 Mon Sep 17 00:00:00 2001 From: Zohar Babin Date: Fri, 24 May 2024 10:56:49 +0300 Subject: [PATCH 1/3] Refactored get_llm function for improved maintainability and flexibility - Introduced a dictionary (LLM_MODEL_MAP) to map LLM types and model name prefixes to their respective classes. - Updated the model prefix extraction logic to correctly handle the structure of model names (e.g., extracting 'anthropic' from 'anthropic.claude-3-sonnet-20240229-v1:0'). - Improved error handling to provide more specific error messages when LLM types or model prefixes are not implemented. - Added type annotations and imports for better clarity and maintainability. - Enhanced logging to provide clear information on the LLM provider and model being used. These changes make the function more resilient, modular, and easier to extend with new LLM models. --- .../llm_providers/__init__.py | 52 +++++++++++-------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/src/pydantic_prompter/llm_providers/__init__.py b/src/pydantic_prompter/llm_providers/__init__.py index 2c0532c..925fe0e 100644 --- a/src/pydantic_prompter/llm_providers/__init__.py +++ b/src/pydantic_prompter/llm_providers/__init__.py @@ -1,3 +1,4 @@ +from typing import Type, Dict, Union from pydantic_prompter.annotation_parser import AnnotationParser from pydantic_prompter.common import logger from pydantic_prompter.llm_providers.bedrock_anthropic import BedRockAnthropic @@ -5,26 +6,35 @@ from pydantic_prompter.llm_providers.bedrock_llama2 import BedRockLlama2 from pydantic_prompter.llm_providers.cohere import Cohere from pydantic_prompter.llm_providers.openai import OpenAI +from pydantic_prompter.llm_providers.base import LLM +# Mapping of llm type and model_name prefixes to their respective classes +LLM_MODEL_MAP: Dict[str, Dict[str, Type[LLM]]] = { + "openai": { + "default": OpenAI, + }, + "bedrock": { + "anthropic": BedRockAnthropic, + "cohere": BedRockCohere, + "meta": BedRockLlama2, + }, + "cohere": { + "command": Cohere, + } +} -def get_llm(llm: str, model_name: str, parser: AnnotationParser) -> "LLM": - if llm == "openai": - llm_inst = OpenAI(model_name, parser) - elif llm == "bedrock" and model_name.startswith("anthropic"): - logger.debug("Using bedrock provider with Anthropic model") - llm_inst = BedRockAnthropic(model_name, parser) - elif llm == "bedrock" and model_name.startswith("cohere"): - logger.debug("Using bedrock provider with Cohere model") - llm_inst = BedRockCohere(model_name, parser) - elif llm == "bedrock" and model_name.startswith("meta"): - logger.debug("Using bedrock provider with Cohere model") - llm_inst = BedRockLlama2(model_name, parser) - elif llm == "cohere" and model_name.startswith("command"): - logger.debug("Using Cohere model") - llm_inst = Cohere(model_name, parser) - else: - raise Exception(f"Model not implemented {llm}, {model_name}") - logger.debug( - f"Using {llm_inst.__class__.__name__} provider with model {model_name}" - ) - return llm_inst +def get_llm(llm: str, model_name: str, parser: AnnotationParser, model_settings: dict | None = None) -> LLM: + if llm not in LLM_MODEL_MAP: + raise ValueError(f"LLM type '{llm}' is not implemented") + + # Extract the prefix from the model name. Adjust this logic as necessary. + model_prefix = model_name.split('.')[0] # Extract 'anthropic' from 'anthropic.claude-3-sonnet-20240229-v1:0' + + model_class = LLM_MODEL_MAP.get(llm, {}).get(model_prefix, None) + + if model_class is None: + raise ValueError(f"Model prefix '{model_prefix}' for LLM type '{llm}' is not implemented") + + logger.debug(f"Using {model_class.__name__} provider with model {model_name}") + + return model_class(model_name, parser, model_settings) From 69c899f39e2943cb6335cde7c846b8281ed1ff4f Mon Sep 17 00:00:00 2001 From: Zohar Babin Date: Fri, 24 May 2024 10:58:23 +0300 Subject: [PATCH 2/3] Refactored Prompter and BedRockAnthropic for configurability and flexibility - Updated Prompter to allow dynamic configuration of model settings through the decorator. - Added `model_settings` parameter to the `@Prompter` decorator. - Modified `_Pr` class to pass `model_settings` to the LLM provider. - Refactored BedRockAnthropic to support configurable model settings. - Introduced `model_settings` parameter in the constructor to allow dynamic configuration of settings like `max_tokens`, `temperature`, `top_p`, `top_k`, `stop_sequences`, and `anthropic_version`. - Ensured `stop_sequences` and `anthropic_version` are always included in the request body. - Improved error handling and logging for better traceability. These changes enhance the flexibility and configurability of the prompter system, allowing for dynamic adjustment of LLM settings to suit different scenarios and requirements. --- .../llm_providers/bedrock_anthropic.py | 34 ++++++++++++------- src/pydantic_prompter/prompter.py | 12 ++++--- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/src/pydantic_prompter/llm_providers/bedrock_anthropic.py b/src/pydantic_prompter/llm_providers/bedrock_anthropic.py index 86aa113..64f5fdf 100644 --- a/src/pydantic_prompter/llm_providers/bedrock_anthropic.py +++ b/src/pydantic_prompter/llm_providers/bedrock_anthropic.py @@ -1,12 +1,21 @@ import json import random -from typing import List +from typing import List, Optional, Dict from pydantic_prompter.common import Message, logger from pydantic_prompter.llm_providers.bedrock_base import BedRock - +from pydantic_prompter.annotation_parser import AnnotationParser class BedRockAnthropic(BedRock): + def __init__(self, model_name: str, parser: AnnotationParser, model_settings: Optional[Dict] = None): + super().__init__(model_name, parser) + self.model_settings = model_settings or { + "temperature": random.uniform(0, 1), + "max_tokens": 8000, + "stop_sequences": ["Human:"], + "anthropic_version": "bedrock-2023-05-31", + } + def _build_prompt(self, messages: List[Message], params: dict | str): return "\n".join([m.content for m in messages]) @@ -62,18 +71,17 @@ def call( final_messages = [m.model_dump() for m in messages] final_messages = self.fix_messages(final_messages) - body = json.dumps( - { - "system": system_message, - "max_tokens": 8000, - "messages": final_messages, - "stop_sequences": [self._stop_sequence], - "temperature": random.uniform(0, 1), - "anthropic_version": "bedrock-2023-05-31", - } - ) - response = self._boto_invoke(body) + # Ensure stop_sequences and anthropic_version are always included + body = { + "system": system_message, + "messages": final_messages, + "stop_sequences": self.model_settings.get("stop_sequences", [self._stop_sequence]), + "anthropic_version": self.model_settings.get("anthropic_version", "bedrock-2023-05-31"), + **self.model_settings + } + + response = self._boto_invoke(json.dumps(body)) res = response.get("body").read().decode() response_body = json.loads(res) diff --git a/src/pydantic_prompter/prompter.py b/src/pydantic_prompter/prompter.py index 3a1c460..29006c5 100644 --- a/src/pydantic_prompter/prompter.py +++ b/src/pydantic_prompter/prompter.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional, Dict from jinja2 import Template from retry import retry @@ -15,11 +15,11 @@ class _Pr: - def __init__(self, function, llm: str, model_name: str, jinja: bool): + def __init__(self, function, llm: str, model_name: str, jinja: bool, model_settings: Optional[Dict] = None): self.jinja = jinja self.function = function self.parser = AnnotationParser.get_parser(function) - self.llm = get_llm(llm=llm, parser=self.parser, model_name=model_name) + self.llm = get_llm(llm=llm, parser=self.parser, model_name=model_name, model_settings=model_settings) @retry(tries=3, delay=1, logger=logger, exceptions=(Retryable,)) def __call__(self, *args, **inputs): @@ -92,10 +92,11 @@ def call_llm(self, llm_data: LLMDataAndResult) -> LLMDataAndResult: class Prompter: - def __init__(self, llm: str, model_name: str, jinja=False): + def __init__(self, llm: str, model_name: str, jinja=False, model_settings: Optional[Dict] = None): self.model_name = model_name self.llm = llm self.jinja = jinja + self.model_settings = model_settings def __call__(self, function): return _Pr( @@ -103,4 +104,5 @@ def __call__(self, function): jinja=self.jinja, llm=self.llm, model_name=self.model_name, - ) + model_settings=self.model_settings, + ) \ No newline at end of file From 5c0f4b68a0301e62c079ff51821bf4623b686dcd Mon Sep 17 00:00:00 2001 From: Zohar Babin Date: Sat, 25 May 2024 23:17:30 +0300 Subject: [PATCH 3/3] Upgrading anthropic bedrock model more resillient: boto client config updates, model prompt structure using mdown and better instructions, and using fix_busted_json to resolve cases with llm responding with extra text rather than just json object (requires a dependancy lib: fix_busted_json) --- pyproject.toml | 1 + .../llm_providers/bedrock_anthropic.py | 38 ++++++++++++------- .../llm_providers/bedrock_base.py | 18 ++++++--- 3 files changed, 38 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 00e45f1..9c6dda3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ bedrock = [ "botocore>=1.34.11", "boto3>=1.34.11", "awscli>=1.32.11", + "fix_busted_json>0.0.18", ] chere = [ "cohere>=4.46", diff --git a/src/pydantic_prompter/llm_providers/bedrock_anthropic.py b/src/pydantic_prompter/llm_providers/bedrock_anthropic.py index 64f5fdf..356fdac 100644 --- a/src/pydantic_prompter/llm_providers/bedrock_anthropic.py +++ b/src/pydantic_prompter/llm_providers/bedrock_anthropic.py @@ -1,7 +1,7 @@ import json import random from typing import List, Optional, Dict - +from fix_busted_json import repair_json, largest_json from pydantic_prompter.common import Message, logger from pydantic_prompter.llm_providers.bedrock_base import BedRock from pydantic_prompter.annotation_parser import AnnotationParser @@ -54,20 +54,29 @@ def call( ) -> str: if scheme: - system_message = f"""Act like a REST API that answers the question contained in tags. - Your response should be within xml tags in JSON format with the schema - specified in the tags. - DO NOT add any other text other than the JSON response + system_message = f"""Act like a REST API that performs the requested operation the user asked according to guidelines provided. + Your response should be a valid JSON format, strictly adhering to the Pydantic schema provided in the pydantic_schema section. + Stick to the facts and details in the provided data, and follow the guidelines closely. + Respond in a structured JSON format according to the provided schema. + DO NOT add any other text other than the requested JSON response. + + ## pydantic_schema: - {json.dumps(scheme, indent=4)} - + """ else: # return_type: - system_message = f"""Act like an answer bot that answers the question contained in tags. - Your response should be within <{return_type}> xml tags in {return_type} format . - DO NOT add any other text other than the STRING response - """ + system_message = f"""Act like a REST API that performs the requested operation the user asked according to guidelines provided. + Your response should be according to the format requested in the return_type section. + Stick to the facts and details in the provided data, and follow the guidelines closely. + Respond in a structured JSON format according to the provided schema. + DO NOT add any other text other than the requested return_type response. + + ## return_type: + + {return_type} + +""" final_messages = [m.model_dump() for m in messages] final_messages = self.fix_messages(final_messages) @@ -82,8 +91,9 @@ def call( } response = self._boto_invoke(json.dumps(body)) - res = response.get("body").read().decode() - response_body = json.loads(res) - + response_text = response.get("body").read().decode() + response_json = repair_json(largest_json(response_text)) + response_body = json.loads(response_json) + logger.info(response_body) return response_body.get("content")[0]["text"] diff --git a/src/pydantic_prompter/llm_providers/bedrock_base.py b/src/pydantic_prompter/llm_providers/bedrock_base.py index 15980d9..884149f 100644 --- a/src/pydantic_prompter/llm_providers/bedrock_base.py +++ b/src/pydantic_prompter/llm_providers/bedrock_base.py @@ -2,9 +2,8 @@ import json import random from typing import List - from jinja2 import Template - +from fix_busted_json import repair_json from pydantic_prompter.common import Message, logger from pydantic_prompter.exceptions import BedRockAuthenticationError from pydantic_prompter.llm_providers.base import LLM @@ -63,7 +62,8 @@ def _boto_invoke(self, body): try: logger.debug(f"Request body: \n{body}") import boto3 - + from botocore.config import Config + session = boto3.Session( aws_access_key_id=self.settings.aws_access_key_id, aws_secret_access_key=self.settings.aws_secret_access_key, @@ -71,7 +71,13 @@ def _boto_invoke(self, body): profile_name=self.settings.aws_profile, region_name=self.settings.aws_default_region, ) - client = session.client("bedrock-runtime") + config = Config( + read_timeout=120, # 2 minutes before timeout + retries={'max_attempts': 5, 'mode': 'adaptive'}, # retry 5 times adaptivly + max_pool_connections=50 # allow for significant concurrency + ) + client = session.client('bedrock-runtime', config=config) + #execute the model response = client.invoke_model( body=body, modelId=self.model_name, @@ -101,6 +107,8 @@ def call( ) response = self._boto_invoke(body) - response_body = json.loads(response.get("body").read().decode()) + response_text = response.get("body").read().decode() + response_json = repair_json(response_text) + response_body = json.loads(response_json) logger.info(response_body) return response_body.get("completion")