Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ChatBedrock does not support Custom Model Import models #153

Open
lopezfelipe opened this issue Aug 12, 2024 · 5 comments
Open

ChatBedrock does not support Custom Model Import models #153

lopezfelipe opened this issue Aug 12, 2024 · 5 comments

Comments

@lopezfelipe
Copy link

Issue: Trying to invoke a custom Llama 3 8B model imported with Bedrock Custom Model Import. The model was fine tuned and tested with the invoke_model function as described here.

from langchain_aws import ChatBedrock

llm = ChatBedrock(
    model_id="<<INSERT MODEL ARN HERE>>",
    model_kwargs={
        "max_tokens": 100,
        "top_p": 0.9,
        "temperature": 0.1,
    },
    provider="meta"
)

messages = [
    (
        "system",
        "You are a powerful text-to-SQL model. Your job is to answer questions about a database. You can use the following table schema for context: CREATE TABLE table_name_11 (tournament VARCHAR)",
    ),
    ("human", "Return the SQL query that answers the following question: Which Tournament has A in 1987?"),
]
ai_msg = llm.invoke(messages)
ai_msg

Observed issue: The following validation error is returned

---------------------------------------------------------------------------
ValidationError                           Traceback (most recent call last)
Cell In[63], line 8
      1 messages = [
      2     (
      3         "system",
   (...)
      6     ("human", "Return the SQL query that answers the following question: Which Tournament has A in 1987?"),
      7 ]
----> 8 ai_msg = llm.invoke(messages)
      9 ai_msg

File /opt/conda/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py:284, in BaseChatModel.invoke(self, input, config, stop, **kwargs)
    273 def invoke(
    274     self,
    275     input: LanguageModelInput,
   (...)
    279     **kwargs: Any,
    280 ) -> BaseMessage:
    281     config = ensure_config(config)
    282     return cast(
    283         ChatGeneration,
--> 284         self.generate_prompt(
    285             [self._convert_input(input)],
    286             stop=stop,
    287             callbacks=config.get("callbacks"),
    288             tags=config.get("tags"),
    289             metadata=config.get("metadata"),
    290             run_name=config.get("run_name"),
    291             run_id=config.pop("run_id", None),
    292             **kwargs,
    293         ).generations[0][0],
    294     ).message

File /opt/conda/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py:756, in BaseChatModel.generate_prompt(self, prompts, stop, callbacks, **kwargs)
    748 def generate_prompt(
    749     self,
    750     prompts: List[PromptValue],
   (...)
    753     **kwargs: Any,
    754 ) -> LLMResult:
    755     prompt_messages = [p.to_messages() for p in prompts]
--> 756     return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)

File /opt/conda/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py:613, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, run_id, **kwargs)
    611         if run_managers:
    612             run_managers[i].on_llm_error(e, response=LLMResult(generations=[]))
--> 613         raise e
    614 flattened_outputs = [
    615     LLMResult(generations=[res.generations], llm_output=res.llm_output)  # type: ignore[list-item]
    616     for res in results
    617 ]
    618 llm_output = self._combine_llm_outputs([res.llm_output for res in results])

File /opt/conda/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py:603, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, run_id, **kwargs)
    600 for i, m in enumerate(messages):
    601     try:
    602         results.append(
--> 603             self._generate_with_cache(
    604                 m,
    605                 stop=stop,
    606                 run_manager=run_managers[i] if run_managers else None,
    607                 **kwargs,
    608             )
    609         )
    610     except BaseException as e:
    611         if run_managers:

File /opt/conda/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py:825, in BaseChatModel._generate_with_cache(self, messages, stop, run_manager, **kwargs)
    823 else:
    824     if inspect.signature(self._generate).parameters.get("run_manager"):
--> 825         result = self._generate(
    826             messages, stop=stop, run_manager=run_manager, **kwargs
    827         )
    828     else:
    829         result = self._generate(messages, stop=stop, **kwargs)

File /opt/conda/lib/python3.10/site-packages/langchain_aws/chat_models/bedrock.py:552, in ChatBedrock._generate(self, messages, stop, run_manager, **kwargs)
    548     usage_metadata = None
    550 llm_output["model_id"] = self.model_id
