Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bad_output_process_model option and use_fixed_model_version option for all generation methods, to avoid future OpenAI API changes break Sotopia running. #196

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/pages/concepts/generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ async def agenerate(
output_parser: BaseOutputParser[OutputType],
temperature: float = 0.7,
structured_output: bool = False,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True
) -> OutputType:
input_variables = re.findall(r"(?<!{){([^{}]+)}(?!})", template)
```
Expand All @@ -23,6 +25,12 @@ The `agenerate` function is versatile by taking the output_parser as an argument
* `gpt-4o-mini-2024-07-18` and later
* `gpt-4o-2024-08-06` and later

The `bad_output_process_model` is used to process the bad output. `DEFAULT_BAD_OUTPUT_PROCESS_MODEL` is set to be `gpt-4o-mini` (At the publication time of Sotopia, we used `gpt-3.5-turbo-0613`. However this model has been taken off the shelf by OpenAI.).

The `use_fixed_model_version` is used to determine whether to use the fixed model version. If set to `True`, the model version will be fixed to the version that was used in Sotopia paper. If set to `False`, the model version will be the latest version available.

Warning: As some fixed model versions might not be available in the OpenAI API, setting `use_fixed_model_version = True` might result in an error.

</Callout>

Here are a few examples of how to use the `agenerate` function:
Expand All @@ -37,6 +45,8 @@ async def agenerate_env_profile(
inspiration_prompt: str = "asking my boyfriend to stop being friends with his ex",
examples: str = "",
temperature: float = 0.7,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True
) -> tuple[EnvironmentProfile, str]:
"""
Using langchain to generate the background
Expand All @@ -56,6 +66,8 @@ async def agenerate_env_profile(
),
output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile),
temperature=temperature,
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version
)
```
### Other generation functions
Expand All @@ -66,6 +78,8 @@ Similarly, there are other utility functions that builds upon the `agenerate` fu
async def agenerate_relationship_profile(
model_name: str,
agents_profiles: list[str],
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True
) -> tuple[RelationshipProfile, str]
```

Expand All @@ -78,5 +92,7 @@ async def agenerate_script(
agent_name: str = "",
history: str = "",
single_step: bool = False,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True
) -> tuple[ScriptInteractionReturnType, str]
```
70 changes: 63 additions & 7 deletions sotopia/generation_utils/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
"redis",
"groq/llama3-70b-8192",
]
# subject to future OpenAI changes
DEFAULT_BAD_OUTPUT_PROCESS_MODEL = "gpt-4o-mini"

OutputType = TypeVar("OutputType", bound=object)
client = OpenAI()
Expand Down Expand Up @@ -304,6 +306,7 @@ def obtain_chain(
input_variables: list[str],
temperature: float = 0.7,
max_retries: int = 6,
use_fixed_model_version: bool = True,
) -> RunnableSerializable[dict[Any, Any], BaseMessage]:
"""
Using langchain to sample profiles for participants
Expand All @@ -315,7 +318,8 @@ def obtain_chain(
)
)
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
model_name = _return_fixed_model_version(model_name)
if use_fixed_model_version:
model_name = _return_fixed_model_version(model_name)
if model_name.startswith("together_ai"):
model_name = "/".join(model_name.split("/")[1:])
chat_openai = ChatOpenAI(
Expand Down Expand Up @@ -391,7 +395,8 @@ def format_bad_output_for_script(
ill_formed_output: str,
format_instructions: str,
agents: list[str],
model_name: str = "gpt-4o-mini",
model_name: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> BaseMessage:
template = """
Given the string that can not be parsed by a parser, reformat it to a string that can be parsed by the parser which uses the following format instructions. Do not add or delete any information.
Expand All @@ -410,6 +415,7 @@ def format_bad_output_for_script(
model_name=model_name,
template=template,
input_variables=re.findall(r"{(.*?)}", template),
use_fixed_model_version=use_fixed_model_version,
)
input_values = {
"ill_formed_output": ill_formed_output,
Expand All @@ -425,7 +431,8 @@ def format_bad_output_for_script(
def format_bad_output(
ill_formed_output: BaseMessage,
format_instructions: str,
model_name: str = "gpt-4o-mini",
model_name: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> BaseMessage:
template = """
Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser.
Expand All @@ -439,6 +446,7 @@ def format_bad_output(
model_name=model_name,
template=template,
input_variables=re.findall(r"{(.*?)}", template),
use_fixed_model_version=use_fixed_model_version,
)
input_values = {
"ill_formed_output": ill_formed_output.content,
Expand All @@ -458,6 +466,8 @@ async def agenerate(
output_parser: BaseOutputParser[OutputType],
temperature: float = 0.7,
structured_output: bool = False,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> OutputType:
input_variables = re.findall(
r"(?<!{){([^{}]+)}(?!})", template
Expand All @@ -473,6 +483,7 @@ async def agenerate(
template=template,
input_variables=input_variables,
temperature=temperature,
use_fixed_model_version=use_fixed_model_version,
)

if "format_instructions" not in input_values:
Expand Down Expand Up @@ -516,7 +527,10 @@ async def agenerate(
extra={"markup": True},
)
reformat_parsed_result = format_bad_output(
result, format_instructions=output_parser.get_format_instructions()
result,
format_instructions=output_parser.get_format_instructions(),
model_name=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)
parsed_result = output_parser.invoke(reformat_parsed_result)
log.info(f"Generated result: {parsed_result}")
Expand All @@ -530,6 +544,8 @@ async def agenerate_env_profile(
inspiration_prompt: str = "asking my boyfriend to stop being friends with his ex",
examples: str = "",
temperature: float = 0.7,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> tuple[EnvironmentProfile, str]:
"""
Using langchain to generate the background
Expand All @@ -549,13 +565,17 @@ async def agenerate_env_profile(
),
output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile),
temperature=temperature,
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)


@beartype
async def agenerate_relationship_profile(
model_name: str,
agents_profiles: list[str],
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> tuple[RelationshipProfile, str]:
"""
Using langchain to generate the background
Expand All @@ -572,6 +592,8 @@ async def agenerate_relationship_profile(
agent_profile=agent_profile,
),
output_parser=PydanticOutputParser(pydantic_object=RelationshipProfile),
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)


Expand All @@ -586,6 +608,8 @@ async def agenerate_action(
goal: str,
temperature: float = 0.7,
script_like: bool = False,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> AgentAction:
"""
Using langchain to generate an example episode
Expand Down Expand Up @@ -635,6 +659,8 @@ async def agenerate_action(
),
output_parser=PydanticOutputParser(pydantic_object=AgentAction),
temperature=temperature,
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)
except Exception:
return AgentAction(action_type="none", argument="")
Expand All @@ -650,6 +676,8 @@ async def agenerate_script(
agent_name: str = "",
history: str = "",
single_step: bool = False,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> tuple[ScriptInteractionReturnType, str]:
"""
Using langchain to generate an the script interactions between two agent
Expand Down Expand Up @@ -683,6 +711,8 @@ async def agenerate_script(
single_turn=True,
),
temperature=temperature,
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)

else:
Expand All @@ -705,6 +735,8 @@ async def agenerate_script(
single_turn=False,
),
temperature=temperature,
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)
except Exception as e:
# TODO raise(e) # Maybe we do not want to return anything?
Expand Down Expand Up @@ -733,7 +765,12 @@ def process_history(


@beartype
async def agenerate_init_profile(model_name: str, basic_info: dict[str, str]) -> str:
async def agenerate_init_profile(
model_name: str,
basic_info: dict[str, str],
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> str:
"""
Using langchain to generate the background
"""
Expand Down Expand Up @@ -767,11 +804,19 @@ async def agenerate_init_profile(model_name: str, basic_info: dict[str, str]) ->
secret=basic_info["secret"],
),
output_parser=StrOutputParser(),
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)


@beartype
async def convert_narratives(model_name: str, narrative: str, text: str) -> str:
async def convert_narratives(
model_name: str,
narrative: str,
text: str,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> str:
if narrative == "first":
return await agenerate(
model_name=model_name,
Expand All @@ -780,6 +825,8 @@ async def convert_narratives(model_name: str, narrative: str, text: str) -> str:
{text}""",
input_values=dict(text=text),
output_parser=StrOutputParser(),
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)
elif narrative == "second":
return await agenerate(
Expand All @@ -789,13 +836,20 @@ async def convert_narratives(model_name: str, narrative: str, text: str) -> str:
{text}""",
input_values=dict(text=text),
output_parser=StrOutputParser(),
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)
else:
raise ValueError(f"Narrative {narrative} is not supported.")


@beartype
async def agenerate_goal(model_name: str, background: str) -> str:
async def agenerate_goal(
model_name: str,
background: str,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> str:
"""
Using langchain to generate the background
"""
Expand All @@ -806,4 +860,6 @@ async def agenerate_goal(model_name: str, background: str) -> str:
""",
input_values=dict(background=background),
output_parser=StrOutputParser(),
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)
Loading