diff --git a/docs/pages/concepts/generation.md b/docs/pages/concepts/generation.md index 41895aec3..8f633c6bc 100644 --- a/docs/pages/concepts/generation.md +++ b/docs/pages/concepts/generation.md @@ -12,6 +12,8 @@ async def agenerate( output_parser: BaseOutputParser[OutputType], temperature: float = 0.7, structured_output: bool = False, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True ) -> OutputType: input_variables = re.findall(r"(? Here are a few examples of how to use the `agenerate` function: @@ -37,6 +45,8 @@ async def agenerate_env_profile( inspiration_prompt: str = "asking my boyfriend to stop being friends with his ex", examples: str = "", temperature: float = 0.7, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True ) -> tuple[EnvironmentProfile, str]: """ Using langchain to generate the background @@ -56,6 +66,8 @@ async def agenerate_env_profile( ), output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile), temperature=temperature, + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version ) ``` ### Other generation functions @@ -66,6 +78,8 @@ Similarly, there are other utility functions that builds upon the `agenerate` fu async def agenerate_relationship_profile( model_name: str, agents_profiles: list[str], + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True ) -> tuple[RelationshipProfile, str] ``` @@ -78,5 +92,7 @@ async def agenerate_script( agent_name: str = "", history: str = "", single_step: bool = False, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True ) -> tuple[ScriptInteractionReturnType, str] ``` diff --git a/sotopia/generation_utils/generate.py b/sotopia/generation_utils/generate.py index 8d517ec57..92c4c7457 100644 --- a/sotopia/generation_utils/generate.py +++ b/sotopia/generation_utils/generate.py @@ -55,6 +55,8 @@ "redis", "groq/llama3-70b-8192", ] +# subject to future OpenAI changes +DEFAULT_BAD_OUTPUT_PROCESS_MODEL = "gpt-4o-mini" OutputType = TypeVar("OutputType", bound=object) client = OpenAI() @@ -304,6 +306,7 @@ def obtain_chain( input_variables: list[str], temperature: float = 0.7, max_retries: int = 6, + use_fixed_model_version: bool = True, ) -> RunnableSerializable[dict[Any, Any], BaseMessage]: """ Using langchain to sample profiles for participants @@ -315,7 +318,8 @@ def obtain_chain( ) ) chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt]) - model_name = _return_fixed_model_version(model_name) + if use_fixed_model_version: + model_name = _return_fixed_model_version(model_name) if model_name.startswith("together_ai"): model_name = "/".join(model_name.split("/")[1:]) chat_openai = ChatOpenAI( @@ -391,7 +395,8 @@ def format_bad_output_for_script( ill_formed_output: str, format_instructions: str, agents: list[str], - model_name: str = "gpt-4o-mini", + model_name: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, ) -> BaseMessage: template = """ Given the string that can not be parsed by a parser, reformat it to a string that can be parsed by the parser which uses the following format instructions. Do not add or delete any information. @@ -410,6 +415,7 @@ def format_bad_output_for_script( model_name=model_name, template=template, input_variables=re.findall(r"{(.*?)}", template), + use_fixed_model_version=use_fixed_model_version, ) input_values = { "ill_formed_output": ill_formed_output, @@ -425,7 +431,8 @@ def format_bad_output_for_script( def format_bad_output( ill_formed_output: BaseMessage, format_instructions: str, - model_name: str = "gpt-4o-mini", + model_name: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, ) -> BaseMessage: template = """ Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser. @@ -439,6 +446,7 @@ def format_bad_output( model_name=model_name, template=template, input_variables=re.findall(r"{(.*?)}", template), + use_fixed_model_version=use_fixed_model_version, ) input_values = { "ill_formed_output": ill_formed_output.content, @@ -458,6 +466,8 @@ async def agenerate( output_parser: BaseOutputParser[OutputType], temperature: float = 0.7, structured_output: bool = False, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, ) -> OutputType: input_variables = re.findall( r"(? tuple[EnvironmentProfile, str]: """ Using langchain to generate the background @@ -549,6 +565,8 @@ async def agenerate_env_profile( ), output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile), temperature=temperature, + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) @@ -556,6 +574,8 @@ async def agenerate_env_profile( async def agenerate_relationship_profile( model_name: str, agents_profiles: list[str], + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, ) -> tuple[RelationshipProfile, str]: """ Using langchain to generate the background @@ -572,6 +592,8 @@ async def agenerate_relationship_profile( agent_profile=agent_profile, ), output_parser=PydanticOutputParser(pydantic_object=RelationshipProfile), + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) @@ -586,6 +608,8 @@ async def agenerate_action( goal: str, temperature: float = 0.7, script_like: bool = False, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, ) -> AgentAction: """ Using langchain to generate an example episode @@ -635,6 +659,8 @@ async def agenerate_action( ), output_parser=PydanticOutputParser(pydantic_object=AgentAction), temperature=temperature, + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) except Exception: return AgentAction(action_type="none", argument="") @@ -650,6 +676,8 @@ async def agenerate_script( agent_name: str = "", history: str = "", single_step: bool = False, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, ) -> tuple[ScriptInteractionReturnType, str]: """ Using langchain to generate an the script interactions between two agent @@ -683,6 +711,8 @@ async def agenerate_script( single_turn=True, ), temperature=temperature, + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) else: @@ -705,6 +735,8 @@ async def agenerate_script( single_turn=False, ), temperature=temperature, + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) except Exception as e: # TODO raise(e) # Maybe we do not want to return anything? @@ -733,7 +765,12 @@ def process_history( @beartype -async def agenerate_init_profile(model_name: str, basic_info: dict[str, str]) -> str: +async def agenerate_init_profile( + model_name: str, + basic_info: dict[str, str], + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, +) -> str: """ Using langchain to generate the background """ @@ -767,11 +804,19 @@ async def agenerate_init_profile(model_name: str, basic_info: dict[str, str]) -> secret=basic_info["secret"], ), output_parser=StrOutputParser(), + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) @beartype -async def convert_narratives(model_name: str, narrative: str, text: str) -> str: +async def convert_narratives( + model_name: str, + narrative: str, + text: str, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, +) -> str: if narrative == "first": return await agenerate( model_name=model_name, @@ -780,6 +825,8 @@ async def convert_narratives(model_name: str, narrative: str, text: str) -> str: {text}""", input_values=dict(text=text), output_parser=StrOutputParser(), + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) elif narrative == "second": return await agenerate( @@ -789,13 +836,20 @@ async def convert_narratives(model_name: str, narrative: str, text: str) -> str: {text}""", input_values=dict(text=text), output_parser=StrOutputParser(), + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) else: raise ValueError(f"Narrative {narrative} is not supported.") @beartype -async def agenerate_goal(model_name: str, background: str) -> str: +async def agenerate_goal( + model_name: str, + background: str, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, +) -> str: """ Using langchain to generate the background """ @@ -806,4 +860,6 @@ async def agenerate_goal(model_name: str, background: str) -> str: """, input_values=dict(background=background), output_parser=StrOutputParser(), + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, )