Skip to content

Commit

Permalink
add refactored LLM benchmark code, initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii committed Apr 30, 2024
1 parent fab5d06 commit be41d26
Show file tree
Hide file tree
Showing 14 changed files with 838 additions and 0 deletions.
Empty file.
39 changes: 39 additions & 0 deletions benchmarks/inference/llm-bench/src/arg_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import argparse
from .clients import client_config_classes
from .benchmark_runner import BenchmarkConfig

def parse_args_to_configs():
def add_model(parser, model):
fields = model.model_fields
for name, field in fields.items():
nargs = None
field_type = field.annotation
if getattr(field.annotation, "_name", "") == "List":
nargs = "+"
field_type = field.annotation.__args__[0]
parser.add_argument(
f"--{name}",
dest=name,
nargs=nargs,
type=field_type,
required=getattr(field, "required", False),
default=getattr(field, "default", None),
help=getattr(field, "description", ""),
)
parser = argparse.ArgumentParser()
add_model(parser, BenchmarkConfig)
args, remaining_args = parser.parse_known_args()
unused_args = set(remaining_args)
benchmark_config = BenchmarkConfig(**vars(args))

client_config_class = client_config_classes[benchmark_config.api]
parser = argparse.ArgumentParser()
add_model(parser, client_config_class)
args, remaining_args = parser.parse_known_args()
unused_args = unused_args.intersection(remaining_args)
client_config = client_config_class(**vars(args))

if unused_args:
raise ValueError(f"Unused arguments: {unused_args}")

return benchmark_config, client_config
181 changes: 181 additions & 0 deletions benchmarks/inference/llm-bench/src/benchmark_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from .config import BaseConfigModel
from .prompt import PromptGenerator, PromptConfig
from .clients import client_classes
from typing import List, Optional
from pydantic import Field
from pathlib import Path
import multiprocessing
import threading
import queue
import time
import yaml
import itertools
from tqdm import tqdm
from loguru import logger


class BenchmarkConfig(BaseConfigModel):
model: str = Field(..., description="HuggingFace.co model name")
api: str = "azure_ml"
warmup_requests: int = 1
result_dir: Path = Path("./results")
use_threading: bool = False
config_files: List[Path] = []
num_clients: List[int] = [1, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32]
prompt_generator_seed: Optional[int] = None
num_requests_per_client: int = 16
max_prompt_length: int = 4000
prompt_length: List[int] = [2600]
prompt_length_var: float = 0.3
max_new_tokens: List[int] = [60]
max_new_tokens_var: float = 0.3
streaming: bool = False

class BenchmarkRunner():
def __init__(self, benchmark_config: BaseConfigModel, client_config: BaseConfigModel) -> None:
logger.info("Initializing Benchmark Runner")
self.config = benchmark_config
self.client_config = client_config
self.client_class = client_classes[self.config.api]

self.runnable_cls = multiprocessing.Process
self.barrier_cls = multiprocessing.Barrier
self.queue_cls = multiprocessing.Queue
if self.config.use_threading:
self.runnable_cls = threading.Thread
self.barrier_cls = threading.Barrier
self.queue_cls = queue.Queue

def _generate_prompts(self, prompt_config: PromptConfig, num_clients: int) -> None:
logger.info("Generating Prompts")
prompt_generator = PromptGenerator(prompt_config)
warmup_prompts = self.config.warmup_requests * num_clients
workload_prompts = self.config.num_requests_per_client * num_clients
for prompt in prompt_generator(warmup_prompts + workload_prompts):
self.query_queue.put(prompt)
logger.info(f"Generated {warmup_prompts} warmup and {workload_prompts} workload prompts.")

def _launch_clients(self, num_clients):
logger.info(f"Launching {num_clients} client(s)")
self.barrier = self.barrier_cls(num_clients + 1)
processes = [
self.runnable_cls(
target=self._run_client,
args=(
self.barrier,
self.query_queue,
self.result_queue,
self.client_class,
self.client_config,
self.config.warmup_requests,
),
)
for _ in range(num_clients)
]
for p in processes:
p.start()

total_prompts = num_clients * self.config.num_requests_per_client
pbar = tqdm(total=total_prompts)

self.barrier.wait() # Barrier 1 for master process

num_results = 0
while num_results != total_prompts:
num_results = self.result_queue.qsize()
pbar.update(num_results - pbar.n)
time.sleep(1)
pbar.close()

self.barrier.wait() # Barrier 2 for master process


@staticmethod
def _run_client(barrier, query_queue, result_queue, client_class, client_config, warmup_requests):
client = client_class(client_config)

for _ in range(warmup_requests):
prompt = query_queue.get(timeout=1.0)
request_kwargs = client.prepare_request(prompt)
raw_response = client.send_request(request_kwargs)
response = client.process_response(raw_response)

barrier.wait() # Barrier 1 for client process
try:
while not query_queue.empty():
prompt = query_queue.get(timeout=1.0)
request_kwargs = client.prepare_request(prompt)
start_time = time.time()
raw_response = client.send_request(request_kwargs)
end_time = time.time()
response = client.process_response(raw_response)
response.request_time = end_time - start_time
result_queue.put_nowait(response)
except queue.Empty:
pass

barrier.wait() # Barrier 2 for client process

def _benchmark_settings(self):
prompt_config_keys = list(PromptConfig.model_fields.keys()) + ["num_clients"]

configs_list = []
for f in self.config.config_files:
logger.info(f"Generating benchmark run settings from config file: {f}")
with open(f, "r") as fh:
file_config = yaml.safe_load(fh)
for key in prompt_config_keys:
if key not in file_config:
file_config[key] = getattr(self.config, key)
configs_list.append(file_config)

if not configs_list:
logger.info(f"Generating benchmark run settings from command line args")
configs_list.append({key: getattr(self.config, key) for key in prompt_config_keys})

all_config_product = []
for config in configs_list:
for k, v in config.items():
if not isinstance(v, list) or isinstance(v, tuple):
config[k] = [v]
for vals in itertools.product(*[config[k] for k in prompt_config_keys]):
all_config_product.append({k:v for k,v in zip(prompt_config_keys, vals)})

logger.info(f"Generated {len(all_config_product)} benchmark run setting(s)")

for config in all_config_product:
num_clients = config.pop("num_clients")
prompt_config = PromptConfig(**config)
yield num_clients, prompt_config

def _clear_queues(self):
self.query_queue = self.queue_cls()
self.result_queue = self.queue_cls()

def _save_results(self, num_clients, prompt_config):
response_details = []
while len(response_details) != num_clients * self.config.num_requests_per_client:
res = self.result_queue.get()
# vLLM returns concatinated tokens
response_details.append(res)
return response_details

def run(self):
self.client_class.start_service(self.client_config)
for num_clients, prompt_config in self._benchmark_settings():
logger.info(f"Running benchmark with {num_clients} client(s) and prompt config: {prompt_config}")
self._clear_queues()
self._generate_prompts(prompt_config=prompt_config, num_clients=num_clients)
#self._prepare_requests()
self._launch_clients(num_clients=num_clients)
#self._process_repsonses()
rd = self._save_results(prompt_config=prompt_config, num_clients=num_clients)
print(len(rd))
self.client_class.stop_service(self.client_config)


if __name__ == "__main__":
from .arg_parsing import parse_args_to_configs
benchmark_config, client_config = parse_args_to_configs()
benchmark_runner = BenchmarkRunner(benchmark_config, client_config)
benchmark_runner.run()
7 changes: 7 additions & 0 deletions benchmarks/inference/llm-bench/src/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .azure_ml_client import AzureMLClientConfig, AzureMLClient
from .fastgen_client import FastGenClientConfig, FastGenClient
from .vllm_client import vLLMClientConfig, vLLMClient
from .dummy_client import DummyClientConfig, DummyClient

client_config_classes = {"dummy": DummyClientConfig,"azure_ml": AzureMLClientConfig, "fastgen": FastGenClientConfig, "vllm": vLLMClientConfig}
client_classes = {"dummy": DummyClient, "azure_ml": AzureMLClient, "fastgen": FastGenClient, "vllm": vLLMClient}
66 changes: 66 additions & 0 deletions benchmarks/inference/llm-bench/src/clients/azure_ml_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from .base import BaseClient
from ..config import BaseConfigModel
from ..status import Status
from ..prompt import Prompt
from ..response import Response

import requests
import json
from typing import Any, Dict, Optional


class AzureMLClientConfig(BaseConfigModel):
api_url: str = ""
api_key: str = ""
deployment_name: str = ""

class AzureMLClient(BaseClient):
def __init__(self, config: AzureMLClientConfig) -> None:
self.api_url = config.api_url
self.api_key = config.api_key
self.deployment_name = config.deployment_name

@staticmethod
def start_service(config: AzureMLClientConfig) -> Status:
pass

@staticmethod
def stop_service(config: AzureMLClientConfig) -> Status:
pass

def prepare_request(self, prompt: Prompt) -> Dict[str, Any]:
if prompt.streaming:
raise ValueError("AzureMLClient does not support streaming prompts.")

headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + self.api_key),
"azureml-model-deployment": self.deployment_name,
}
pload = {
"input_data": {
"input_string": [
prompt.text,
],
"parameters": {
"max_tokens": prompt.max_new_tokens,
"return_full_text": prompt.return_full_text,
},
}
}
return {"url": self.api_url, "headers": headers, "json": pload, "timeout": 180}

def send_request(self, request_kwargs: Dict[str, Any]) -> Any:
while True:
try: # Sometimes the AML endpoint will return an error, so we send the request again
response = requests.post(**request_kwargs)
output = json.loads(response.content)
break
except Exception as e:
print(f"Connection failed with {e}. Retrying AML request")

return output

def process_response(self, raw_response: Any) -> Response:
response_text = raw_response[0]
return Response(response_text)
35 changes: 35 additions & 0 deletions benchmarks/inference/llm-bench/src/clients/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from abc import ABC, abstractmethod

from ..prompt import Prompt
from ..response import Response
from ..status import Status
from ..config import BaseConfigModel

from typing import Any, Dict

class BaseClient(ABC):
@abstractmethod
def __init__(self):
pass

@staticmethod
@abstractmethod
def start_service(config: BaseConfigModel) -> Status:
pass

@staticmethod
@abstractmethod
def stop_service(config: BaseConfigModel) -> Status:
pass

@abstractmethod
def prepare_request(self, prompt: Prompt) -> Dict[str, Any]:
pass

@abstractmethod
def send_request(self, request_kwargs: Dict[str,Any]) -> Any:
pass

@abstractmethod
def process_response(self, raw_response: Any) -> Response:
pass
37 changes: 37 additions & 0 deletions benchmarks/inference/llm-bench/src/clients/dummy_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from abc import ABC, abstractmethod

from ..config import BaseConfigModel
from ..prompt import Prompt
from ..response import Response
from ..status import Status
from .base import BaseClient

from typing import Any, Dict
import time
import random

class DummyClientConfig(BaseConfigModel):
pass

class DummyClient(BaseClient):
def __init__(self, config: DummyClientConfig) -> None:
pass

@staticmethod
def start_service(config: DummyClientConfig) -> Status:
return Status("OK")

@staticmethod
def stop_service(config: DummyClientConfig) -> Status:
return Status("OK")

def prepare_request(self, prompt: Prompt) -> Dict[str, Any]:
return {"input_text": prompt.text, "max_new_tokens": prompt.max_new_tokens}

def send_request(self, request_kwargs: Dict[str,Any]) -> Any:
time.sleep(random.uniform(1, 2))
#time.sleep(1)
return request_kwargs["input_text"]*2

def process_response(self, raw_response: Any) -> Response:
return Response(raw_response)
Loading

0 comments on commit be41d26

Please sign in to comment.