--> 552 msg = AIMessage(
    553     content=completion,
    554     additional_kwargs=llm_output,
    555     tool_calls=cast(List[ToolCall], tool_calls),
    556     usage_metadata=usage_metadata,
    557 )
    559 return ChatResult(
    560     generations=[
    561         ChatGeneration(
   (...)
    565     llm_output=llm_output,
    566 )

File /opt/conda/lib/python3.10/site-packages/langchain_core/messages/ai.py:94, in AIMessage.__init__(self, content, **kwargs)
     85 def __init__(
     86     self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
     87 ) -> None:
     88     """Pass in content as positional arg.
     89 
     90     Args:
     91         content: The content of the message.
     92         kwargs: Additional arguments to pass to the parent class.
     93     """
---> 94     super().__init__(content=content, **kwargs)

File /opt/conda/lib/python3.10/site-packages/langchain_core/messages/base.py:66, in BaseMessage.__init__(self, content, **kwargs)
     57 def __init__(
     58     self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
     59 ) -> None:
     60     """Pass in content as positional arg.
     61 
     62     Args:
     63         content: The string contents of the message.
     64         kwargs: Additional fields to pass to the
     65     """
---> 66     super().__init__(content=content, **kwargs)

File /opt/conda/lib/python3.10/site-packages/langchain_core/load/serializable.py:113, in Serializable.__init__(self, *args, **kwargs)
    111 def __init__(self, *args: Any, **kwargs: Any) -> None:
    112     """"""
--> 113     super().__init__(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/pydantic/v1/main.py:341, in BaseModel.__init__(__pydantic_self__, **data)
    339 values, fields_set, validation_error = validate_model(__pydantic_self__.__class__, data)
    340 if validation_error:
--> 341     raise validation_error
    342 try:
    343     object_setattr(__pydantic_self__, '__dict__', values)

ValidationError: 1 validation error for AIMessage
content
  none is not an allowed value (type=type_error.none.not_allowed)
@rsgrewal-aws
Copy link

@baskaryan could you please let us know if there is a fix for this -- the API examnple can be found here - https://github.com/aws-samples/amazon-bedrock-samples/blob/main/custom_models/import_models/llama-3/llama3-ngrammedqa-fine-tuning.ipynb scroll to the def call_invoke_model_and_print(native_request): function

@3coins
Copy link
Collaborator

3coins commented Aug 15, 2024

@lopezfelipe @rsgrewal-aws
Would help to include a sample of the response from Bedrock in this case, it seems like it is responding with a null value that might be causing this issue.

@lopezfelipe
Copy link
Author

lopezfelipe commented Aug 16, 2024

Here is an example of the same request being sent to the model using the invoke_model function.

import boto3
import json

region = sess.boto_region_name
client = boto3.client("bedrock-runtime", region_name=region)
model_id = "<model-id>"

def get_sql_query(system_prompt, user_question):
    """
    Generate a SQL query using Llama 3 8B
    Remember to use the same template used in fine tuning
    """
    formatted_prompt = f"<s>[INST] <<SYS>>{system_prompt}<</SYS>>\n\n[INST]Human: {user_question}[/INST]\n\nAssistant:"
    print(formatted_prompt)
    native_request = {
        "prompt": formatted_prompt,
        "max_tokens": 100,
        "top_p": 0.9,
        "temperature": 0.1,
    }
    response = client.invoke_model(modelId=model_id,
                                   body=json.dumps(native_request))
    response_text = json.loads(response.get('body').read())["outputs"][0]["text"]

    return response_text

system_prompt = "You are a powerful text-to-SQL model. Your job is to answer questions about a database. You can use the following table schema for context: CREATE TABLE table_name_11 (tournament VARCHAR)"
user_question = "Return the SQL query that answers the following question: Which Tournament has A in 1987?"

query = get_sql_query(system_prompt, user_question).strip()
print(query)

The code from above returns
SELECT tournament FROM table_name_11 WHERE 1987 = "a"

@rsgrewal-aws
Copy link

am wondering why would it respond with a null value, does that means there is an error being thrown somewhere ? Second part to this how do we want to handle no response / null from the model. Ideally we should return the response as is -- which means if it null then a string like "null" ?

@rsgrewal-aws
Copy link

@3coins could you please let us know when we can expect some fixes ? Secondly as part of restructure can we move invoke to a separate method. That will let this class be open to extension easily

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants