diff --git a/benchmarks/inference/mii/run_aml.sh b/benchmarks/inference/mii/run_aml.sh new file mode 100644 index 000000000..90ad50e2c --- /dev/null +++ b/benchmarks/inference/mii/run_aml.sh @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Run benchmark against AML endpoint +python ./run_benchmark.py \ + --model \ + --deployment_name \ + --aml_api_url \ + --aml_api_key \ + --mean_prompt_length 2600 \ + --mean_max_new_tokens 60 \ + --num_requests 256 \ + --backend aml + +### Gernerate the plots +python ./src/plot_th_lat.py + +echo "Find figures in ./plots/ and log outputs in ./results/" \ No newline at end of file diff --git a/benchmarks/inference/mii/src/client.py b/benchmarks/inference/mii/src/client.py index 916fe4f23..c0fd6a767 100644 --- a/benchmarks/inference/mii/src/client.py +++ b/benchmarks/inference/mii/src/client.py @@ -163,7 +163,11 @@ def get_response(response: requests.Response) -> List[str]: token_gen_time = [] start_time = time.time() response = requests.post(args.aml_api_url, headers=headers, json=pload) - output = get_response(response) + # Sometimes the AML endpoint will return an error, so we send the request again + try: + output = get_response(response) + except Exception as e: + return call_aml(input_tokens, max_new_tokens, args) return ResponseDetails( generated_tokens=output,