Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc] update utils to support comparing multiple settings #9140

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,38 @@ def compare_two_settings(model: str,
env2: The second set of environment variables to pass to the API server.
"""

compare_all_settings(
model,
[arg1, arg2],
[env1, env2],
method=method,
max_wait_seconds=max_wait_seconds,
)


def compare_all_settings(model: str,
all_args: List[List[str]],
all_envs: List[Optional[Dict[str, str]]],
*,
method: Literal["generate", "encode"] = "generate",
max_wait_seconds: Optional[float] = None) -> None:
"""
Launch API server with several different sets of arguments/environments
and compare the results of the API calls with the first set of arguments.
Args:
model: The model to test.
all_args: A list of argument lists to pass to the API server.
all_envs: A list of environment dictionaries to pass to the API server.
"""

trust_remote_code = False
for args in (arg1, arg2):
for args in all_args:
if "--trust-remote-code" in args:
trust_remote_code = True
break

tokenizer_mode = "auto"
for args in (arg1, arg2):
for args in all_args:
if "--tokenizer-mode" in args:
tokenizer_mode = args[args.index("--tokenizer-mode") + 1]
break
Expand All @@ -330,8 +354,10 @@ def compare_two_settings(model: str,

prompt = "Hello, my name is"
token_ids = tokenizer(prompt).input_ids
results = []
for args, env in ((arg1, env1), (arg2, env2)):
ref_results: List = []
for i, (args, env) in enumerate(zip(all_args, all_envs)):
compare_results: List = []
results = ref_results if i == 0 else compare_results
with RemoteOpenAIServer(model,
args,
env_dict=env,
Expand All @@ -355,13 +381,20 @@ def compare_two_settings(model: str,
else:
assert_never(method)

n = len(results) // 2
arg1_results = results[:n]
arg2_results = results[n:]
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
assert arg1_result == arg2_result, (
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
f"{arg1_result=} != {arg2_result=}")
if i > 0:
# if any setting fails, raise an error early
ref_args = all_args[0]
ref_envs = all_envs[0]
compare_args = all_args[i]
compare_envs = all_envs[i]
for ref_result, compare_result in zip(ref_results,
compare_results):
assert ref_result == compare_result, (
f"Results for {model=} are not the same.\n"
f"{ref_args=} {ref_envs=}\n"
f"{compare_args=} {compare_envs=}\n"
f"{ref_result=}\n"
f"{compare_result=}\n")


def init_test_distributed_environment(
Expand Down
Loading