diff --git a/tests/utils.py b/tests/utils.py index 55c813728b1e0..020c33b81129a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -310,14 +310,38 @@ def compare_two_settings(model: str, env2: The second set of environment variables to pass to the API server. """ + compare_all_settings( + model, + [arg1, arg2], + [env1, env2], + method=method, + max_wait_seconds=max_wait_seconds, + ) + + +def compare_all_settings(model: str, + all_args: List[List[str]], + all_envs: List[Optional[Dict[str, str]]], + *, + method: Literal["generate", "encode"] = "generate", + max_wait_seconds: Optional[float] = None) -> None: + """ + Launch API server with several different sets of arguments/environments + and compare the results of the API calls with the first set of arguments. + Args: + model: The model to test. + all_args: A list of argument lists to pass to the API server. + all_envs: A list of environment dictionaries to pass to the API server. + """ + trust_remote_code = False - for args in (arg1, arg2): + for args in all_args: if "--trust-remote-code" in args: trust_remote_code = True break tokenizer_mode = "auto" - for args in (arg1, arg2): + for args in all_args: if "--tokenizer-mode" in args: tokenizer_mode = args[args.index("--tokenizer-mode") + 1] break @@ -330,8 +354,10 @@ def compare_two_settings(model: str, prompt = "Hello, my name is" token_ids = tokenizer(prompt).input_ids - results = [] - for args, env in ((arg1, env1), (arg2, env2)): + ref_results: List = [] + for i, (args, env) in enumerate(zip(all_args, all_envs)): + compare_results: List = [] + results = ref_results if i == 0 else compare_results with RemoteOpenAIServer(model, args, env_dict=env, @@ -355,13 +381,20 @@ def compare_two_settings(model: str, else: assert_never(method) - n = len(results) // 2 - arg1_results = results[:n] - arg2_results = results[n:] - for arg1_result, arg2_result in zip(arg1_results, arg2_results): - assert arg1_result == arg2_result, ( - f"Results for {model=} are not the same with {arg1=} and {arg2=}. " - f"{arg1_result=} != {arg2_result=}") + if i > 0: + # if any setting fails, raise an error early + ref_args = all_args[0] + ref_envs = all_envs[0] + compare_args = all_args[i] + compare_envs = all_envs[i] + for ref_result, compare_result in zip(ref_results, + compare_results): + assert ref_result == compare_result, ( + f"Results for {model=} are not the same.\n" + f"{ref_args=} {ref_envs=}\n" + f"{compare_args=} {compare_envs=}\n" + f"{ref_result=}\n" + f"{compare_result=}\n") def init_test_distributed_environment(