diff --git a/pyproject.toml b/pyproject.toml index fbf1ced..a9796c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ bedrock = [ "botocore>=1.34.11", "boto3>=1.34.11", "awscli>=1.32.11", + "fix_busted_json>0.0.18", ] chere = [ "cohere>=4.46", diff --git a/src/pydantic_prompter/llm_providers/bedrock_anthropic.py b/src/pydantic_prompter/llm_providers/bedrock_anthropic.py index 64f5fdf..356fdac 100644 --- a/src/pydantic_prompter/llm_providers/bedrock_anthropic.py +++ b/src/pydantic_prompter/llm_providers/bedrock_anthropic.py @@ -1,7 +1,7 @@ import json import random from typing import List, Optional, Dict - +from fix_busted_json import repair_json, largest_json from pydantic_prompter.common import Message, logger from pydantic_prompter.llm_providers.bedrock_base import BedRock from pydantic_prompter.annotation_parser import AnnotationParser @@ -54,20 +54,29 @@ def call( ) -> str: if scheme: - system_message = f"""Act like a REST API that answers the question contained in tags. - Your response should be within xml tags in JSON format with the schema - specified in the tags. - DO NOT add any other text other than the JSON response + system_message = f"""Act like a REST API that performs the requested operation the user asked according to guidelines provided. + Your response should be a valid JSON format, strictly adhering to the Pydantic schema provided in the pydantic_schema section. + Stick to the facts and details in the provided data, and follow the guidelines closely. + Respond in a structured JSON format according to the provided schema. + DO NOT add any other text other than the requested JSON response. + + ## pydantic_schema: - {json.dumps(scheme, indent=4)} - + """ else: # return_type: - system_message = f"""Act like an answer bot that answers the question contained in tags. - Your response should be within <{return_type}> xml tags in {return_type} format . - DO NOT add any other text other than the STRING response - """ + system_message = f"""Act like a REST API that performs the requested operation the user asked according to guidelines provided. + Your response should be according to the format requested in the return_type section. + Stick to the facts and details in the provided data, and follow the guidelines closely. + Respond in a structured JSON format according to the provided schema. + DO NOT add any other text other than the requested return_type response. + + ## return_type: + + {return_type} + +""" final_messages = [m.model_dump() for m in messages] final_messages = self.fix_messages(final_messages) @@ -82,8 +91,9 @@ def call( } response = self._boto_invoke(json.dumps(body)) - res = response.get("body").read().decode() - response_body = json.loads(res) - + response_text = response.get("body").read().decode() + response_json = repair_json(largest_json(response_text)) + response_body = json.loads(response_json) + logger.info(response_body) return response_body.get("content")[0]["text"] diff --git a/src/pydantic_prompter/llm_providers/bedrock_base.py b/src/pydantic_prompter/llm_providers/bedrock_base.py index 15980d9..884149f 100644 --- a/src/pydantic_prompter/llm_providers/bedrock_base.py +++ b/src/pydantic_prompter/llm_providers/bedrock_base.py @@ -2,9 +2,8 @@ import json import random from typing import List - from jinja2 import Template - +from fix_busted_json import repair_json from pydantic_prompter.common import Message, logger from pydantic_prompter.exceptions import BedRockAuthenticationError from pydantic_prompter.llm_providers.base import LLM @@ -63,7 +62,8 @@ def _boto_invoke(self, body): try: logger.debug(f"Request body: \n{body}") import boto3 - + from botocore.config import Config + session = boto3.Session( aws_access_key_id=self.settings.aws_access_key_id, aws_secret_access_key=self.settings.aws_secret_access_key, @@ -71,7 +71,13 @@ def _boto_invoke(self, body): profile_name=self.settings.aws_profile, region_name=self.settings.aws_default_region, ) - client = session.client("bedrock-runtime") + config = Config( + read_timeout=120, # 2 minutes before timeout + retries={'max_attempts': 5, 'mode': 'adaptive'}, # retry 5 times adaptivly + max_pool_connections=50 # allow for significant concurrency + ) + client = session.client('bedrock-runtime', config=config) + #execute the model response = client.invoke_model( body=body, modelId=self.model_name, @@ -101,6 +107,8 @@ def call( ) response = self._boto_invoke(body) - response_body = json.loads(response.get("body").read().decode()) + response_text = response.get("body").read().decode() + response_json = repair_json(response_text) + response_body = json.loads(response_json) logger.info(response_body) return response_body.get("completion")