-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add refactored LLM benchmark code, initial commit
- Loading branch information
Showing
14 changed files
with
838 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import argparse | ||
from .clients import client_config_classes | ||
from .benchmark_runner import BenchmarkConfig | ||
|
||
def parse_args_to_configs(): | ||
def add_model(parser, model): | ||
fields = model.model_fields | ||
for name, field in fields.items(): | ||
nargs = None | ||
field_type = field.annotation | ||
if getattr(field.annotation, "_name", "") == "List": | ||
nargs = "+" | ||
field_type = field.annotation.__args__[0] | ||
parser.add_argument( | ||
f"--{name}", | ||
dest=name, | ||
nargs=nargs, | ||
type=field_type, | ||
required=getattr(field, "required", False), | ||
default=getattr(field, "default", None), | ||
help=getattr(field, "description", ""), | ||
) | ||
parser = argparse.ArgumentParser() | ||
add_model(parser, BenchmarkConfig) | ||
args, remaining_args = parser.parse_known_args() | ||
unused_args = set(remaining_args) | ||
benchmark_config = BenchmarkConfig(**vars(args)) | ||
|
||
client_config_class = client_config_classes[benchmark_config.api] | ||
parser = argparse.ArgumentParser() | ||
add_model(parser, client_config_class) | ||
args, remaining_args = parser.parse_known_args() | ||
unused_args = unused_args.intersection(remaining_args) | ||
client_config = client_config_class(**vars(args)) | ||
|
||
if unused_args: | ||
raise ValueError(f"Unused arguments: {unused_args}") | ||
|
||
return benchmark_config, client_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
from .config import BaseConfigModel | ||
from .prompt import PromptGenerator, PromptConfig | ||
from .clients import client_classes | ||
from typing import List, Optional | ||
from pydantic import Field | ||
from pathlib import Path | ||
import multiprocessing | ||
import threading | ||
import queue | ||
import time | ||
import yaml | ||
import itertools | ||
from tqdm import tqdm | ||
from loguru import logger | ||
|
||
|
||
class BenchmarkConfig(BaseConfigModel): | ||
model: str = Field(..., description="HuggingFace.co model name") | ||
api: str = "azure_ml" | ||
warmup_requests: int = 1 | ||
result_dir: Path = Path("./results") | ||
use_threading: bool = False | ||
config_files: List[Path] = [] | ||
num_clients: List[int] = [1, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32] | ||
prompt_generator_seed: Optional[int] = None | ||
num_requests_per_client: int = 16 | ||
max_prompt_length: int = 4000 | ||
prompt_length: List[int] = [2600] | ||
prompt_length_var: float = 0.3 | ||
max_new_tokens: List[int] = [60] | ||
max_new_tokens_var: float = 0.3 | ||
streaming: bool = False | ||
|
||
class BenchmarkRunner(): | ||
def __init__(self, benchmark_config: BaseConfigModel, client_config: BaseConfigModel) -> None: | ||
logger.info("Initializing Benchmark Runner") | ||
self.config = benchmark_config | ||
self.client_config = client_config | ||
self.client_class = client_classes[self.config.api] | ||
|
||
self.runnable_cls = multiprocessing.Process | ||
self.barrier_cls = multiprocessing.Barrier | ||
self.queue_cls = multiprocessing.Queue | ||
if self.config.use_threading: | ||
self.runnable_cls = threading.Thread | ||
self.barrier_cls = threading.Barrier | ||
self.queue_cls = queue.Queue | ||
|
||
def _generate_prompts(self, prompt_config: PromptConfig, num_clients: int) -> None: | ||
logger.info("Generating Prompts") | ||
prompt_generator = PromptGenerator(prompt_config) | ||
warmup_prompts = self.config.warmup_requests * num_clients | ||
workload_prompts = self.config.num_requests_per_client * num_clients | ||
for prompt in prompt_generator(warmup_prompts + workload_prompts): | ||
self.query_queue.put(prompt) | ||
logger.info(f"Generated {warmup_prompts} warmup and {workload_prompts} workload prompts.") | ||
|
||
def _launch_clients(self, num_clients): | ||
logger.info(f"Launching {num_clients} client(s)") | ||
self.barrier = self.barrier_cls(num_clients + 1) | ||
processes = [ | ||
self.runnable_cls( | ||
target=self._run_client, | ||
args=( | ||
self.barrier, | ||
self.query_queue, | ||
self.result_queue, | ||
self.client_class, | ||
self.client_config, | ||
self.config.warmup_requests, | ||
), | ||
) | ||
for _ in range(num_clients) | ||
] | ||
for p in processes: | ||
p.start() | ||
|
||
total_prompts = num_clients * self.config.num_requests_per_client | ||
pbar = tqdm(total=total_prompts) | ||
|
||
self.barrier.wait() # Barrier 1 for master process | ||
|
||
num_results = 0 | ||
while num_results != total_prompts: | ||
num_results = self.result_queue.qsize() | ||
pbar.update(num_results - pbar.n) | ||
time.sleep(1) | ||
pbar.close() | ||
|
||
self.barrier.wait() # Barrier 2 for master process | ||
|
||
|
||
@staticmethod | ||
def _run_client(barrier, query_queue, result_queue, client_class, client_config, warmup_requests): | ||
client = client_class(client_config) | ||
|
||
for _ in range(warmup_requests): | ||
prompt = query_queue.get(timeout=1.0) | ||
request_kwargs = client.prepare_request(prompt) | ||
raw_response = client.send_request(request_kwargs) | ||
response = client.process_response(raw_response) | ||
|
||
barrier.wait() # Barrier 1 for client process | ||
try: | ||
while not query_queue.empty(): | ||
prompt = query_queue.get(timeout=1.0) | ||
request_kwargs = client.prepare_request(prompt) | ||
start_time = time.time() | ||
raw_response = client.send_request(request_kwargs) | ||
end_time = time.time() | ||
response = client.process_response(raw_response) | ||
response.request_time = end_time - start_time | ||
result_queue.put_nowait(response) | ||
except queue.Empty: | ||
pass | ||
|
||
barrier.wait() # Barrier 2 for client process | ||
|
||
def _benchmark_settings(self): | ||
prompt_config_keys = list(PromptConfig.model_fields.keys()) + ["num_clients"] | ||
|
||
configs_list = [] | ||
for f in self.config.config_files: | ||
logger.info(f"Generating benchmark run settings from config file: {f}") | ||
with open(f, "r") as fh: | ||
file_config = yaml.safe_load(fh) | ||
for key in prompt_config_keys: | ||
if key not in file_config: | ||
file_config[key] = getattr(self.config, key) | ||
configs_list.append(file_config) | ||
|
||
if not configs_list: | ||
logger.info(f"Generating benchmark run settings from command line args") | ||
configs_list.append({key: getattr(self.config, key) for key in prompt_config_keys}) | ||
|
||
all_config_product = [] | ||
for config in configs_list: | ||
for k, v in config.items(): | ||
if not isinstance(v, list) or isinstance(v, tuple): | ||
config[k] = [v] | ||
for vals in itertools.product(*[config[k] for k in prompt_config_keys]): | ||
all_config_product.append({k:v for k,v in zip(prompt_config_keys, vals)}) | ||
|
||
logger.info(f"Generated {len(all_config_product)} benchmark run setting(s)") | ||
|
||
for config in all_config_product: | ||
num_clients = config.pop("num_clients") | ||
prompt_config = PromptConfig(**config) | ||
yield num_clients, prompt_config | ||
|
||
def _clear_queues(self): | ||
self.query_queue = self.queue_cls() | ||
self.result_queue = self.queue_cls() | ||
|
||
def _save_results(self, num_clients, prompt_config): | ||
response_details = [] | ||
while len(response_details) != num_clients * self.config.num_requests_per_client: | ||
res = self.result_queue.get() | ||
# vLLM returns concatinated tokens | ||
response_details.append(res) | ||
return response_details | ||
|
||
def run(self): | ||
self.client_class.start_service(self.client_config) | ||
for num_clients, prompt_config in self._benchmark_settings(): | ||
logger.info(f"Running benchmark with {num_clients} client(s) and prompt config: {prompt_config}") | ||
self._clear_queues() | ||
self._generate_prompts(prompt_config=prompt_config, num_clients=num_clients) | ||
#self._prepare_requests() | ||
self._launch_clients(num_clients=num_clients) | ||
#self._process_repsonses() | ||
rd = self._save_results(prompt_config=prompt_config, num_clients=num_clients) | ||
print(len(rd)) | ||
self.client_class.stop_service(self.client_config) | ||
|
||
|
||
if __name__ == "__main__": | ||
from .arg_parsing import parse_args_to_configs | ||
benchmark_config, client_config = parse_args_to_configs() | ||
benchmark_runner = BenchmarkRunner(benchmark_config, client_config) | ||
benchmark_runner.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .azure_ml_client import AzureMLClientConfig, AzureMLClient | ||
from .fastgen_client import FastGenClientConfig, FastGenClient | ||
from .vllm_client import vLLMClientConfig, vLLMClient | ||
from .dummy_client import DummyClientConfig, DummyClient | ||
|
||
client_config_classes = {"dummy": DummyClientConfig,"azure_ml": AzureMLClientConfig, "fastgen": FastGenClientConfig, "vllm": vLLMClientConfig} | ||
client_classes = {"dummy": DummyClient, "azure_ml": AzureMLClient, "fastgen": FastGenClient, "vllm": vLLMClient} |
66 changes: 66 additions & 0 deletions
66
benchmarks/inference/llm-bench/src/clients/azure_ml_client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from .base import BaseClient | ||
from ..config import BaseConfigModel | ||
from ..status import Status | ||
from ..prompt import Prompt | ||
from ..response import Response | ||
|
||
import requests | ||
import json | ||
from typing import Any, Dict, Optional | ||
|
||
|
||
class AzureMLClientConfig(BaseConfigModel): | ||
api_url: str = "" | ||
api_key: str = "" | ||
deployment_name: str = "" | ||
|
||
class AzureMLClient(BaseClient): | ||
def __init__(self, config: AzureMLClientConfig) -> None: | ||
self.api_url = config.api_url | ||
self.api_key = config.api_key | ||
self.deployment_name = config.deployment_name | ||
|
||
@staticmethod | ||
def start_service(config: AzureMLClientConfig) -> Status: | ||
pass | ||
|
||
@staticmethod | ||
def stop_service(config: AzureMLClientConfig) -> Status: | ||
pass | ||
|
||
def prepare_request(self, prompt: Prompt) -> Dict[str, Any]: | ||
if prompt.streaming: | ||
raise ValueError("AzureMLClient does not support streaming prompts.") | ||
|
||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": ("Bearer " + self.api_key), | ||
"azureml-model-deployment": self.deployment_name, | ||
} | ||
pload = { | ||
"input_data": { | ||
"input_string": [ | ||
prompt.text, | ||
], | ||
"parameters": { | ||
"max_tokens": prompt.max_new_tokens, | ||
"return_full_text": prompt.return_full_text, | ||
}, | ||
} | ||
} | ||
return {"url": self.api_url, "headers": headers, "json": pload, "timeout": 180} | ||
|
||
def send_request(self, request_kwargs: Dict[str, Any]) -> Any: | ||
while True: | ||
try: # Sometimes the AML endpoint will return an error, so we send the request again | ||
response = requests.post(**request_kwargs) | ||
output = json.loads(response.content) | ||
break | ||
except Exception as e: | ||
print(f"Connection failed with {e}. Retrying AML request") | ||
|
||
return output | ||
|
||
def process_response(self, raw_response: Any) -> Response: | ||
response_text = raw_response[0] | ||
return Response(response_text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
from ..prompt import Prompt | ||
from ..response import Response | ||
from ..status import Status | ||
from ..config import BaseConfigModel | ||
|
||
from typing import Any, Dict | ||
|
||
class BaseClient(ABC): | ||
@abstractmethod | ||
def __init__(self): | ||
pass | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def start_service(config: BaseConfigModel) -> Status: | ||
pass | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def stop_service(config: BaseConfigModel) -> Status: | ||
pass | ||
|
||
@abstractmethod | ||
def prepare_request(self, prompt: Prompt) -> Dict[str, Any]: | ||
pass | ||
|
||
@abstractmethod | ||
def send_request(self, request_kwargs: Dict[str,Any]) -> Any: | ||
pass | ||
|
||
@abstractmethod | ||
def process_response(self, raw_response: Any) -> Response: | ||
pass |
37 changes: 37 additions & 0 deletions
37
benchmarks/inference/llm-bench/src/clients/dummy_client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
from ..config import BaseConfigModel | ||
from ..prompt import Prompt | ||
from ..response import Response | ||
from ..status import Status | ||
from .base import BaseClient | ||
|
||
from typing import Any, Dict | ||
import time | ||
import random | ||
|
||
class DummyClientConfig(BaseConfigModel): | ||
pass | ||
|
||
class DummyClient(BaseClient): | ||
def __init__(self, config: DummyClientConfig) -> None: | ||
pass | ||
|
||
@staticmethod | ||
def start_service(config: DummyClientConfig) -> Status: | ||
return Status("OK") | ||
|
||
@staticmethod | ||
def stop_service(config: DummyClientConfig) -> Status: | ||
return Status("OK") | ||
|
||
def prepare_request(self, prompt: Prompt) -> Dict[str, Any]: | ||
return {"input_text": prompt.text, "max_new_tokens": prompt.max_new_tokens} | ||
|
||
def send_request(self, request_kwargs: Dict[str,Any]) -> Any: | ||
time.sleep(random.uniform(1, 2)) | ||
#time.sleep(1) | ||
return request_kwargs["input_text"]*2 | ||
|
||
def process_response(self, raw_response: Any) -> Response: | ||
return Response(raw_response) |
Oops, something went wrong.