diff --git a/.github/workflows/python-package-genai.yml b/.github/workflows/python-package-genai.yml new file mode 100644 index 000000000..85d763f49 --- /dev/null +++ b/.github/workflows/python-package-genai.yml @@ -0,0 +1,60 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions for GenAI-Perf. +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Python package (GenAI-Perf) + +on: + pull_request: + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: ["ubuntu-22.04"] + python-version: ["3.8", "3.10"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + cd src/c++/perf_analyzer/genai-perf/ + python -m pip install --upgrade pip + python -m pip install -e . + python -c "import genai_perf; print(genai_perf.__version__)" + - name: Test with pytest + run: | + pip install pytest pytest-timeout pytest-cov psutil + cd src/c++/perf_analyzer/genai-perf/tests + pytest --doctest-modules --junitxml=junit/test-results.xml --cov=genai_perf --cov-report=xml --cov-report=html --ignore-glob=test_models diff --git a/src/c++/perf_analyzer/CMakeLists.txt b/src/c++/perf_analyzer/CMakeLists.txt index fe34ace4f..b81795e38 100644 --- a/src/c++/perf_analyzer/CMakeLists.txt +++ b/src/c++/perf_analyzer/CMakeLists.txt @@ -112,6 +112,7 @@ set( profile_data_exporter.h periodic_concurrency_manager.h periodic_concurrency_worker.h + thread_config.h ) add_executable( diff --git a/src/c++/perf_analyzer/client_backend/client_backend.h b/src/c++/perf_analyzer/client_backend/client_backend.h index f3caa7707..06f68c2e3 100644 --- a/src/c++/perf_analyzer/client_backend/client_backend.h +++ b/src/c++/perf_analyzer/client_backend/client_backend.h @@ -138,6 +138,8 @@ enum BackendKind { TRITON_C_API = 3, OPENAI = 4 }; +std::string BackendKindToString(const BackendKind kind); + enum ProtocolType { HTTP = 0, GRPC = 1, UNKNOWN = 2 }; enum GrpcCompressionAlgorithm { COMPRESS_NONE = 0, diff --git a/src/c++/perf_analyzer/command_line_parser.cc b/src/c++/perf_analyzer/command_line_parser.cc index b23295703..bd3d72d73 100644 --- a/src/c++/perf_analyzer/command_line_parser.cc +++ b/src/c++/perf_analyzer/command_line_parser.cc @@ -137,6 +137,7 @@ CLParser::Usage(const std::string& msg) "profiling>" << std::endl; std::cerr << "\t--percentile " << std::endl; + std::cerr << "\t--request-count " << std::endl; std::cerr << "\tDEPRECATED OPTIONS" << std::endl; std::cerr << "\t-t " << std::endl; std::cerr << "\t-c " << std::endl; @@ -463,6 +464,14 @@ CLParser::Usage(const std::string& msg) "that the average latency is used to determine stability", 18) << std::endl; + std::cerr + << FormatMessage( + " --request-count: Specifies a total number of requests to " + "use for measurement. The default is 0, which means that there is " + "no request count and the measurement will proceed using windows " + "until stabilization is detected.", + 18) + << std::endl; std::cerr << FormatMessage( " --serial-sequences: Enables serial sequence mode " "where a maximum of one request is outstanding at a time " @@ -879,6 +888,7 @@ CLParser::ParseCommandLine(int argc, char** argv) {"request-period", required_argument, 0, 59}, {"request-parameter", required_argument, 0, 60}, {"endpoint", required_argument, 0, 61}, + {"request-count", required_argument, 0, 62}, {0, 0, 0, 0}}; // Parse commandline... @@ -1614,6 +1624,13 @@ CLParser::ParseCommandLine(int argc, char** argv) params_->endpoint = optarg; break; } + case 62: { + if (std::stoi(optarg) < 0) { + Usage("Failed to parse --request-count. The value must be > 0."); + } + params_->request_count = std::stoi(optarg); + break; + } case 'v': params_->extra_verbose = params_->verbose; params_->verbose = true; @@ -1705,6 +1722,13 @@ CLParser::ParseCommandLine(int argc, char** argv) // Will be using user-provided time intervals, hence no control variable. params_->search_mode = SearchMode::NONE; } + + // When the request-count feature is enabled, override the measurement mode to + // be count windows with a window size of the requested count + if (params_->request_count) { + params_->measurement_mode = MeasurementMode::COUNT_WINDOWS; + params_->measurement_request_count = params_->request_count; + } } void @@ -1874,6 +1898,31 @@ CLParser::VerifyOptions() "binary search mode."); } + if (params_->request_count != 0) { + if (params_->using_concurrency_range) { + if (params_->request_count < params_->concurrency_range.start) { + Usage("request-count can not be less than concurrency"); + } + if (params_->concurrency_range.start < params_->concurrency_range.end) { + Usage( + "request-count not supported with multiple concurrency values in " + "one run"); + } + } + if (params_->using_request_rate_range) { + if (params_->request_count < + static_cast(params_->request_rate_range[0])) { + Usage("request-count can not be less than request-rate"); + } + if (params_->request_rate_range[SEARCH_RANGE::kSTART] < + params_->request_rate_range[SEARCH_RANGE::kEND]) { + Usage( + "request-count not supported with multiple request-rate values in " + "one run"); + } + } + } + if (params_->kind == cb::TENSORFLOW_SERVING) { if (params_->protocol != cb::ProtocolType::GRPC) { Usage( diff --git a/src/c++/perf_analyzer/command_line_parser.h b/src/c++/perf_analyzer/command_line_parser.h index cbd807eb4..461e24e2d 100644 --- a/src/c++/perf_analyzer/command_line_parser.h +++ b/src/c++/perf_analyzer/command_line_parser.h @@ -62,6 +62,7 @@ struct PerfAnalyzerParameters { uint64_t latency_threshold_ms = NO_LIMIT; double stability_threshold = 0.1; size_t max_trials = 10; + size_t request_count = 0; bool zero_input = false; size_t string_length = 128; std::string string_data; diff --git a/src/c++/perf_analyzer/concurrency_manager.cc b/src/c++/perf_analyzer/concurrency_manager.cc index a64062cc0..283861846 100644 --- a/src/c++/perf_analyzer/concurrency_manager.cc +++ b/src/c++/perf_analyzer/concurrency_manager.cc @@ -1,4 +1,4 @@ -// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -84,10 +84,10 @@ ConcurrencyManager::InitManagerFinalize() cb::Error ConcurrencyManager::ChangeConcurrencyLevel( - const size_t concurrent_request_count) + const size_t concurrent_request_count, const size_t request_count) { PauseSequenceWorkers(); - ReconfigThreads(concurrent_request_count); + ReconfigThreads(concurrent_request_count, request_count); ResumeSequenceWorkers(); std::cout << "Request concurrency: " << concurrent_request_count << std::endl; @@ -109,7 +109,8 @@ ConcurrencyManager::PauseSequenceWorkers() } void -ConcurrencyManager::ReconfigThreads(const size_t concurrent_request_count) +ConcurrencyManager::ReconfigThreads( + size_t concurrent_request_count, size_t request_count) { // Always prefer to create new threads if the maximum limit has not been met // @@ -121,8 +122,7 @@ ConcurrencyManager::ReconfigThreads(const size_t concurrent_request_count) (threads_.size() < max_threads_)) { // Launch new thread for inferencing threads_stat_.emplace_back(new ThreadStat()); - threads_config_.emplace_back( - new ConcurrencyWorker::ThreadConfig(threads_config_.size())); + threads_config_.emplace_back(new ThreadConfig(threads_config_.size())); workers_.push_back( MakeWorker(threads_stat_.back(), threads_config_.back())); @@ -138,6 +138,10 @@ ConcurrencyManager::ReconfigThreads(const size_t concurrent_request_count) // and spread the remaining value size_t avg_concurrency = concurrent_request_count / threads_.size(); size_t threads_add_one = concurrent_request_count % threads_.size(); + + size_t avg_req_count = request_count / threads_.size(); + size_t req_count_add_one = request_count % threads_.size(); + size_t seq_stat_index_offset = 0; active_threads_ = 0; for (size_t i = 0; i < threads_stat_.size(); i++) { @@ -145,6 +149,10 @@ ConcurrencyManager::ReconfigThreads(const size_t concurrent_request_count) threads_config_[i]->concurrency_ = concurrency; threads_config_[i]->seq_stat_index_offset_ = seq_stat_index_offset; + + size_t thread_num_reqs = avg_req_count + (i < req_count_add_one ? 1 : 0); + threads_config_[i]->num_requests_ = thread_num_reqs; + seq_stat_index_offset += concurrency; if (concurrency) { @@ -171,7 +179,7 @@ ConcurrencyManager::ResumeSequenceWorkers() std::shared_ptr ConcurrencyManager::MakeWorker( std::shared_ptr thread_stat, - std::shared_ptr thread_config) + std::shared_ptr thread_config) { uint32_t id = workers_.size(); diff --git a/src/c++/perf_analyzer/concurrency_manager.h b/src/c++/perf_analyzer/concurrency_manager.h index 513d7396c..c6c90f1d1 100644 --- a/src/c++/perf_analyzer/concurrency_manager.h +++ b/src/c++/perf_analyzer/concurrency_manager.h @@ -1,4 +1,4 @@ -// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -89,14 +89,16 @@ class ConcurrencyManager : public LoadManager { /// Adjusts the number of concurrent requests to be the same as /// 'concurrent_request_count' (by creating or pausing threads) /// \param concurent_request_count The number of concurrent requests. + /// \param request_count The number of requests to generate. If 0, then + /// there is no limit, and it will generate until told to stop. /// \return cb::Error object indicating success or failure. - cb::Error ChangeConcurrencyLevel(const size_t concurrent_request_count); + cb::Error ChangeConcurrencyLevel( + const size_t concurrent_request_count, const size_t request_count = 0); protected: // Makes a new worker virtual std::shared_ptr MakeWorker( - std::shared_ptr, - std::shared_ptr); + std::shared_ptr, std::shared_ptr); ConcurrencyManager( const bool async, const bool streaming, const int32_t batch_size, @@ -114,7 +116,7 @@ class ConcurrencyManager : public LoadManager { size_t max_concurrency_; - std::vector> threads_config_; + std::vector> threads_config_; private: void InitManagerFinalize() override; @@ -126,7 +128,7 @@ class ConcurrencyManager : public LoadManager { // Create new threads (if necessary), and then reconfigure all worker threads // to handle the new concurrent request count // - void ReconfigThreads(size_t concurrent_request_count); + void ReconfigThreads(size_t concurrent_request_count, size_t request_count); // Restart all worker threads that were working on sequences // diff --git a/src/c++/perf_analyzer/concurrency_worker.h b/src/c++/perf_analyzer/concurrency_worker.h index 94cb90fbe..4645f07af 100644 --- a/src/c++/perf_analyzer/concurrency_worker.h +++ b/src/c++/perf_analyzer/concurrency_worker.h @@ -1,4 +1,4 @@ -// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -29,6 +29,7 @@ #include "load_worker.h" #include "sequence_manager.h" +#include "thread_config.h" namespace triton { namespace perfanalyzer { @@ -49,28 +50,6 @@ class NaggyMockConcurrencyWorker; /// class ConcurrencyWorker : public LoadWorker { public: - struct ThreadConfig { - ThreadConfig( - size_t thread_id, size_t concurrency = 0, - size_t seq_stat_index_offset = 0) - : thread_id_(thread_id), concurrency_(concurrency), - seq_stat_index_offset_(seq_stat_index_offset), is_paused_(false) - { - } - - // ID of corresponding worker thread - size_t thread_id_; - - // The concurrency level that the worker should produce - size_t concurrency_; - - // The starting sequence stat index for this worker - size_t seq_stat_index_offset_; - - // Whether or not the thread is issuing new inference requests - bool is_paused_; - }; - ConcurrencyWorker( uint32_t id, std::shared_ptr thread_stat, std::shared_ptr thread_config, @@ -85,11 +64,11 @@ class ConcurrencyWorker : public LoadWorker { const std::shared_ptr& infer_data_manager, std::shared_ptr sequence_manager) : LoadWorker( - id, thread_stat, parser, data_loader, factory, on_sequence_model, - async, streaming, batch_size, using_json_data, wake_signal, - wake_mutex, execute, infer_data_manager, sequence_manager), - thread_config_(thread_config), max_concurrency_(max_concurrency), - active_threads_(active_threads) + id, thread_stat, thread_config, parser, data_loader, factory, + on_sequence_model, async, streaming, batch_size, using_json_data, + wake_signal, wake_mutex, execute, infer_data_manager, + sequence_manager), + max_concurrency_(max_concurrency), active_threads_(active_threads) { } @@ -109,8 +88,6 @@ class ConcurrencyWorker : public LoadWorker { // threads? size_t& active_threads_; - std::shared_ptr thread_config_; - // Handle the case where execute_ is false void HandleExecuteOff(); diff --git a/src/c++/perf_analyzer/custom_load_manager.cc b/src/c++/perf_analyzer/custom_load_manager.cc index 32e5693b0..55a20a690 100644 --- a/src/c++/perf_analyzer/custom_load_manager.cc +++ b/src/c++/perf_analyzer/custom_load_manager.cc @@ -1,4 +1,4 @@ -// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -76,10 +76,10 @@ CustomLoadManager::CustomLoadManager( } cb::Error -CustomLoadManager::InitCustomIntervals() +CustomLoadManager::InitCustomIntervals(const size_t request_count) { PauseWorkers(); - ConfigureThreads(); + ConfigureThreads(request_count); auto status = GenerateSchedule(); ResumeWorkers(); return status; diff --git a/src/c++/perf_analyzer/custom_load_manager.h b/src/c++/perf_analyzer/custom_load_manager.h index c762e9c7e..39c51d99f 100644 --- a/src/c++/perf_analyzer/custom_load_manager.h +++ b/src/c++/perf_analyzer/custom_load_manager.h @@ -1,4 +1,4 @@ -// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -88,8 +88,10 @@ class CustomLoadManager : public RequestRateManager { /// Initializes the load manager with the provided file containing request /// intervals + /// \param request_count The number of requests to generate. If 0, then + /// there is no limit, and it will generate until told to stop. /// \return cb::Error object indicating success or failure. - cb::Error InitCustomIntervals(); + cb::Error InitCustomIntervals(const size_t request_count); /// Computes the request rate from the time interval file. Fails with an error /// if the file is not present or is empty. diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/constants.py b/src/c++/perf_analyzer/genai-perf/genai_perf/constants.py index b0d334116..b951524bf 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/constants.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/constants.py @@ -34,4 +34,5 @@ DEFAULT_ARTIFACT_DIR = "artifacts" +DEFAULT_COMPARE_DIR = "compare" DEFAULT_PARQUET_FILE = "all_data" diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py index c17eaa0b2..3137d2fe4 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py @@ -39,14 +39,15 @@ class OutputFormat(Enum): TENSORRTLLM = auto() VLLM = auto() + def to_lowercase(self): + return self.name.lower() + class LlmInputs: """ A library of methods that control the generation of LLM Inputs """ - OUTPUT_FILENAME = DEFAULT_INPUT_DATA_JSON - OPEN_ORCA_URL = "https://datasets-server.huggingface.co/rows?dataset=Open-Orca%2FOpenOrca&config=default&split=train" CNN_DAILYMAIL_URL = "https://datasets-server.huggingface.co/rows?dataset=cnn_dailymail&config=1.0.0&split=train" @@ -92,6 +93,7 @@ def create_llm_inputs( add_stream: bool = False, tokenizer: Tokenizer = get_tokenizer(DEFAULT_TOKENIZER), extra_inputs: Optional[Dict] = None, + output_dir: Path = Path(""), ) -> Dict: """ Given an input type, input format, and output type. Output a string of LLM Inputs @@ -193,7 +195,7 @@ def create_llm_inputs( output_tokens_deterministic, model_name, ) - cls._write_json_to_file(json_in_pa_format) + cls._write_json_to_file(json_in_pa_format, output_dir) return json_in_pa_format @@ -540,8 +542,9 @@ def _convert_generic_json_to_trtllm_format( return pa_json @classmethod - def _write_json_to_file(cls, json_in_pa_format: Dict) -> None: - with open(DEFAULT_INPUT_DATA_JSON, "w") as f: + def _write_json_to_file(cls, json_in_pa_format: Dict, output_dir: Path) -> None: + filename = output_dir / DEFAULT_INPUT_DATA_JSON + with open(str(filename), "w") as f: f.write(json.dumps(json_in_pa_format, indent=2)) @classmethod diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_metrics.py b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_metrics.py index 221d26552..24fcb49f9 100755 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_metrics.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_metrics.py @@ -28,20 +28,23 @@ import csv import json -from itertools import pairwise -from typing import List +from enum import Enum, auto +from itertools import tee +from pathlib import Path +from typing import Dict, List, Tuple, Union import numpy as np import pandas as pd -from genai_perf.constants import DEFAULT_ARTIFACT_DIR -from genai_perf.llm_inputs.llm_inputs import OutputFormat from genai_perf.tokenizer import Tokenizer from genai_perf.utils import load_json, remove_sse_prefix from rich.console import Console from rich.table import Table -_OPENAI_CHAT_COMPLETIONS = OutputFormat.OPENAI_CHAT_COMPLETIONS -_OPENAI_COMPLETIONS = OutputFormat.OPENAI_COMPLETIONS + +class ResponseFormat(Enum): + OPENAI_CHAT_COMPLETIONS = auto() + OPENAI_COMPLETIONS = auto() + TRITON = auto() class Metrics: @@ -112,7 +115,7 @@ def __init__( request_throughputs: List[float] = [], request_latencies: List[int] = [], time_to_first_tokens: List[int] = [], - inter_token_latencies: List[list[int]] = [[]], + inter_token_latencies: List[List[int]] = [[]], output_token_throughputs: List[float] = [], output_token_throughputs_per_request: List[int] = [], num_output_tokens: List[int] = [], @@ -167,7 +170,7 @@ def __init__(self, metrics: Metrics): self._calculate_minmax(data, attr) self._calculate_std(data, attr) - def _preprocess_data(self, data: list, attr: str) -> list[int | float]: + def _preprocess_data(self, data: List, attr: str) -> List[Union[int, float]]: new_data = [] if attr == "inter_token_latency": # flatten inter token latencies to 1D @@ -177,11 +180,11 @@ def _preprocess_data(self, data: list, attr: str) -> list[int | float]: new_data = data return new_data - def _calculate_mean(self, data: list[int | float], attr: str) -> None: + def _calculate_mean(self, data: List[Union[int, float]], attr: str) -> None: avg = np.mean(data) setattr(self, "avg_" + attr, avg) - def _calculate_percentiles(self, data: list[int | float], attr: str) -> None: + def _calculate_percentiles(self, data: List[Union[int, float]], attr: str) -> None: p25, p50, p75 = np.percentile(data, [25, 50, 75]) p90, p95, p99 = np.percentile(data, [90, 95, 99]) setattr(self, "p25_" + attr, p25) @@ -191,12 +194,12 @@ def _calculate_percentiles(self, data: list[int | float], attr: str) -> None: setattr(self, "p95_" + attr, p95) setattr(self, "p99_" + attr, p99) - def _calculate_minmax(self, data: list[int | float], attr: str) -> None: + def _calculate_minmax(self, data: List[Union[int, float]], attr: str) -> None: min, max = np.min(data), np.max(data) setattr(self, "min_" + attr, min) setattr(self, "max_" + attr, max) - def _calculate_std(self, data: list[int | float], attr: str) -> None: + def _calculate_std(self, data: List[Union[int, float]], attr: str) -> None: std = np.std(data) setattr(self, "std_" + attr, std) @@ -373,15 +376,17 @@ def export_to_csv(self, csv_filename: str) -> None: for row in singular_metric_rows: csv_writer.writerow(row) - def export_parquet(self, parquet_filename: str) -> None: + def export_parquet(self, artifact_dir: Path, filename: str) -> None: max_length = -1 col_index = 0 filler_list = [] df = pd.DataFrame() + # Data frames require all columns of the same length # find the max length column for key, value in self._metrics.data.items(): max_length = max(max_length, len(value)) + # Insert None for shorter columns to match longest column for key, value in self._metrics.data.items(): if len(value) < max_length: @@ -391,9 +396,9 @@ def export_parquet(self, parquet_filename: str) -> None: diff = 0 filler_list = [] col_index = col_index + 1 - df.to_parquet( - f"{DEFAULT_ARTIFACT_DIR}/data/{parquet_filename}.gzip", compression="gzip" - ) + + filepath = artifact_dir / f"{filename}.gzip" + df.to_parquet(filepath, compression="gzip") class ProfileDataParser: @@ -401,10 +406,36 @@ class ProfileDataParser: extract core metrics and calculate various performance statistics. """ - def __init__(self, filename: str) -> None: + def __init__(self, filename: Path) -> None: data = load_json(filename) + self._get_profile_metadata(data) self._parse_profile_data(data) + def _get_profile_metadata(self, data: dict) -> None: + self._service_kind = data["service_kind"] + if self._service_kind == "openai": + if data["endpoint"] == "v1/chat/completions": + self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS + elif data["endpoint"] == "v1/completions": + self._response_format = ResponseFormat.OPENAI_COMPLETIONS + else: + # TPA-66: add PA metadata to handle this case + # When endpoint field is either empty or custom endpoint, fall + # back to parsing the response to extract the response format. + request = data["experiments"][0]["requests"][0] + response = request["response_outputs"][0]["response"] + if "chat.completion" in response: + self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS + elif "text_completion" in response: + self._response_format = ResponseFormat.OPENAI_COMPLETIONS + else: + raise RuntimeError("Unknown OpenAI response format.") + + elif self._service_kind == "triton": + self._response_format = ResponseFormat.TRITON + else: + raise ValueError(f"Unknown service kind: {self._service_kind}") + def _parse_profile_data(self, data: dict) -> None: """Parse through the entire profile data to collect statistics.""" self._profile_results = {} @@ -429,6 +460,10 @@ def get_statistics(self, infer_mode: str, load_level: str) -> Statistics: raise KeyError(f"Profile with {infer_mode}={load_level} does not exist.") return self._profile_results[(infer_mode, load_level)] + def get_profile_load_info(self) -> List[Tuple[str, str]]: + """Return available (infer_mode, load_level) tuple keys.""" + return [k for k, _ in self._profile_results.items()] + class LLMProfileDataParser(ProfileDataParser): """A class that calculates and aggregates all the LLM performance statistics @@ -447,7 +482,6 @@ class LLMProfileDataParser(ProfileDataParser): >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> pd = LLMProfileDataParser( >>> filename="profile_export.json", - >>> service_kind="triton", >>> tokenizer=tokenizer, >>> ) >>> stats = pd.get_statistics(infer_mode="concurrency", level=10) @@ -458,14 +492,10 @@ class LLMProfileDataParser(ProfileDataParser): def __init__( self, - filename: str, - service_kind: str, - output_format: OutputFormat, + filename: Path, tokenizer: Tokenizer, ) -> None: self._tokenizer = tokenizer - self._service_kind = service_kind - self._output_format = output_format super().__init__(filename) def _parse_requests(self, requests: dict) -> LLMMetrics: @@ -517,7 +547,9 @@ def _parse_requests(self, requests: dict) -> LLMMetrics: # inter token latency itl_per_request = [] - for (t1, _), (t2, n2) in pairwise(zip(res_timestamps, num_output_tokens)): + for (t1, _), (t2, n2) in self._pairwise( + zip(res_timestamps, num_output_tokens) + ): # TMA-1676: handle empty first/last responses # if the latter response has zero token (e.g. empty string), # then set it default to one for the sake of inter token latency @@ -542,8 +574,14 @@ def _parse_requests(self, requests: dict) -> LLMMetrics: num_input_tokens, ) + def _pairwise(self, iterable): + """Generate pairs of consecutive elements from the given iterable.""" + a, b = tee(iterable) + next(b, None) + return zip(a, b) + def _preprocess_response( - self, res_timestamps: list[int], res_outputs: list[dict[str, str]] + self, res_timestamps: List[int], res_outputs: List[Dict[str, str]] ) -> None: """Helper function to preprocess responses of a request.""" if self._service_kind == "openai": @@ -562,19 +600,19 @@ def _preprocess_response( # Remove responses without any content # These are only observed to happen at the start or end - while res_outputs[0] and self._is_openai_empty_response( + while res_outputs and self._is_openai_empty_response( res_outputs[0]["response"] ): res_timestamps.pop(0) res_outputs.pop(0) - while res_outputs[-1] and self._is_openai_empty_response( + while res_outputs and self._is_openai_empty_response( res_outputs[-1]["response"] ): res_timestamps.pop() res_outputs.pop() - def _tokenize_request_inputs(self, req_inputs: dict) -> list[int]: + def _tokenize_request_inputs(self, req_inputs: dict) -> List[int]: """Deserialize the request input and return tokenized inputs.""" if self._service_kind == "triton": return self._tokenize_triton_request_input(req_inputs) @@ -583,17 +621,17 @@ def _tokenize_request_inputs(self, req_inputs: dict) -> list[int]: else: raise ValueError(f"Unknown service kind: '{self._service_kind}'.") - def _tokenize_triton_request_input(self, req_inputs: dict) -> list[int]: + def _tokenize_triton_request_input(self, req_inputs: dict) -> List[int]: """Tokenize the Triton request input texts.""" encodings = self._tokenizer(req_inputs["text_input"]) return encodings.data["input_ids"] - def _tokenize_openai_request_input(self, req_inputs: dict) -> list[int]: + def _tokenize_openai_request_input(self, req_inputs: dict) -> List[int]: """Tokenize the OpenAI request input texts.""" payload = json.loads(req_inputs["payload"]) - if self._output_format == _OPENAI_CHAT_COMPLETIONS: + if self._response_format == ResponseFormat.OPENAI_CHAT_COMPLETIONS: input_text = payload["messages"][0]["content"] - elif self._output_format == _OPENAI_COMPLETIONS: + elif self._response_format == ResponseFormat.OPENAI_COMPLETIONS: input_text = payload["prompt"] else: raise ValueError( @@ -602,7 +640,7 @@ def _tokenize_openai_request_input(self, req_inputs: dict) -> list[int]: encodings = self._tokenizer(input_text) return encodings.data["input_ids"] - def _tokenize_response_outputs(self, res_outputs: dict) -> list[list[int]]: + def _tokenize_response_outputs(self, res_outputs: dict) -> List[List[int]]: """Deserialize the response output and return tokenized outputs.""" if self._service_kind == "triton": return self._tokenize_triton_response_output(res_outputs) @@ -611,14 +649,14 @@ def _tokenize_response_outputs(self, res_outputs: dict) -> list[list[int]]: else: raise ValueError(f"Unknown service kind: '{self._service_kind}'.") - def _tokenize_triton_response_output(self, res_outputs: dict) -> list[list[int]]: + def _tokenize_triton_response_output(self, res_outputs: dict) -> List[List[int]]: """Tokenize the Triton response output texts.""" output_texts = [] for output in res_outputs: output_texts.append(output["text_output"]) return self._run_tokenizer(output_texts) - def _tokenize_openai_response_output(self, res_outputs: dict) -> list[list[int]]: + def _tokenize_openai_response_output(self, res_outputs: dict) -> List[List[int]]: """Tokenize the OpenAI response output texts.""" output_texts = [] for output in res_outputs: @@ -626,7 +664,7 @@ def _tokenize_openai_response_output(self, res_outputs: dict) -> list[list[int]] output_texts.append(text) return self._run_tokenizer(output_texts) - def _run_tokenizer(self, output_texts: list[str]) -> list[list[int]]: + def _run_tokenizer(self, output_texts: List[str]) -> List[List[int]]: # exclamation mark trick forces the llama tokenization to consistently # start each output with a specific token which allows us to safely skip # the first token of every tokenized output and get only the ones that diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/logging.py b/src/c++/perf_analyzer/genai-perf/genai_perf/logging.py index 9d8914024..db23dff06 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/logging.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/logging.py @@ -70,6 +70,16 @@ def init_logging() -> None: "level": "DEBUG", "propagate": False, }, + "genai_perf.plots.plot_config_parser": { + "handlers": ["console"], + "level": "DEBUG", + "propagate": False, + }, + "genai_perf.plots.plot_manager": { + "handlers": ["console"], + "level": "DEBUG", + "propagate": False, + }, }, } logging.config.dictConfig(LOGGING_CONFIG) diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/main.py b/src/c++/perf_analyzer/genai-perf/genai_perf/main.py index dbb1e295a..04dcc799e 100755 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/main.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/main.py @@ -26,7 +26,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os -import shutil import sys import traceback from argparse import Namespace @@ -34,23 +33,20 @@ import genai_perf.logging as logging from genai_perf import parser -from genai_perf.constants import DEFAULT_ARTIFACT_DIR, DEFAULT_PARQUET_FILE +from genai_perf.constants import DEFAULT_PARQUET_FILE from genai_perf.exceptions import GenAIPerfException from genai_perf.llm_inputs.llm_inputs import LlmInputs -from genai_perf.llm_metrics import LLMProfileDataParser, Statistics +from genai_perf.llm_metrics import LLMProfileDataParser +from genai_perf.plots.plot_config_parser import PlotConfigParser from genai_perf.plots.plot_manager import PlotManager from genai_perf.tokenizer import Tokenizer, get_tokenizer -def init_logging() -> None: - logging.init_logging() - - -def create_artifacts_dirs(generate_plots: bool) -> None: - if not os.path.exists(f"{DEFAULT_ARTIFACT_DIR}"): - os.mkdir(f"{DEFAULT_ARTIFACT_DIR}") - os.mkdir(f"{DEFAULT_ARTIFACT_DIR}/data") - os.mkdir(f"{DEFAULT_ARTIFACT_DIR}/plots") +def create_artifacts_dirs(args: Namespace) -> None: + # TMA-1911: support plots CLI option + plot_dir = args.artifact_dir / "plots" + os.makedirs(args.artifact_dir, exist_ok=True) + os.makedirs(plot_dir, exist_ok=True) def generate_inputs(args: Namespace, tokenizer: Tokenizer) -> None: @@ -81,25 +77,24 @@ def generate_inputs(args: Namespace, tokenizer: Tokenizer) -> None: add_stream=args.streaming, tokenizer=tokenizer, extra_inputs=extra_input_dict, + output_dir=args.artifact_dir, ) def calculate_metrics(args: Namespace, tokenizer: Tokenizer) -> LLMProfileDataParser: return LLMProfileDataParser( filename=args.profile_export_file, - service_kind=args.service_kind, - output_format=args.output_format, tokenizer=tokenizer, ) def report_output(data_parser: LLMProfileDataParser, args: Namespace) -> None: - if "concurrency_range" in args: + if args.concurrency: infer_mode = "concurrency" - load_level = args.concurrency_range - elif "request_rate_range" in args: + load_level = f"{args.concurrency}" + elif args.request_rate: infer_mode = "request_rate" - load_level = args.request_rate_range + load_level = f"{args.request_rate}" else: raise GenAIPerfException("No valid infer mode specified") @@ -108,48 +103,48 @@ def report_output(data_parser: LLMProfileDataParser, args: Namespace) -> None: args.profile_export_file.stem + "_genai_perf.csv" ) stats.export_to_csv(export_csv_name) - stats.export_parquet(DEFAULT_PARQUET_FILE) + stats.export_parquet(args.artifact_dir, DEFAULT_PARQUET_FILE) stats.pretty_print() if args.generate_plots: - create_plots(stats) + create_plots(args) -def create_plots(stats: Statistics) -> None: - plot_manager = PlotManager(stats) - plot_manager.create_default_plots() - - -def finalize(profile_export_file: Path): - shutil.move("llm_inputs.json", f"{DEFAULT_ARTIFACT_DIR}/data/llm_inputs.json") - shutil.move( - profile_export_file, f"{DEFAULT_ARTIFACT_DIR}/data/{profile_export_file}" - ) - profile_export_file_csv = profile_export_file.stem + "_genai_perf.csv" - shutil.move( - profile_export_file_csv, - f"{DEFAULT_ARTIFACT_DIR}/data/{profile_export_file_csv}", +def create_plots(args: Namespace) -> None: + # TMA-1911: support plots CLI option + plot_dir = args.artifact_dir / "plots" + PlotConfigParser.create_init_yaml_config( + filenames=[args.profile_export_file], # single run + output_dir=plot_dir, ) + config_parser = PlotConfigParser(plot_dir / "config.yaml") + plot_configs = config_parser.generate_configs() + plot_manager = PlotManager(plot_configs) + plot_manager.generate_plots() # Separate function that can raise exceptions used for testing # to assert correct errors and messages. def run(): try: - init_logging() + # TMA-1900: refactor CLI handler + logging.init_logging() args, extra_args = parser.parse_args() - create_artifacts_dirs(args.generate_plots) - tokenizer = get_tokenizer(args.tokenizer) - generate_inputs(args, tokenizer) - args.func(args, extra_args) - data_parser = calculate_metrics(args, tokenizer) - report_output(data_parser, args) - finalize(args.profile_export_file) + if args.subcommand == "compare": + args.func(args) + else: + create_artifacts_dirs(args) + tokenizer = get_tokenizer(args.tokenizer) + generate_inputs(args, tokenizer) + args.func(args, extra_args) + data_parser = calculate_metrics(args, tokenizer) + report_output(data_parser, args) except Exception as e: raise GenAIPerfException(e) def main(): - # Interactive use will catch exceptions and log formatted errors rather than tracebacks. + # Interactive use will catch exceptions and log formatted errors rather than + # tracebacks. try: run() except Exception as e: diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py b/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py index e0b9f404c..4bdfe3c56 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py @@ -25,13 +25,21 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import argparse +import os import sys from pathlib import Path import genai_perf.logging as logging import genai_perf.utils as utils -from genai_perf.constants import CNN_DAILY_MAIL, OPEN_ORCA +from genai_perf.constants import ( + CNN_DAILY_MAIL, + DEFAULT_ARTIFACT_DIR, + DEFAULT_COMPARE_DIR, + OPEN_ORCA, +) from genai_perf.llm_inputs.llm_inputs import LlmInputs, OutputFormat, PromptSource +from genai_perf.plots.plot_config_parser import PlotConfigParser +from genai_perf.plots.plot_manager import PlotManager from genai_perf.tokenizer import DEFAULT_TOKENIZER from . import __version__ @@ -41,6 +49,29 @@ _endpoint_type_map = {"chat": "v1/chat/completions", "completions": "v1/completions"} +def _check_model_args( + parser: argparse.ArgumentParser, args: argparse.Namespace +) -> argparse.Namespace: + """ + Check if model name is provided. + """ + if not args.subcommand and not args.model: + parser.error("The -m/--model option is required and cannot be empty.") + return args + + +def _check_compare_args( + parser: argparse.ArgumentParser, args: argparse.Namespace +) -> argparse.Namespace: + """ + Check compare subcommand args + """ + if args.subcommand == "compare": + if not args.config and not args.files: + parser.error("Either the --config or --files option must be specified.") + return args + + def _check_conditional_args( parser: argparse.ArgumentParser, args: argparse.Namespace ) -> argparse.Namespace: @@ -93,19 +124,52 @@ def _check_conditional_args( return args -def _update_load_manager_args(args: argparse.Namespace) -> argparse.Namespace: +def _check_load_manager_args(args: argparse.Namespace) -> argparse.Namespace: """ - Update genai-perf load manager attributes to PA format + Check inference load args """ - for attr_key in ["concurrency", "request_rate"]: - attr_val = getattr(args, attr_key) - if attr_val is not None: - setattr(args, f"{attr_key}_range", f"{attr_val}") - delattr(args, attr_key) - return args - # If no concurrency or request rate is set, default to 1 - setattr(args, "concurrency_range", "1") + if not args.concurrency and not args.request_rate: + args.concurrency = 1 + return args + + +def _set_artifact_paths(args: argparse.Namespace) -> argparse.Namespace: + """ + Set paths for all the artifacts. + """ + if args.artifact_dir == Path(DEFAULT_ARTIFACT_DIR): + # Preprocess Huggingface model names that include '/' in their model name. + if (args.model is not None) and ("/" in args.model): + filtered_name = "_".join(args.model.split("/")) + logger.info( + f"Model name '{args.model}' cannot be used to create artifact " + f"directory. Instead, '{filtered_name}' will be used." + ) + name = [f"{filtered_name}"] + else: + name = [f"{args.model}"] + + if args.service_kind == "openai": + name += [f"{args.service_kind}-{args.endpoint_type}"] + elif args.service_kind == "triton": + name += [f"{args.service_kind}-{args.backend.to_lowercase()}"] + else: + raise ValueError(f"Unknown service kind '{args.service_kind}'.") + + if args.concurrency: + name += [f"concurrency{args.concurrency}"] + elif args.request_rate: + name += [f"request_rate{args.request_rate}"] + args.artifact_dir = args.artifact_dir / Path("-".join(name)) + + if args.profile_export_file.parent != Path(""): + raise ValueError( + "Please use --artifact-dir option to define intermediary paths to " + "the profile export file." + ) + + args.profile_export_file = args.artifact_dir / args.profile_export_file return args @@ -132,15 +196,6 @@ def _convert_str_to_enum_entry(args, option, enum): return args -### Handlers ### - - -def handler(args, extra_args): - from genai_perf.wrapper import Profiler - - Profiler.run(args=args, extra_args=extra_args) - - ### Parsers ### @@ -286,7 +341,7 @@ def _add_endpoint_args(parser): "-m", "--model", type=str, - required=True, + default=None, help=f"The name of the model to benchmark.", ) @@ -350,24 +405,29 @@ def _add_endpoint_args(parser): def _add_output_args(parser): output_group = parser.add_argument_group("Output") - output_group.add_argument( "--generate-plots", action="store_true", required=False, help="An option to enable the generation of plots.", ) - output_group.add_argument( "--profile-export-file", type=Path, - default="profile_export.json", + default=Path("profile_export.json"), help="The path where the perf_analyzer profile export will be " "generated. By default, the profile export will be to profile_export.json. " "The genai-perf file will be exported to _genai_perf.csv. " "For example, if the profile export file is profile_export.json, the genai-perf file will be " "exported to profile_export_genai_perf.csv.", ) + output_group.add_argument( + "--artifact-dir", + type=Path, + default=Path(DEFAULT_ARTIFACT_DIR), + help="The directory to store all the (output) artifacts generated by " + "GenAI-Perf and Perf Analyzer.", + ) def _add_other_args(parser): @@ -437,6 +497,61 @@ def get_extra_inputs_as_dict(args: argparse.Namespace) -> dict: return request_inputs +def _parse_compare_args(subparsers) -> argparse.ArgumentParser: + compare = subparsers.add_parser( + "compare", + description="Subcommand to generate plots that compare multiple profile runs.", + ) + compare_group = compare.add_argument_group("Compare") + mx_group = compare_group.add_mutually_exclusive_group(required=False) + mx_group.add_argument( + "--config", + type=Path, + default=None, + help="The path to the YAML file that specifies plot configurations for " + "comparing multiple runs.", + ) + mx_group.add_argument( + "-f", + "--files", + nargs="+", + default=[], + help="List of paths to the profile export JSON files. Users can specify " + "this option instead of the `--config` option if they would like " + "GenAI-Perf to generate default plots as well as initial YAML config file.", + ) + compare.set_defaults(func=compare_handler) + return compare + + +### Handlers ### + + +def create_compare_dir() -> None: + if not os.path.exists(DEFAULT_COMPARE_DIR): + os.mkdir(DEFAULT_COMPARE_DIR) + + +def profile_handler(args, extra_args): + from genai_perf.wrapper import Profiler + + Profiler.run(args=args, extra_args=extra_args) + + +def compare_handler(args: argparse.Namespace): + """Handles `compare` subcommand workflow.""" + if args.files: + create_compare_dir() + output_dir = Path(f"{DEFAULT_COMPARE_DIR}") + PlotConfigParser.create_init_yaml_config(args.files, output_dir) + args.config = output_dir / "config.yaml" + + config_parser = PlotConfigParser(args.config) + plot_configs = config_parser.generate_configs() + plot_manager = PlotManager(plot_configs) + plot_manager.generate_plots() + + ### Entrypoint ### @@ -448,7 +563,7 @@ def parse_args(): description="CLI to profile LLMs and Generative AI models with Perf Analyzer", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.set_defaults(func=handler) + parser.set_defaults(func=profile_handler) # Conceptually group args for easier visualization _add_endpoint_args(parser) @@ -457,6 +572,12 @@ def parse_args(): _add_output_args(parser) _add_other_args(parser) + # Add subcommands + subparsers = parser.add_subparsers( + help="List of subparser commands.", dest="subcommand" + ) + compare_parser = _parse_compare_args(subparsers) + # Check for passthrough args if "--" in argv: passthrough_index = argv.index("--") @@ -466,7 +587,10 @@ def parse_args(): args = parser.parse_args(argv[1:passthrough_index]) args = _infer_prompt_source(args) + args = _check_model_args(parser, args) args = _check_conditional_args(parser, args) - args = _update_load_manager_args(args) + args = _check_compare_args(compare_parser, args) + args = _check_load_manager_args(args) + args = _set_artifact_paths(args) return args, argv[passthrough_index + 1 :] diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/plots/__init__.py b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/__init__.py new file mode 100755 index 000000000..086616e41 --- /dev/null +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/__init__.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/plots/base_plot.py b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/base_plot.py index a403e57f8..470e0b942 100755 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/plots/base_plot.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/base_plot.py @@ -25,14 +25,12 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from pathlib import Path +from typing import List -from copy import deepcopy -from typing import Dict - -from genai_perf.constants import DEFAULT_ARTIFACT_DIR +import pandas as pd from genai_perf.exceptions import GenAIPerfException -from genai_perf.llm_metrics import Statistics -from pandas import DataFrame +from genai_perf.plots.plot_config import ProfileRunData from plotly.graph_objects import Figure @@ -41,40 +39,44 @@ class BasePlot: Base class for plots """ - def __init__(self, stats: Statistics, extra_data: Dict | None = None) -> None: - self._stats = stats - self._metrics_data = deepcopy(stats.metrics.data) - if extra_data: - self._metrics_data = self._metrics_data | extra_data + def __init__(self, data: List[ProfileRunData]) -> None: + self._profile_data = data def create_plot( self, - x_key: str, - y_key: str, - x_metric: str, - y_metric: str, graph_title: str, x_label: str, y_label: str, + width: int, + height: int, filename_root: str, + output_dir: Path, ) -> None: """ Create plot for specific graph type """ raise NotImplementedError - def _generate_parquet(self, dataframe: DataFrame, file: str) -> None: - dataframe.to_parquet( - f"{DEFAULT_ARTIFACT_DIR}/data/{file}.gzip", compression="gzip" + def _create_dataframe(self, x_label: str, y_label: str) -> pd.DataFrame: + return pd.DataFrame( + { + x_label: [prd.x_metric for prd in self._profile_data], + y_label: [prd.y_metric for prd in self._profile_data], + "Run Name": [prd.name for prd in self._profile_data], + } ) - def _generate_graph_file(self, fig: Figure, file: str, title: str) -> None: + def _generate_parquet(self, df: pd.DataFrame, output_dir: Path, file: str) -> None: + filepath = output_dir / f"{file}.gzip" + df.to_parquet(filepath, compression="gzip") + + def _generate_graph_file(self, fig: Figure, output_dir: Path, file: str) -> None: if file.endswith("jpeg"): - print(f"Generating '{title}' jpeg") - fig.write_image(f"{DEFAULT_ARTIFACT_DIR}/plots/{file}") + filepath = output_dir / f"{file}" + fig.write_image(filepath) elif file.endswith("html"): - print(f"Generating '{title}' html") - fig.write_html(f"{DEFAULT_ARTIFACT_DIR}/plots/{file}") + filepath = output_dir / f"{file}" + fig.write_html(filepath) else: extension = file.split(".")[-1] raise GenAIPerfException(f"image file type {extension} is not supported") diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/plots/box_plot.py b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/box_plot.py index 5ba7d80b3..38aad36dc 100755 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/plots/box_plot.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/box_plot.py @@ -25,16 +25,12 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from pathlib import Path +from typing import List -import copy -from typing import Dict - -import pandas as pd -import plotly.express as px -from genai_perf.llm_metrics import Statistics +import plotly.graph_objects as go from genai_perf.plots.base_plot import BasePlot -from genai_perf.utils import scale -from plotly.graph_objects import Figure +from genai_perf.plots.plot_config import ProfileRunData class BoxPlot(BasePlot): @@ -42,88 +38,40 @@ class BoxPlot(BasePlot): Generate a box plot in jpeg and html format. """ - def __init__(self, stats: Statistics, extra_data: Dict | None = None) -> None: - super().__init__(stats, extra_data) + def __init__(self, data: List[ProfileRunData]) -> None: + super().__init__(data) def create_plot( self, - x_key: str = "", - y_key: str = "", - x_metric: str = "", - y_metric: str = "", graph_title: str = "", x_label: str = "", y_label: str = "", + width: int = 700, + height: int = 450, filename_root: str = "", + output_dir: Path = Path(""), ) -> None: - df = pd.DataFrame({y_metric: self._metrics_data[y_key]}) - fig = px.box( - df, - y=y_metric, - points="all", - title=graph_title, + fig = go.Figure() + for pd in self._profile_data: + fig.add_trace(go.Box(y=pd.y_metric, name=pd.name)) + + # Update layout and axis labels + fig.update_layout( + title={ + "text": f"{graph_title}", + "xanchor": "center", + "x": 0.5, + }, + width=width, + height=height, ) - fig.update_layout(title_x=0.5) - fig.update_xaxes(title_text=x_label) - - fig.update_yaxes(title_text="") - - # create a copy to avoid annotations on html file - fig_jpeg = copy.deepcopy(fig) - self._add_annotations(fig_jpeg, y_metric) - - self._generate_parquet(df, filename_root) - self._generate_graph_file(fig, filename_root + ".html", graph_title) - self._generate_graph_file(fig_jpeg, filename_root + ".jpeg", graph_title) + fig.update_traces(boxpoints="all") + fig.update_xaxes(title_text=x_label, showticklabels=False) + fig.update_yaxes(title_text=y_label) - def _add_annotations(self, fig: Figure, y_metric: str) -> None: - """ - Add annotations to the non html version of the box plot - to replace the missing hovertext - """ - stat_root_name = self._stats.metrics.get_base_name(y_metric) + # Save dataframe as parquet file + df = self._create_dataframe(x_label, y_label) + self._generate_parquet(df, output_dir, filename_root) - val = scale(self._stats.data[f"max_{stat_root_name}"], (1 / 1e9)) - fig.add_annotation( - x=0.5, - y=val, - text=f"max: {round(val, 2)}", - showarrow=False, - yshift=10, - ) - - val = scale(self._stats.data[f"p75_{stat_root_name}"], (1 / 1e9)) - fig.add_annotation( - x=0.5, - y=val, - text=f"q3: {round(val, 2)}", - showarrow=False, - yshift=10, - ) - - val = scale(self._stats.data[f"p50_{stat_root_name}"], (1 / 1e9)) - fig.add_annotation( - x=0.5, - y=val, - text=f"median: {round(val, 2)}", - showarrow=False, - yshift=10, - ) - - val = scale(self._stats.data[f"p25_{stat_root_name}"], (1 / 1e9)) - fig.add_annotation( - x=0.5, - y=val, - text=f"q1: {round(val, 2)}", - showarrow=False, - yshift=10, - ) - - val = scale(self._stats.data[f"min_{stat_root_name}"], (1 / 1e9)) - fig.add_annotation( - x=0.5, - y=val, - text=f"min: {round(val, 2)}", - showarrow=False, - yshift=10, - ) + self._generate_graph_file(fig, output_dir, filename_root + ".html") + self._generate_graph_file(fig, output_dir, filename_root + ".jpeg") diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/plots/heat_map.py b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/heat_map.py index ee5571ded..7f4dbe166 100755 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/plots/heat_map.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/heat_map.py @@ -25,12 +25,13 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import Dict, Optional +from pathlib import Path +from typing import List -import pandas as pd -import plotly.express as px -from genai_perf.llm_metrics import Statistics +import plotly.graph_objects as go from genai_perf.plots.base_plot import BasePlot +from genai_perf.plots.plot_config import ProfileRunData +from plotly.subplots import make_subplots class HeatMap(BasePlot): @@ -38,43 +39,62 @@ class HeatMap(BasePlot): Generate a heat map in jpeg and html format. """ - def __init__(self, stats: Statistics, extra_data: Optional[Dict] = None) -> None: - super().__init__(stats, extra_data) + def __init__(self, data: List[ProfileRunData]) -> None: + super().__init__(data) def create_plot( self, - x_key: str = "", - y_key: str = "", - x_metric: str = "", - y_metric: str = "", graph_title: str = "", x_label: str = "", y_label: str = "", + width: int = 700, + height: int = 450, filename_root: str = "", + output_dir: Path = Path(""), ) -> None: - x_values = self._metrics_data[x_key] - y_values = self._metrics_data[y_key] - df = pd.DataFrame( - { - x_metric: x_values, - y_metric: y_values, - } - ) - fig = px.density_heatmap( - df, - x=x_metric, - y=y_metric, + N = len(self._profile_data) + + if N <= 3: + n_rows, n_cols = 1, N + else: + n_rows = (N + 2) // 3 + n_cols = 3 + + fig = make_subplots( + rows=n_rows, + cols=n_cols, + x_title=x_label, + y_title=y_label, + subplot_titles=[prd.name for prd in self._profile_data], ) + + for index, prd in enumerate(self._profile_data): + hm = go.Histogram2d( + x=prd.x_metric, + y=prd.y_metric, + coloraxis="coloraxis", + name=prd.name, + ) + + # Calculate the location where the figure should be added in the subplot + c_row = int(index / n_cols) + 1 + c_col = index % n_cols + 1 + fig.add_trace(hm, c_row, c_col) + fig.update_layout( title={ "text": graph_title, "xanchor": "center", "x": 0.5, - } + }, + width=width, + height=height, ) - fig.update_xaxes(title_text=x_label) - fig.update_yaxes(title_text=y_label) - self._generate_parquet(df, filename_root) - self._generate_graph_file(fig, filename_root + ".html", graph_title) - self._generate_graph_file(fig, filename_root + ".jpeg", graph_title) + # Save dataframe as parquet file + df = self._create_dataframe(x_label, y_label) + self._generate_parquet(df, output_dir, filename_root) + + # self._generate_parquet(df, filename_root) + self._generate_graph_file(fig, output_dir, filename_root + ".html") + self._generate_graph_file(fig, output_dir, filename_root + ".jpeg") diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/plots/plot_config.py b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/plot_config.py new file mode 100755 index 000000000..2408d0591 --- /dev/null +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/plot_config.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum, auto +from pathlib import Path +from typing import List, Sequence, Union + + +class PlotType(Enum): + SCATTER = auto() + BOX = auto() + HEATMAP = auto() + + +@dataclass +class ProfileRunData: + name: str + x_metric: Sequence[Union[int, float]] + y_metric: Sequence[Union[int, float]] + + +@dataclass +class PlotConfig: + title: str + data: List[ProfileRunData] + x_label: str + y_label: str + width: int + height: int + type: PlotType + output: Path diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/plots/plot_config_parser.py b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/plot_config_parser.py new file mode 100755 index 000000000..c174024a2 --- /dev/null +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/plot_config_parser.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from pathlib import Path +from typing import List, Union + +import genai_perf.logging as logging + +# Skip type checking to avoid mypy error +# Issue: https://github.com/python/mypy/issues/10632 +import yaml # type: ignore +from genai_perf.llm_metrics import LLMProfileDataParser, Statistics +from genai_perf.plots.plot_config import PlotConfig, PlotType, ProfileRunData +from genai_perf.tokenizer import DEFAULT_TOKENIZER, get_tokenizer +from genai_perf.utils import load_yaml, scale + +logger = logging.getLogger(__name__) + + +class PlotConfigParser: + """Parses YAML configuration file to generate PlotConfigs.""" + + def __init__(self, filename: Path) -> None: + self._filename = filename + + def generate_configs(self) -> List[PlotConfig]: + """Load YAML configuration file and convert to PlotConfigs.""" + logger.info( + f"Generating plot configurations by parsing {self._filename}. " + "This may take a few seconds.", + ) + configs = load_yaml(self._filename) + + plot_configs = [] + for _, config in configs.items(): + # Collect profile run data + profile_data: List[ProfileRunData] = [] + for filepath in config["paths"]: + stats = self._get_statistics(filepath) + profile_data.append( + ProfileRunData( + name=self._get_run_name(Path(filepath)), + x_metric=self._get_metric(stats, config["x_metric"]), + y_metric=self._get_metric(stats, config["y_metric"]), + ) + ) + + plot_configs.append( + PlotConfig( + title=config["title"], + data=profile_data, + x_label=config["x_label"], + y_label=config["y_label"], + width=config["width"], + height=config["height"], + type=self._get_plot_type(config["type"]), + output=Path(config["output"]), + ) + ) + + return plot_configs + + def _get_statistics(self, filepath: str) -> Statistics: + """Extract a single profile run data.""" + data_parser = LLMProfileDataParser( + filename=Path(filepath), + tokenizer=get_tokenizer(DEFAULT_TOKENIZER), + ) + load_info = data_parser.get_profile_load_info() + + # TMA-1904: Remove single experiment assumption + assert len(load_info) == 1 + infer_mode, load_level = load_info[0] + stats = data_parser.get_statistics(infer_mode, load_level) + return stats + + def _get_run_name(self, filepath: Path) -> str: + """Construct a profile run name.""" + if filepath.parent.name: + return filepath.parent.name + "/" + filepath.stem + return filepath.stem + + def _get_metric(self, stats: Statistics, name: str) -> List[Union[int, float]]: + if not name: # no metric + return [] + elif name == "inter_token_latencies": + # Flatten ITL since they are grouped by request + itl_flatten = [] + for request_itls in stats.metrics.data[name]: + itl_flatten += request_itls + return [scale(x, (1 / 1e6)) for x in itl_flatten] # ns to ms + elif name == "token_positions": + token_positions: List[Union[int, float]] = [] + for request_itls in stats.metrics.data["inter_token_latencies"]: + token_positions += list(range(1, len(request_itls) + 1)) + return token_positions + elif name == "time_to_first_tokens": + ttfts = stats.metrics.data[name] + return [scale(x, (1 / 1e6)) for x in ttfts] # ns to ms + elif name == "request_latencies": + req_latencies = stats.metrics.data[name] + return [scale(x, (1 / 1e6)) for x in req_latencies] # ns to ms + + return stats.metrics.data[name] + + def _get_plot_type(self, plot_type: str) -> PlotType: + """Returns the plot type as PlotType object.""" + if plot_type == "scatter": + return PlotType.SCATTER + elif plot_type == "box": + return PlotType.BOX + elif plot_type == "heatmap": + return PlotType.HEATMAP + else: + raise ValueError( + "Unknown plot type encountered while parsing YAML configuration. " + "Plot type must be either 'scatter', 'box', or 'heatmap'." + ) + + @staticmethod + def create_init_yaml_config(filenames: List[Path], output_dir: Path) -> None: + config_str = f""" + plot1: + title: Time to First Token + x_metric: "" + y_metric: time_to_first_tokens + x_label: Time to First Token (ms) + y_label: "" + width: {1200 if len(filenames) > 1 else 700} + height: 450 + type: box + paths: {[str(f) for f in filenames]} + output: {output_dir} + + plot2: + title: Request Latency + x_metric: "" + y_metric: request_latencies + x_label: Request Latency (ms) + y_label: "" + width: {1200 if len(filenames) > 1 else 700} + height: 450 + type: box + paths: {[str(f) for f in filenames]} + output: {output_dir} + + plot3: + title: Distribution of Input Tokens to Generated Tokens + x_metric: num_input_tokens + y_metric: num_output_tokens + x_label: Number of Input Tokens Per Request + y_label: Number of Generated Tokens Per Request + width: {1200 if len(filenames) > 1 else 700} + height: 450 + type: heatmap + paths: {[str(f) for f in filenames]} + output: {output_dir} + + plot4: + title: Time to First Token vs Number of Input Tokens + x_metric: num_input_tokens + y_metric: time_to_first_tokens + x_label: Number of Input Tokens + y_label: Time to First Token (ms) + width: {1200 if len(filenames) > 1 else 700} + height: 450 + type: scatter + paths: {[str(f) for f in filenames]} + output: {output_dir} + + plot5: + title: Token-to-Token Latency vs Output Token Position + x_metric: token_positions + y_metric: inter_token_latencies + x_label: Output Token Position + y_label: Token-to-Token Latency (ms) + width: {1200 if len(filenames) > 1 else 700} + height: 450 + type: scatter + paths: {[str(f) for f in filenames]} + output: {output_dir} + """ + + filepath = output_dir / "config.yaml" + logger.info(f"Creating initial YAML configuration file to {filepath}") + config = yaml.safe_load(config_str) + with open(str(filepath), "w") as f: + yaml.dump(config, f, sort_keys=False) diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/plots/plot_manager.py b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/plot_manager.py index 395830bc9..e548a7de7 100755 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/plots/plot_manager.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/plot_manager.py @@ -25,12 +25,15 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from typing import List -from genai_perf.llm_metrics import Statistics +import genai_perf.logging as logging from genai_perf.plots.box_plot import BoxPlot from genai_perf.plots.heat_map import HeatMap +from genai_perf.plots.plot_config import PlotConfig, PlotType from genai_perf.plots.scatter_plot import ScatterPlot -from genai_perf.utils import scale + +logger = logging.getLogger(__name__) class PlotManager: @@ -38,81 +41,47 @@ class PlotManager: Manage details around plots generated """ - def __init__(self, stats: Statistics) -> None: - self._stats = stats - - def create_default_plots(self): - y_metric = "time_to_first_tokens" - y_key = "time_to_first_tokens_scaled" - scaled_data = [scale(x, (1 / 1e9)) for x in self._stats.metrics.data[y_metric]] - extra_data = {y_key: scaled_data} - bp_ttft = BoxPlot(self._stats, extra_data) - bp_ttft.create_plot( - y_key=y_key, - y_metric=y_metric, - graph_title="Time to First Token", - filename_root="time_to_first_token", - x_label="Time to First Token (seconds)", - ) + def __init__(self, plot_configs: List[PlotConfig]) -> None: + self._plot_configs = plot_configs - y_metric = "request_latencies" - y_key = "request_latencies_scaled" - scaled_data = [scale(x, (1 / 1e9)) for x in self._stats.metrics.data[y_metric]] - extra_data = {y_key: scaled_data} - bp_req_lat = BoxPlot(self._stats, extra_data) - bp_req_lat.create_plot( - y_key=y_key, - y_metric=y_metric, - graph_title="Request Latency", - filename_root="request_latency", - x_label="Request Latency (seconds)", - ) + def _generate_filename(self, title: str) -> str: + filename = "_".join(title.lower().split()) + return filename - hm = HeatMap(self._stats) - hm.create_plot( - x_key="num_input_tokens", - y_key="num_output_tokens", - x_metric="input_tokens", - y_metric="generated_tokens", - graph_title="Distribution of Input Tokens to Generated Tokens", - x_label="Number of Input Tokens Per Request", - y_label="Number of Generated Tokens Per Request", - filename_root="input_tokens_vs_generated_tokens", - ) + def generate_plots(self) -> None: + for plot_config in self._plot_configs: + logger.info(f"Generating '{plot_config.title}' plot") + if plot_config.type == PlotType.BOX: + bp = BoxPlot(plot_config.data) + bp.create_plot( + graph_title=plot_config.title, + x_label=plot_config.x_label, + width=plot_config.width, + height=plot_config.height, + filename_root=self._generate_filename(plot_config.title), + output_dir=plot_config.output, + ) - x_metric = "num_input_tokens" - y_metric = "time_to_first_tokens" - y_key = "time_to_first_tokens_scaled" - scaled_data = [scale(x, (1 / 1e9)) for x in self._stats.metrics.data[y_metric]] - extra_data = {y_key: scaled_data} - sp_ttft_vs_input_tokens = ScatterPlot(self._stats, extra_data) - sp_ttft_vs_input_tokens.create_plot( - x_key=x_metric, - y_key=y_key, - x_metric=x_metric, - y_metric=y_metric, - graph_title="Time to First Token vs Number of Input Tokens", - x_label="Number of Input Tokens", - y_label="Time to First Token (seconds)", - filename_root="ttft_vs_input_tokens", - ) + elif plot_config.type == PlotType.HEATMAP: + hm = HeatMap(plot_config.data) + hm.create_plot( + graph_title=plot_config.title, + x_label=plot_config.x_label, + y_label=plot_config.y_label, + width=plot_config.width, + height=plot_config.height, + filename_root=self._generate_filename(plot_config.title), + output_dir=plot_config.output, + ) - itl_latencies = self._stats.metrics.data["inter_token_latencies"] - x_data = [] - y_data = [] - for itl_latency_list in itl_latencies: - for index, latency in enumerate(itl_latency_list): - x_data.append(index + 1) - y_data.append(latency / 1e9) - x_key = "token_position" - y_key = "inter_token_latency" - extra_data = {x_key: x_data, y_key: y_data} - sp_tot_v_tok_pos = ScatterPlot(self._stats, extra_data) - sp_tot_v_tok_pos.create_plot( - x_key=x_key, - y_key=y_key, - graph_title="Token-to-Token Latency vs Output Token Position", - x_label="Output Token Position", - y_label="Token-to-Token Latency (seconds)", - filename_root="token_to_token_vs_output_position", - ) + elif plot_config.type == PlotType.SCATTER: + sp = ScatterPlot(plot_config.data) + sp.create_plot( + graph_title=plot_config.title, + x_label=plot_config.x_label, + y_label=plot_config.y_label, + width=plot_config.width, + height=plot_config.height, + filename_root=self._generate_filename(plot_config.title), + output_dir=plot_config.output, + ) diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/plots/scatter_plot.py b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/scatter_plot.py index ecd78bc2f..35dca8fc3 100755 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/plots/scatter_plot.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/plots/scatter_plot.py @@ -25,12 +25,12 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import Dict, Optional +from pathlib import Path +from typing import List -import pandas as pd -import plotly.express as px -from genai_perf.llm_metrics import Statistics +import plotly.graph_objects as go from genai_perf.plots.base_plot import BasePlot +from genai_perf.plots.plot_config import ProfileRunData class ScatterPlot(BasePlot): @@ -38,47 +38,45 @@ class ScatterPlot(BasePlot): Generate a scatter plot in jpeg and html format. """ - def __init__(self, stats: Statistics, extra_data: Optional[Dict] = None) -> None: - super().__init__(stats, extra_data) + def __init__(self, data: List[ProfileRunData]) -> None: + super().__init__(data) def create_plot( self, - x_key: str = "", - y_key: str = "", - x_metric: str = "", - y_metric: str = "", graph_title: str = "", x_label: str = "", y_label: str = "", + width: int = 700, + height: int = 450, filename_root: str = "", + output_dir: Path = Path(""), ) -> None: - x_values = self._metrics_data[x_key] - y_values = self._metrics_data[y_key] - - df = pd.DataFrame( - { - x_key: x_values, - y_key: y_values, - } - ) - - fig = px.scatter( - df, - x=x_key, - y=y_key, - trendline="ols", - ) + fig = go.Figure() + for pd in self._profile_data: + fig.add_trace( + go.Scatter( + x=pd.x_metric, + y=pd.y_metric, + mode="markers", + name=pd.name, + ) + ) fig.update_layout( title={ "text": f"{graph_title}", "xanchor": "center", "x": 0.5, - } + }, + width=width, + height=height, ) fig.update_xaxes(title_text=f"{x_label}") fig.update_yaxes(title_text=f"{y_label}") - self._generate_parquet(df, filename_root) - self._generate_graph_file(fig, filename_root + ".html", graph_title) - self._generate_graph_file(fig, filename_root + ".jpeg", graph_title) + # Save dataframe as parquet file + df = self._create_dataframe(x_label, y_label) + self._generate_parquet(df, output_dir, filename_root) + + self._generate_graph_file(fig, output_dir, filename_root + ".html") + self._generate_graph_file(fig, output_dir, filename_root + ".jpeg") diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py b/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py index cf629e9a2..a10befe13 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py @@ -27,15 +27,28 @@ import json from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type + +# Skip type checking to avoid mypy error +# Issue: https://github.com/python/mypy/issues/10632 +import yaml # type: ignore def remove_sse_prefix(msg: str) -> str: - return msg.removeprefix("data: ").strip() + prefix = "data: " + if msg.startswith(prefix): + return msg[len(prefix) :].strip() + return msg.strip() + + +def load_yaml(filepath: Path) -> Dict[str, Any]: + with open(str(filepath)) as f: + configs = yaml.safe_load(f) + return configs -def load_json(filename: str) -> Dict[str, Any]: - with open(filename, encoding="utf-8", errors="ignore") as f: +def load_json(filepath: Path) -> Dict[str, Any]: + with open(str(filepath), encoding="utf-8", errors="ignore") as f: return json.load(f) @@ -48,14 +61,14 @@ def convert_option_name(name: str) -> str: return name.replace("_", "-") -def get_enum_names(enum: type[Enum]) -> List: +def get_enum_names(enum: Type[Enum]) -> List: names = [] for e in enum: names.append(e.name.lower()) return names -def get_enum_entry(name: str, enum: type[Enum]) -> Optional[Enum]: +def get_enum_entry(name: str, enum: Type[Enum]) -> Optional[Enum]: for e in enum: if e.name.lower() == name.lower(): return e diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py b/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py index a52bfa611..fa0049118 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py @@ -26,6 +26,7 @@ import subprocess from argparse import Namespace +from typing import List, Optional import genai_perf.logging as logging import genai_perf.utils as utils @@ -37,8 +38,8 @@ class Profiler: @staticmethod - def add_protocol_args(args: Namespace): - cmd = [""] + def add_protocol_args(args: Namespace) -> List[str]: + cmd = [] if args.service_kind == "triton": cmd += ["-i", "grpc", "--streaming"] if args.u is None: # url @@ -50,7 +51,16 @@ def add_protocol_args(args: Namespace): return cmd @staticmethod - def build_cmd(args: Namespace, extra_args: list[str] | None = None) -> list[str]: + def add_inference_load_args(args: Namespace) -> List[str]: + cmd = [] + if args.concurrency: + cmd += ["--concurrency-range", f"{args.concurrency}"] + elif args.request_rate: + cmd += ["--request-rate-range", f"{args.request_rate}"] + return cmd + + @staticmethod + def build_cmd(args: Namespace, extra_args: Optional[List[str]] = None) -> List[str]: skip_args = [ "func", "input_dataset", @@ -76,6 +86,10 @@ def build_cmd(args: Namespace, extra_args: list[str] | None = None) -> list[str] "tokenizer", "endpoint_type", "generate_plots", + "subcommand", + "concurrency", + "request_rate", + "artifact_dir", ] utils.remove_file(args.profile_export_file) @@ -86,7 +100,7 @@ def build_cmd(args: Namespace, extra_args: list[str] | None = None) -> list[str] f"{args.model}", f"--async", f"--input-data", - f"{DEFAULT_INPUT_DATA_JSON}", + f"{args.artifact_dir / DEFAULT_INPUT_DATA_JSON}", ] for arg, value in vars(args).items(): if arg in skip_args: @@ -108,6 +122,7 @@ def build_cmd(args: Namespace, extra_args: list[str] | None = None) -> list[str] cmd += [f"--{arg}", f"{value}"] cmd += Profiler.add_protocol_args(args) + cmd += Profiler.add_inference_load_args(args) if extra_args is not None: for arg in extra_args: @@ -115,7 +130,7 @@ def build_cmd(args: Namespace, extra_args: list[str] | None = None) -> list[str] return cmd @staticmethod - def run(args: Namespace, extra_args: list[str] | None) -> None: + def run(args: Namespace, extra_args: Optional[List[str]]) -> None: cmd = Profiler.build_cmd(args, extra_args) logger.info(f"Running Perf Analyzer : '{' '.join(cmd)}'") if args and args.verbose: diff --git a/src/c++/perf_analyzer/genai-perf/pyproject.toml b/src/c++/perf_analyzer/genai-perf/pyproject.toml index 22e9c6a0c..b8068bd7c 100644 --- a/src/c++/perf_analyzer/genai-perf/pyproject.toml +++ b/src/c++/perf_analyzer/genai-perf/pyproject.toml @@ -57,6 +57,7 @@ dependencies = [ "pyarrow", "fastparquet", "pytest-mock", + "pyyaml", ] # CLI Entrypoint diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_cli.py b/src/c++/perf_analyzer/genai-perf/tests/test_cli.py index 1e21c1d9f..bca4cac01 100644 --- a/src/c++/perf_analyzer/genai-perf/tests/test_cli.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_cli.py @@ -26,12 +26,16 @@ from pathlib import Path +import genai_perf.logging as logging import pytest from genai_perf import __version__, parser from genai_perf.llm_inputs.llm_inputs import OutputFormat, PromptSource class TestCLIArguments: + # ================================================ + # GENAI-PERF COMMAND + # ================================================ expected_help_output = ( "CLI to profile LLMs and Generative AI models with Perf Analyzer" ) @@ -66,7 +70,7 @@ def test_help_version_arguments_output_and_exit( @pytest.mark.parametrize( "arg, expected_attributes", [ - (["--concurrency", "3"], {"concurrency_range": "3"}), + (["--concurrency", "3"], {"concurrency": 3}), ( ["--endpoint-type", "completions", "--service-kind", "openai"], {"endpoint": "v1/completions"}, @@ -146,11 +150,15 @@ def test_help_version_arguments_output_and_exit( (["-p", "100"], {"measurement_interval": 100}), (["--num-prompts", "101"], {"num_prompts": 101}), ( - ["--profile-export-file", "text.txt"], - {"profile_export_file": Path("text.txt")}, + ["--profile-export-file", "test.json"], + { + "profile_export_file": Path( + "artifacts/test_model-triton-tensorrtllm-concurrency1/test.json" + ) + }, ), (["--random-seed", "8"], {"random_seed": 8}), - (["--request-rate", "9.0"], {"request_rate_range": "9.0"}), + (["--request-rate", "9.0"], {"request_rate": 9.0}), (["--service-kind", "triton"], {"service_kind": "triton"}), ( ["--service-kind", "openai", "--endpoint-type", "chat"], @@ -163,6 +171,10 @@ def test_help_version_arguments_output_and_exit( (["-v"], {"verbose": True}), (["--url", "test_url"], {"u": "test_url"}), (["-u", "test_url"], {"u": "test_url"}), + ( + ["--artifact-dir", "test_artifact_dir"], + {"artifact_dir": Path("test_artifact_dir")}, + ), ], ) def test_non_file_flags_parsed(self, monkeypatch, arg, expected_attributes, capsys): @@ -193,10 +205,93 @@ def test_file_flags_parsed(self, monkeypatch, mocker): args.input_file == mocked_open.return_value ), "The file argument should be the mock object" + @pytest.mark.parametrize( + "arg, expected_path", + [ + ( + ["--service-kind", "openai", "--endpoint-type", "chat"], + "artifacts/test_model-openai-chat-concurrency1", + ), + ( + ["--service-kind", "openai", "--endpoint-type", "completions"], + "artifacts/test_model-openai-completions-concurrency1", + ), + ( + ["--service-kind", "triton", "--backend", "tensorrtllm"], + "artifacts/test_model-triton-tensorrtllm-concurrency1", + ), + ( + ["--service-kind", "triton", "--backend", "vllm"], + "artifacts/test_model-triton-vllm-concurrency1", + ), + ( + [ + "--service-kind", + "triton", + "--backend", + "vllm", + "--concurrency", + "32", + ], + "artifacts/test_model-triton-vllm-concurrency32", + ), + ], + ) + def test_default_profile_export_filepath( + self, monkeypatch, arg, expected_path, capsys + ): + combined_args = ["genai-perf", "--model", "test_model"] + arg + monkeypatch.setattr("sys.argv", combined_args) + args, extra_args = parser.parse_args() + + assert args.artifact_dir == Path(expected_path) + captured = capsys.readouterr() + assert captured.out == "" + + @pytest.mark.parametrize( + "arg, expected_path, expected_output", + [ + ( + ["--model", "strange/test_model"], + "artifacts/strange_test_model-triton-tensorrtllm-concurrency1", + ( + "Model name 'strange/test_model' cannot be used to create " + "artifact directory. Instead, 'strange_test_model' will be used" + ), + ), + ( + [ + "--model", + "hello/world/test_model", + "--service-kind", + "openai", + "--endpoint-type", + "chat", + ], + "artifacts/hello_world_test_model-openai-chat-concurrency1", + ( + "Model name 'hello/world/test_model' cannot be used to create " + "artifact directory. Instead, 'hello_world_test_model' will be used" + ), + ), + ], + ) + def test_model_name_artifact_path( + self, monkeypatch, arg, expected_path, expected_output, capsys + ): + logging.init_logging() + combined_args = ["genai-perf"] + arg + monkeypatch.setattr("sys.argv", combined_args) + args, extra_args = parser.parse_args() + + assert args.artifact_dir == Path(expected_path) + captured = capsys.readouterr() + assert expected_output in captured.out + def test_default_load_level(self, monkeypatch, capsys): monkeypatch.setattr("sys.argv", ["genai-perf", "--model", "test_model"]) args, extra_args = parser.parse_args() - assert getattr(args, "concurrency_range") == "1" + assert args.concurrency == 1 captured = capsys.readouterr() assert captured.out == "" @@ -217,7 +312,7 @@ def test_load_level_mutually_exclusive(self, monkeypatch, capsys): def test_model_not_provided(self, monkeypatch, capsys): monkeypatch.setattr("sys.argv", ["genai-perf"]) - expected_output = "the following arguments are required: -m/--model" + expected_output = "The -m/--model option is required and cannot be empty." with pytest.raises(SystemExit) as excinfo: parser.parse_args() @@ -437,3 +532,70 @@ def test_prompt_source_assertions(self, monkeypatch, mocker, capsys): assert excinfo.value.code != 0 captured = capsys.readouterr() assert expected_output in captured.err + + # ================================================ + # COMPARE SUBCOMMAND + # ================================================ + expected_compare_help_output = ( + "Subcommand to generate plots that compare multiple profile runs." + ) + + @pytest.mark.parametrize( + "args, expected_output", + [ + (["-h"], expected_compare_help_output), + (["--help"], expected_compare_help_output), + ], + ) + def test_compare_help_arguments_output_and_exit( + self, monkeypatch, args, expected_output, capsys + ): + monkeypatch.setattr("sys.argv", ["genai-perf", "compare"] + args) + + with pytest.raises(SystemExit) as excinfo: + _ = parser.parse_args() + + # Check that the exit was successful + assert excinfo.value.code == 0 + + # Capture that the correct message was displayed + captured = capsys.readouterr() + assert expected_output in captured.out + + def test_compare_mutually_exclusive(self, monkeypatch, capsys): + args = ["genai-perf", "compare", "--config", "hello", "--files", "a", "b", "c"] + monkeypatch.setattr("sys.argv", args) + expected_output = "argument -f/--files: not allowed with argument --config" + + with pytest.raises(SystemExit) as excinfo: + parser.parse_args() + + assert excinfo.value.code != 0 + captured = capsys.readouterr() + assert expected_output in captured.err + + def test_compare_not_provided(self, monkeypatch, capsys): + args = ["genai-perf", "compare"] + monkeypatch.setattr("sys.argv", args) + expected_output = "Either the --config or --files option must be specified." + + with pytest.raises(SystemExit) as excinfo: + parser.parse_args() + + assert excinfo.value.code != 0 + captured = capsys.readouterr() + assert expected_output in captured.err + + @pytest.mark.parametrize( + "args, expected_model", + [ + (["--files", "profile1.json", "profile2.json", "profile3.json"], None), + (["--config", "config.yaml"], None), + ], + ) + def test_compare_model_arg(self, monkeypatch, args, expected_model): + combined_args = ["genai-perf", "compare"] + args + monkeypatch.setattr("sys.argv", combined_args) + args, _ = parser.parse_args() + + assert args.model == expected_model diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py b/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py index 3e1c0bbb3..c57bd13d3 100755 --- a/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py @@ -28,7 +28,8 @@ import json from io import StringIO -from typing import Any, List +from pathlib import Path +from typing import Any, List, Union import numpy as np import pytest @@ -38,7 +39,7 @@ from transformers import AutoTokenizer -def ns_to_sec(ns: int) -> int | float: +def ns_to_sec(ns: int) -> Union[int, float]: """Convert from nanosecond to second.""" return ns / 1e9 @@ -73,6 +74,9 @@ def write(self: Any, content: str) -> int: elif filename == "openai_profile_export.json": tmp_file = StringIO(json.dumps(self.openai_profile_data)) return tmp_file + elif filename == "empty_profile_export.json": + tmp_file = StringIO(json.dumps(self.empty_profile_data)) + return tmp_file elif filename == "profile_export.csv": tmp_file = StringIO() tmp_file.write = write.__get__(tmp_file) @@ -92,9 +96,7 @@ def test_csv_output(self, mock_read_write: pytest.MonkeyPatch) -> None: tokenizer = get_tokenizer(DEFAULT_TOKENIZER) pd = LLMProfileDataParser( - filename="triton_profile_export.json", - service_kind="triton", - output_format=OutputFormat.TENSORRTLLM, + filename=Path("triton_profile_export.json"), tokenizer=tokenizer, ) stat = pd.get_statistics(infer_mode="concurrency", load_level="10") @@ -147,9 +149,7 @@ def test_triton_llm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> N """ tokenizer = get_tokenizer(DEFAULT_TOKENIZER) pd = LLMProfileDataParser( - filename="triton_profile_export.json", - service_kind="triton", - output_format=OutputFormat.TENSORRTLLM, + filename=Path("triton_profile_export.json"), tokenizer=tokenizer, ) @@ -290,9 +290,7 @@ def test_openai_llm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> N """ tokenizer = get_tokenizer(DEFAULT_TOKENIZER) pd = LLMProfileDataParser( - filename="openai_profile_export.json", - service_kind="openai", - output_format=OutputFormat.OPENAI_CHAT_COMPLETIONS, + filename=Path("openai_profile_export.json"), tokenizer=tokenizer, ) @@ -375,9 +373,7 @@ def test_merged_sse_response(self, mock_read_write: pytest.MonkeyPatch) -> None: tokenizer = get_tokenizer(DEFAULT_TOKENIZER) pd = LLMProfileDataParser( - filename="openai_profile_export.json", - service_kind="openai", - output_format=OutputFormat.OPENAI_CHAT_COMPLETIONS, + filename=Path("openai_profile_export.json"), tokenizer=tokenizer, ) @@ -408,7 +404,50 @@ def test_llm_metrics_get_base_name(self) -> None: with pytest.raises(KeyError): metrics.get_base_name("hello1234") + def test_empty_response(self, mock_read_write: pytest.MonkeyPatch) -> None: + """Check if it handles all empty responses.""" + tokenizer = get_tokenizer(DEFAULT_TOKENIZER) + + # Should not throw error + _ = LLMProfileDataParser( + filename=Path("empty_profile_export.json"), + tokenizer=tokenizer, + ) + + empty_profile_data = { + "service_kind": "openai", + "endpoint": "v1/chat/completions", + "experiments": [ + { + "experiment": { + "mode": "concurrency", + "value": 10, + }, + "requests": [ + { + "timestamp": 1, + "request_inputs": { + "payload": '{"messages":[{"role":"user","content":"This is test"}],"model":"llama-2-7b","stream":true}', + }, + "response_timestamps": [3, 5, 8], + "response_outputs": [ + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","created":123,"model":"llama-2-7b","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","created":123,"model":"llama-2-7b","choices":[{"index":0,"delta":{"content":""},"finish_reason":null}]}\n\n' + }, + {"response": "data: [DONE]\n\n"}, + ], + }, + ], + }, + ], + } + openai_profile_data = { + "service_kind": "openai", + "endpoint": "v1/chat/completions", "experiments": [ { "experiment": { @@ -474,6 +513,8 @@ def test_llm_metrics_get_base_name(self) -> None: } triton_profile_data = { + "service_kind": "triton", + "endpoint": "", "experiments": [ { "experiment": { diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_plot_configs.py b/src/c++/perf_analyzer/genai-perf/tests/test_plot_configs.py new file mode 100644 index 000000000..1e1391e4c --- /dev/null +++ b/src/c++/perf_analyzer/genai-perf/tests/test_plot_configs.py @@ -0,0 +1,112 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from pathlib import Path + +# Skip type checking to avoid mypy error +# Issue: https://github.com/python/mypy/issues/10632 +import yaml # type: ignore +from genai_perf.plots.plot_config import PlotType +from genai_perf.plots.plot_config_parser import PlotConfigParser + + +class TestPlotConfigParser: + yaml_config = """ + plot1: + title: TTFT vs ITL + x_metric: time_to_first_tokens + y_metric: inter_token_latencies + x_label: TTFT (ms) + y_label: ITL (ms) + width: 1000 + height: 3000 + type: box + paths: + - run1/concurrency32.json + - run2/concurrency32.json + - run3/concurrency32.json + output: test_output_1 + + plot2: + title: Num Input Token vs Num Output Token + x_metric: num_input_tokens + y_metric: num_output_tokens + x_label: Input Tokens + y_label: Output Tokens + width: 1234 + height: 5678 + type: scatter + paths: + - run4/concurrency1.json + output: test_output_2 + """ + + def test_generate_configs(self, monkeypatch) -> None: + monkeypatch.setattr( + "genai_perf.plots.plot_config_parser.load_yaml", + lambda _: yaml.safe_load(self.yaml_config), + ) + monkeypatch.setattr(PlotConfigParser, "_get_statistics", lambda *_: {}) + monkeypatch.setattr(PlotConfigParser, "_get_metric", lambda *_: [1, 2, 3]) + + config_parser = PlotConfigParser(Path("test_config.yaml")) + plot_configs = config_parser.generate_configs() + + assert len(plot_configs) == 2 + pc1, pc2 = plot_configs + + # plot config 1 + assert pc1.title == "TTFT vs ITL" + assert pc1.x_label == "TTFT (ms)" + assert pc1.y_label == "ITL (ms)" + assert pc1.width == 1000 + assert pc1.height == 3000 + assert pc1.type == PlotType.BOX + assert pc1.output == Path("test_output_1") + + assert len(pc1.data) == 3 # profile run data + prd1, prd2, prd3 = pc1.data + assert prd1.name == "run1/concurrency32" + assert prd2.name == "run2/concurrency32" + assert prd3.name == "run3/concurrency32" + for prd in pc1.data: + assert prd.x_metric == [1, 2, 3] + assert prd.y_metric == [1, 2, 3] + + # plot config 2 + assert pc2.title == "Num Input Token vs Num Output Token" + assert pc2.x_label == "Input Tokens" + assert pc2.y_label == "Output Tokens" + assert pc2.width == 1234 + assert pc2.height == 5678 + assert pc2.type == PlotType.SCATTER + assert pc2.output == Path("test_output_2") + + assert len(pc2.data) == 1 # profile run data + prd = pc2.data[0] + assert prd.name == "run4/concurrency1" + assert prd.x_metric == [1, 2, 3] + assert prd.y_metric == [1, 2, 3] diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_wrapper.py b/src/c++/perf_analyzer/genai-perf/tests/test_wrapper.py index 66044477d..184a47f11 100644 --- a/src/c++/perf_analyzer/genai-perf/tests/test_wrapper.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_wrapper.py @@ -52,6 +52,33 @@ def test_url_exactly_once_triton(self, monkeypatch, arg): number_of_url_args = cmd_string.count(" -u ") + cmd_string.count(" --url ") assert number_of_url_args == 1 + @pytest.mark.parametrize( + "arg, expected_filepath", + [ + ( + [], + "artifacts/test_model-triton-tensorrtllm-concurrency1/profile_export.json", + ), + ( + ["--artifact-dir", "test_dir"], + "test_dir/profile_export.json", + ), + ( + ["--artifact-dir", "test_dir", "--profile-export-file", "test.json"], + "test_dir/test.json", + ), + ], + ) + def test_profile_export_filepath(self, monkeypatch, arg, expected_filepath): + args = ["genai-perf", "-m", "test_model", "--service-kind", "triton"] + arg + monkeypatch.setattr("sys.argv", args) + args, extra_args = parser.parse_args() + cmd = Profiler.build_cmd(args, extra_args) + cmd_string = " ".join(cmd) + + expected_pattern = f"--profile-export-file {expected_filepath}" + assert expected_pattern in cmd_string + @pytest.mark.parametrize( "arg", [ diff --git a/src/c++/perf_analyzer/inference_profiler.cc b/src/c++/perf_analyzer/inference_profiler.cc index 46e2bcb52..bd497caf5 100644 --- a/src/c++/perf_analyzer/inference_profiler.cc +++ b/src/c++/perf_analyzer/inference_profiler.cc @@ -107,6 +107,14 @@ EnsembleDurations GetTotalEnsembleDurations(const ServerSideStats& stats) { EnsembleDurations result; + // Calculate avg cache hit latency and cache miss latency for ensemble model + // in case top level response caching is enabled. + const uint64_t ensemble_cache_hit_cnt = stats.cache_hit_count; + const uint64_t ensemble_cache_miss_cnt = stats.cache_miss_count; + result.total_cache_hit_time_avg_us += + AverageDurationInUs(stats.cache_hit_time_ns, ensemble_cache_hit_cnt); + result.total_cache_miss_time_avg_us += + AverageDurationInUs(stats.cache_miss_time_ns, ensemble_cache_miss_cnt); for (const auto& model_stats : stats.composing_models_stat) { if (model_stats.second.composing_models_stat.empty()) { // Cache hit count covers cache hits, not related to compute times @@ -238,7 +246,6 @@ ReportServerSideStats( if (parser->ResponseCacheEnabled()) { const uint64_t overhead_avg_us = GetOverheadDuration( cumm_avg_us, queue_avg_us, combined_cache_compute_avg_us); - std::cout << " (overhead " << overhead_avg_us << " usec + " << "queue " << queue_avg_us << " usec + " << "cache hit/miss " << combined_cache_compute_avg_us @@ -283,12 +290,18 @@ ReportServerSideStats( const uint64_t overhead_avg_us = GetOverheadDuration( cumm_avg_us, ensemble_times.total_queue_time_avg_us, ensemble_times.total_combined_cache_compute_time_avg_us); - std::cout << " (overhead " << overhead_avg_us << " usec + " - << "queue " << ensemble_times.total_queue_time_avg_us - << " usec + " - << "cache hit/miss " - << ensemble_times.total_combined_cache_compute_time_avg_us - << " usec)" << std::endl; + // FIXME - Refactor these calculations in case of ensemble top level + // response cache is enabled + if (!parser->TopLevelResponseCachingEnabled()) { + std::cout << " (overhead " << overhead_avg_us << " usec + " + << "queue " << ensemble_times.total_queue_time_avg_us + << " usec + " + << "cache hit/miss " + << ensemble_times.total_combined_cache_compute_time_avg_us + << " usec)" << std::endl; + } else { + std::cout << std::endl; + } std::cout << ident << ident << " Average Cache Hit Latency: " << ensemble_times.total_cache_hit_time_avg_us << " usec" << std::endl; @@ -533,7 +546,7 @@ InferenceProfiler::InferenceProfiler( cb::Error InferenceProfiler::Profile( - const size_t concurrent_request_count, + const size_t concurrent_request_count, const size_t request_count, std::vector& perf_statuses, bool& meets_threshold, bool& is_stable) { @@ -545,10 +558,11 @@ InferenceProfiler::Profile( is_stable = false; meets_threshold = true; - RETURN_IF_ERROR(dynamic_cast(manager_.get()) - ->ChangeConcurrencyLevel(concurrent_request_count)); + RETURN_IF_ERROR( + dynamic_cast(manager_.get()) + ->ChangeConcurrencyLevel(concurrent_request_count, request_count)); - err = ProfileHelper(perf_status, &is_stable); + err = ProfileHelper(perf_status, request_count, &is_stable); if (err.IsOk()) { uint64_t stabilizing_latency_ms = perf_status.stabilizing_latency_ns / NANOS_PER_MILLIS; @@ -590,8 +604,9 @@ InferenceProfiler::Profile( cb::Error InferenceProfiler::Profile( - const double request_rate, std::vector& perf_statuses, - bool& meets_threshold, bool& is_stable) + const double request_rate, const size_t request_count, + std::vector& perf_statuses, bool& meets_threshold, + bool& is_stable) { cb::Error err; PerfStatus perf_status{}; @@ -602,11 +617,11 @@ InferenceProfiler::Profile( meets_threshold = true; RETURN_IF_ERROR(dynamic_cast(manager_.get()) - ->ChangeRequestRate(request_rate)); + ->ChangeRequestRate(request_rate, request_count)); std::cout << "Request Rate: " << request_rate << " inference requests per seconds" << std::endl; - err = ProfileHelper(perf_status, &is_stable); + err = ProfileHelper(perf_status, request_count, &is_stable); if (err.IsOk()) { uint64_t stabilizing_latency_ms = perf_status.stabilizing_latency_ns / NANOS_PER_MILLIS; @@ -638,21 +653,21 @@ InferenceProfiler::Profile( cb::Error InferenceProfiler::Profile( - std::vector& perf_statuses, bool& meets_threshold, - bool& is_stable) + const size_t request_count, std::vector& perf_statuses, + bool& meets_threshold, bool& is_stable) { cb::Error err; PerfStatus perf_status{}; - RETURN_IF_ERROR( - dynamic_cast(manager_.get())->InitCustomIntervals()); + RETURN_IF_ERROR(dynamic_cast(manager_.get()) + ->InitCustomIntervals(request_count)); RETURN_IF_ERROR(dynamic_cast(manager_.get()) ->GetCustomRequestRate(&perf_status.request_rate)); is_stable = false; meets_threshold = true; - err = ProfileHelper(perf_status, &is_stable); + err = ProfileHelper(perf_status, request_count, &is_stable); if (err.IsOk()) { uint64_t stabilizing_latency_ms = perf_status.stabilizing_latency_ns / NANOS_PER_MILLIS; @@ -684,7 +699,7 @@ InferenceProfiler::Profile( cb::Error InferenceProfiler::ProfileHelper( - PerfStatus& experiment_perf_status, bool* is_stable) + PerfStatus& experiment_perf_status, size_t request_count, bool* is_stable) { // Start measurement LoadStatus load_status; @@ -759,6 +774,12 @@ InferenceProfiler::ProfileHelper( } } + // If request-count is specified, then only measure one window and exit + if (request_count != 0) { + *is_stable = true; + break; + } + *is_stable = DetermineStability(load_status); if (IsDoneProfiling(load_status, is_stable)) { @@ -1516,8 +1537,16 @@ InferenceProfiler::DetermineStatsModelVersion( *status_model_version = std::stoll(model_identifier.second); } } - - if (*status_model_version == -1) { + // FIXME - Investigate why composing model version is -1 in case of ensemble + // cache hit. + // + // In case of ensemble models, if top level response caching is + // enabled, the composing models versions are unavailable in case of a cache + // hit. This is due to the scheduler sends cache response and composing models + // do not get executed. It's a valid scenario and shouldn't throw error. + bool model_version_unspecified_and_invalid = + *status_model_version == -1 && !parser_->TopLevelResponseCachingEnabled(); + if (model_version_unspecified_and_invalid) { return cb::Error( "failed to find the requested model version", pa::GENERIC_ERROR); } @@ -1533,6 +1562,21 @@ InferenceProfiler::DetermineStatsModelVersion( return cb::Error::Success; } +// Only for unit-testing +#ifndef DOCTEST_CONFIG_DISABLE +cb::Error +InferenceProfiler::SetTopLevelResponseCaching( + bool enable_top_level_response_caching) +{ + parser_ = std::make_shared(cb::BackendKind::TRITON); + if (parser_ == nullptr) { + return cb::Error("Failed to initialize ModelParser"); + } + parser_->SetTopLevelResponseCaching(enable_top_level_response_caching); + return cb::Error::Success; +} +#endif + cb::Error InferenceProfiler::SummarizeServerStats( const std::map& start_status, @@ -1588,8 +1632,20 @@ InferenceProfiler::SummarizeServerStatsHelper( const auto& end_itr = end_status.find(this_id); if (end_itr == end_status.end()) { - return cb::Error( - "missing statistics for requested model", pa::GENERIC_ERROR); + // In case of ensemble models, if top level response caching is enabled, + // the composing models statistics are unavailable in case of a cache hit. + // This is due to the scheduler sends cache response and composing models do + // not get executed. It's a valid scenario and shouldn't throw error. + bool stats_not_found_and_invalid = + model_version == -1 && !parser_->TopLevelResponseCachingEnabled(); + if (stats_not_found_and_invalid) { + return cb::Error( + "missing statistics for requested model", pa::GENERIC_ERROR); + } else { + // Setting server stats 0 for composing model in case of ensemble request + // cache hit since the composing model will not be executed + server_stats->Reset(); + } } else { uint64_t start_infer_cnt = 0; uint64_t start_exec_cnt = 0; diff --git a/src/c++/perf_analyzer/inference_profiler.h b/src/c++/perf_analyzer/inference_profiler.h index 913b23ded..013dd0483 100644 --- a/src/c++/perf_analyzer/inference_profiler.h +++ b/src/c++/perf_analyzer/inference_profiler.h @@ -1,4 +1,4 @@ -// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -52,6 +52,7 @@ namespace triton { namespace perfanalyzer { #ifndef DOCTEST_CONFIG_DISABLE class NaggyMockInferenceProfiler; class TestInferenceProfiler; +class ModelParser; #endif /// Constant parameters that determine the whether stopping criteria has met @@ -119,6 +120,28 @@ struct ServerSideStats { uint64_t cache_miss_time_ns; std::map composing_models_stat; + // This function sets composing model server stats to 0 in case of a cache hit + // when top level response cache is enabled, since composing models are not + // executed and do not have any stats + void Reset() + { + inference_count = 0; + execution_count = 0; + success_count = 0; + queue_count = 0; + compute_input_count = 0; + compute_infer_count = 0; + compute_output_count = 0; + cumm_time_ns = 0; + queue_time_ns = 0; + compute_input_time_ns = 0; + compute_infer_time_ns = 0; + compute_output_time_ns = 0; + cache_hit_count = 0; + cache_hit_time_ns = 0; + cache_miss_count = 0; + cache_miss_time_ns = 0; + } }; /// Holds the statistics recorded at the client side. @@ -248,25 +271,29 @@ class InferenceProfiler { /// \param step The step size to move along the search range in linear search /// or the precision in binary search. /// \param search_mode The search algorithm to be applied. - /// \param summary Returns the trace of the measurement along the search - /// path. + /// \param request_count The number of requests to generate in each + /// experiment. If 0, then there is no limit, and it will generate until + /// stable. + /// \param summary Returns the trace of the measurement along the search path. /// \return cb::Error object indicating success or failure. template cb::Error Profile( const T start, const T end, const T step, const SearchMode search_mode, - std::vector& perf_statuses) + const size_t request_count, std::vector& perf_statuses) { cb::Error err; bool meets_threshold, is_stable; if (search_mode == SearchMode::NONE) { - err = Profile(perf_statuses, meets_threshold, is_stable); + err = Profile(request_count, perf_statuses, meets_threshold, is_stable); if (!err.IsOk()) { return err; } } else if (search_mode == SearchMode::LINEAR) { T current_value = start; do { - err = Profile(current_value, perf_statuses, meets_threshold, is_stable); + err = Profile( + current_value, request_count, perf_statuses, meets_threshold, + is_stable); if (!err.IsOk()) { return err; } @@ -280,11 +307,13 @@ class InferenceProfiler { "Failed to obtain stable measurement.", pa::STABILITY_ERROR); } } else { - err = Profile(start, perf_statuses, meets_threshold, is_stable); + err = Profile( + start, request_count, perf_statuses, meets_threshold, is_stable); if (!err.IsOk() || (!meets_threshold)) { return err; } - err = Profile(end, perf_statuses, meets_threshold, is_stable); + err = Profile( + end, request_count, perf_statuses, meets_threshold, is_stable); if (!err.IsOk() || (meets_threshold)) { return err; } @@ -293,7 +322,9 @@ class InferenceProfiler { T this_end = end; while ((this_end - this_start) > step) { T current_value = (this_end + this_start) / 2; - err = Profile(current_value, perf_statuses, meets_threshold, is_stable); + err = Profile( + current_value, request_count, perf_statuses, meets_threshold, + is_stable); if (!err.IsOk()) { return err; } @@ -346,43 +377,58 @@ class InferenceProfiler { /// request and right after the last request in the measurement window). /// \param concurrent_request_count The concurrency level for the measurement. /// \param perf_statuses Appends the measurements summary at the end of this - /// list. \param meets_threshold Returns whether the setting meets the + /// list. + /// \param request_count The number of requests to generate when profiling. If + /// 0, then there is no limit, and it will generate until stable. + /// \param meets_threshold Returns whether the setting meets the /// threshold. /// \param is_stable Returns whether the measurement is stable. /// \return cb::Error object indicating success or failure. cb::Error Profile( - const size_t concurrent_request_count, + const size_t concurrent_request_count, const size_t request_count, std::vector& perf_statuses, bool& meets_threshold, bool& is_stable); /// Similar to above function, but instead of setting the concurrency, it /// sets the specified request rate for measurements. /// \param request_rate The request rate for inferences. + /// \param request_count The number of requests to generate when profiling. If + /// 0, then there is no limit, and it will generate until stable. /// \param perf_statuses Appends the measurements summary at the end of this - /// list. \param meets_threshold Returns whether the setting meets the - /// threshold. \param is_stable Returns whether the measurement is stable. + /// list. + /// \param meets_threshold Returns whether the setting meets the + /// threshold. + /// \param is_stable Returns whether the measurement is stable. /// \return cb::Error object indicating success or failure. cb::Error Profile( - const double request_rate, std::vector& perf_statuses, - bool& meets_threshold, bool& is_stable); + const double request_rate, const size_t request_count, + std::vector& perf_statuses, bool& meets_threshold, + bool& is_stable); /// Measures throughput and latencies for custom load without controlling /// request rate nor concurrency. Requires load manager to be loaded with /// a file specifying the time intervals. + /// \param request_count The number of requests to generate when profiling. If + /// 0, then there is no limit, and it will generate until stable. /// \param perf_statuses Appends the measurements summary at the end of this - /// list. \param meets_threshold Returns whether the measurement met the - /// threshold. \param is_stable Returns whether the measurement is stable. + /// list. + /// \param meets_threshold Returns whether the measurement met the + /// threshold. + /// \param is_stable Returns whether the measurement is stable. /// \return cb::Error object indicating success /// or failure. cb::Error Profile( - std::vector& perf_statuses, bool& meets_threshold, - bool& is_stable); + const size_t request_count, std::vector& perf_statuses, + bool& meets_threshold, bool& is_stable); /// A helper function for profiling functions. /// \param status_summary Returns the summary of the measurement. + /// \param request_count The number of requests to generate when profiling. If + /// 0, then there is no limit, and it will generate until stable. /// \param is_stable Returns whether the measurement stabilized or not. /// \return cb::Error object indicating success or failure. - cb::Error ProfileHelper(PerfStatus& status_summary, bool* is_stable); + cb::Error ProfileHelper( + PerfStatus& status_summary, size_t request_count, bool* is_stable); /// A helper function to determine if profiling is stable /// \param load_status Stores the observations of infer_per_sec and latencies @@ -530,12 +576,17 @@ class InferenceProfiler { /// measurement /// \param end_stats The stats for all models at the end of the measurement /// \param model_version The determined model version + cb::Error DetermineStatsModelVersion( const cb::ModelIdentifier& model_identifier, const std::map& start_stats, const std::map& end_stats, int64_t* model_version); +#ifndef DOCTEST_CONFIG_DISABLE + cb::Error SetTopLevelResponseCaching(bool enable_top_level_request_caching); +#endif + /// \param start_status The model status at the start of the measurement. /// \param end_status The model status at the end of the measurement. /// \param server_stats Returns the summary that the fields recorded by server @@ -738,6 +789,7 @@ class InferenceProfiler { #ifndef DOCTEST_CONFIG_DISABLE friend NaggyMockInferenceProfiler; friend TestInferenceProfiler; + friend ModelParser; public: InferenceProfiler() = default; diff --git a/src/c++/perf_analyzer/load_worker.cc b/src/c++/perf_analyzer/load_worker.cc index e3efe43bb..a32976c6a 100644 --- a/src/c++/perf_analyzer/load_worker.cc +++ b/src/c++/perf_analyzer/load_worker.cc @@ -1,4 +1,4 @@ -// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -37,8 +37,14 @@ namespace triton { namespace perfanalyzer { bool LoadWorker::ShouldExit() { - return early_exit || !thread_stat_->cb_status_.IsOk() || - !thread_stat_->status_.IsOk(); + bool bad_status = + !thread_stat_->cb_status_.IsOk() || !thread_stat_->status_.IsOk(); + + bool done_with_request_count = + thread_config_->num_requests_ != 0 && + thread_stat_->num_sent_requests_ >= thread_config_->num_requests_; + + return early_exit || bad_status || done_with_request_count; } bool @@ -46,6 +52,7 @@ LoadWorker::HandleExitConditions() { if (ShouldExit()) { CompleteOngoingSequences(); + thread_stat_->idle_timer.Start(); WaitForOngoingRequests(); return true; } diff --git a/src/c++/perf_analyzer/load_worker.h b/src/c++/perf_analyzer/load_worker.h index 12781f4fe..dd7e0297f 100644 --- a/src/c++/perf_analyzer/load_worker.h +++ b/src/c++/perf_analyzer/load_worker.h @@ -1,4 +1,4 @@ -// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -36,6 +36,7 @@ #include "iworker.h" #include "model_parser.h" #include "sequence_manager.h" +#include "thread_config.h" namespace triton { namespace perfanalyzer { @@ -45,6 +46,7 @@ class LoadWorker : public IWorker { protected: LoadWorker( uint32_t id, std::shared_ptr thread_stat, + std::shared_ptr thread_config, const std::shared_ptr parser, std::shared_ptr data_loader, const std::shared_ptr factory, @@ -54,8 +56,8 @@ class LoadWorker : public IWorker { bool& execute, const std::shared_ptr& infer_data_manager, std::shared_ptr sequence_manager) - : id_(id), thread_stat_(thread_stat), parser_(parser), - data_loader_(data_loader), factory_(factory), + : id_(id), thread_stat_(thread_stat), thread_config_(thread_config), + parser_(parser), data_loader_(data_loader), factory_(factory), on_sequence_model_(on_sequence_model), async_(async), streaming_(streaming), batch_size_(batch_size), using_json_data_(using_json_data), wake_signal_(wake_signal), @@ -137,6 +139,8 @@ class LoadWorker : public IWorker { // Stats for this thread std::shared_ptr thread_stat_; + // Configuration for this thread + std::shared_ptr thread_config_; std::shared_ptr data_loader_; const std::shared_ptr parser_; diff --git a/src/c++/perf_analyzer/mock_profile_data_exporter.h b/src/c++/perf_analyzer/mock_profile_data_exporter.h index 66173b4a0..90e96d736 100644 --- a/src/c++/perf_analyzer/mock_profile_data_exporter.h +++ b/src/c++/perf_analyzer/mock_profile_data_exporter.h @@ -34,13 +34,15 @@ class NaggyMockProfileDataExporter : public ProfileDataExporter { public: NaggyMockProfileDataExporter() { - ON_CALL(*this, ConvertToJson(testing::_, testing::_)) + ON_CALL( + *this, ConvertToJson(testing::_, testing::_, testing::_, testing::_)) .WillByDefault( [this]( const std::vector& raw_experiments, - std::string& raw_version) -> void { + std::string& raw_version, cb::BackendKind& service_kind, + std::string& endpoint) -> void { return this->ProfileDataExporter::ConvertToJson( - raw_experiments, raw_version); + raw_experiments, raw_version, service_kind, endpoint); }); ON_CALL(*this, OutputToFile(testing::_)) @@ -56,15 +58,34 @@ class NaggyMockProfileDataExporter : public ProfileDataExporter { this->ProfileDataExporter::AddExperiment( entry, experiment, raw_experiment); }); + + ON_CALL(*this, AddServiceKind(testing::_)) + .WillByDefault([this](cb::BackendKind& service_kind) -> void { + this->ProfileDataExporter::AddServiceKind(service_kind); + }); + + ON_CALL(*this, AddEndpoint(testing::_)) + .WillByDefault([this](std::string& endpoint) -> void { + this->ProfileDataExporter::AddEndpoint(endpoint); + }); + + ON_CALL(*this, ClearDocument()).WillByDefault([this]() -> void { + this->ProfileDataExporter::ClearDocument(); + }); } MOCK_METHOD( - void, ConvertToJson, (const std::vector&, std::string&), + void, ConvertToJson, + (const std::vector&, std::string&, cb::BackendKind&, + std::string&), (override)); MOCK_METHOD( void, AddExperiment, (rapidjson::Value&, rapidjson::Value&, const Experiment&), (override)); MOCK_METHOD(void, OutputToFile, (std::string&), (override)); + MOCK_METHOD(void, AddServiceKind, (cb::BackendKind&)); + MOCK_METHOD(void, AddEndpoint, (std::string&)); + MOCK_METHOD(void, ClearDocument, ()); rapidjson::Document& document_{ProfileDataExporter::document_}; }; diff --git a/src/c++/perf_analyzer/model_parser.cc b/src/c++/perf_analyzer/model_parser.cc index 1ab9f7a6d..8ffea56da 100644 --- a/src/c++/perf_analyzer/model_parser.cc +++ b/src/c++/perf_analyzer/model_parser.cc @@ -169,6 +169,10 @@ ModelParser::InitTriton( response_cache_enabled_ = cache_itr->value["enable"].GetBool(); } + if (cache_itr != config.MemberEnd()) { + top_level_response_caching_enabled_ = cache_itr->value["enable"].GetBool(); + } + return cb::Error::Success; } diff --git a/src/c++/perf_analyzer/model_parser.h b/src/c++/perf_analyzer/model_parser.h index c1400d079..ac76b3e22 100644 --- a/src/c++/perf_analyzer/model_parser.h +++ b/src/c++/perf_analyzer/model_parser.h @@ -35,6 +35,7 @@ namespace triton { namespace perfanalyzer { #ifndef DOCTEST_CONFIG_DISABLE class TestModelParser; class MockModelParser; +class InferenceProfiler; #endif struct ModelTensor { @@ -73,7 +74,8 @@ class ModelParser { outputs_(std::make_shared()), composing_models_map_(std::make_shared()), scheduler_type_(NONE), max_batch_size_(0), is_decoupled_(false), - response_cache_enabled_(false) + response_cache_enabled_(false), + top_level_response_caching_enabled_(false) { } @@ -151,6 +153,22 @@ class ModelParser { /// model bool ResponseCacheEnabled() const { return response_cache_enabled_; } + /// Returns whether or not top level request caching is enabled for this model + /// \return the truth value of whether top level request caching is enabled + /// for this model + bool TopLevelResponseCachingEnabled() const + { + return top_level_response_caching_enabled_; + } + +/// Only for testing +#ifndef DOCTEST_CONFIG_DISABLE + void SetTopLevelResponseCaching(bool enable_top_level_response_caching) + { + top_level_response_caching_enabled_ = enable_top_level_response_caching; + } +#endif + /// Get the details about the model inputs. /// \return The map with tensor_name and the tensor details /// stored as key-value pair. @@ -169,6 +187,7 @@ class ModelParser { return composing_models_map_; } + protected: ModelSchedulerType scheduler_type_; bool is_decoupled_; @@ -220,10 +239,12 @@ class ModelParser { std::string model_signature_name_; size_t max_batch_size_; bool response_cache_enabled_; + bool top_level_response_caching_enabled_; #ifndef DOCTEST_CONFIG_DISABLE friend TestModelParser; friend MockModelParser; + friend InferenceProfiler; public: ModelParser() = default; diff --git a/src/c++/perf_analyzer/perf_analyzer.cc b/src/c++/perf_analyzer/perf_analyzer.cc index ced5fc991..b8b4de7ea 100644 --- a/src/c++/perf_analyzer/perf_analyzer.cc +++ b/src/c++/perf_analyzer/perf_analyzer.cc @@ -295,29 +295,39 @@ PerfAnalyzer::PrerunReport() if (params_->kind == cb::BackendKind::TRITON || params_->using_batch_size) { std::cout << " Batch size: " << params_->batch_size << std::endl; } - if (params_->kind == cb::BackendKind::TRITON_C_API) { - std::cout << " Service Kind: Triton C-API" << std::endl; - } else if (params_->kind == cb::BackendKind::TRITON) { - std::cout << " Service Kind: Triton" << std::endl; - } else if (params_->kind == cb::BackendKind::TORCHSERVE) { - std::cout << " Service Kind: TorchServe" << std::endl; - } else if (params_->kind == cb::BackendKind::TENSORFLOW_SERVING) { - std::cout << " Service Kind: TensorFlow Serving" << std::endl; - } - if (params_->measurement_mode == pa::MeasurementMode::COUNT_WINDOWS) { - std::cout << " Using \"count_windows\" mode for stabilization" - << std::endl; + std::cout << " Service Kind: " << BackendKindToString(params_->kind) + << std::endl; + + if (params_->request_count != 0) { + std::cout << " Sending a total of " << params_->request_count + << " requests" << std::endl; } else { - std::cout << " Using \"time_windows\" mode for stabilization" << std::endl; - } - if (params_->measurement_mode == pa::MeasurementMode::TIME_WINDOWS) { - std::cout << " Measurement window: " << params_->measurement_window_ms - << " msec" << std::endl; - } else if (params_->measurement_mode == pa::MeasurementMode::COUNT_WINDOWS) { - std::cout << " Minimum number of samples in each window: " - << params_->measurement_request_count << std::endl; + if (params_->measurement_mode == pa::MeasurementMode::COUNT_WINDOWS) { + std::cout << " Using \"count_windows\" mode for stabilization" + << std::endl; + } else { + std::cout << " Using \"time_windows\" mode for stabilization" + << std::endl; + } + + if (params_->percentile == -1) { + std::cout << " Stabilizing using average latency" << std::endl; + } else { + std::cout << " Stabilizing using p" << params_->percentile << " latency" + << std::endl; + } + + if (params_->measurement_mode == pa::MeasurementMode::TIME_WINDOWS) { + std::cout << " Measurement window: " << params_->measurement_window_ms + << " msec" << std::endl; + } else if ( + params_->measurement_mode == pa::MeasurementMode::COUNT_WINDOWS) { + std::cout << " Minimum number of samples in each window: " + << params_->measurement_request_count << std::endl; + } } + if (params_->concurrency_range.end != 1) { std::cout << " Latency limit: " << params_->latency_threshold_ms << " msec" << std::endl; @@ -364,12 +374,6 @@ PerfAnalyzer::PrerunReport() << std::endl; } - if (params_->percentile == -1) { - std::cout << " Stabilizing using average latency" << std::endl; - } else { - std::cout << " Stabilizing using p" << params_->percentile << " latency" - << std::endl; - } std::cout << std::endl; } @@ -382,7 +386,8 @@ PerfAnalyzer::Profile() if (params_->targeting_concurrency()) { err = profiler_->Profile( params_->concurrency_range.start, params_->concurrency_range.end, - params_->concurrency_range.step, params_->search_mode, perf_statuses_); + params_->concurrency_range.step, params_->search_mode, + params_->request_count, perf_statuses_); } else if (params_->is_using_periodic_concurrency_mode) { err = profiler_->ProfilePeriodicConcurrencyMode(); } else { @@ -390,7 +395,7 @@ PerfAnalyzer::Profile() params_->request_rate_range[pa::SEARCH_RANGE::kSTART], params_->request_rate_range[pa::SEARCH_RANGE::kEND], params_->request_rate_range[pa::SEARCH_RANGE::kSTEP], - params_->search_mode, perf_statuses_); + params_->search_mode, params_->request_count, perf_statuses_); } params_->mpi_driver->MPIBarrierWorld(); @@ -452,7 +457,7 @@ PerfAnalyzer::GenerateProfileExport() if (!params_->profile_export_file.empty()) { exporter_->Export( collector_->GetData(), collector_->GetVersion(), - params_->profile_export_file); + params_->profile_export_file, params_->kind, params_->endpoint); } } diff --git a/src/c++/perf_analyzer/periodic_concurrency_manager.cc b/src/c++/perf_analyzer/periodic_concurrency_manager.cc index 1a5527b7b..a8375ed65 100644 --- a/src/c++/perf_analyzer/periodic_concurrency_manager.cc +++ b/src/c++/perf_analyzer/periodic_concurrency_manager.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -39,7 +39,7 @@ PeriodicConcurrencyManager::RunExperiment() std::shared_ptr PeriodicConcurrencyManager::MakeWorker( std::shared_ptr thread_stat, - std::shared_ptr thread_config) + std::shared_ptr thread_config) { uint32_t id = workers_.size(); auto worker = std::make_shared( @@ -66,8 +66,9 @@ PeriodicConcurrencyManager::AddConcurrentRequest(size_t seq_stat_index_offset) { threads_stat_.emplace_back(std::make_shared()); threads_config_.emplace_back( - std::make_shared( - threads_config_.size(), 1, seq_stat_index_offset)); + std::make_shared(threads_config_.size())); + threads_config_.back()->concurrency_ = 1; + threads_config_.back()->seq_stat_index_offset_ = seq_stat_index_offset; workers_.emplace_back( MakeWorker(threads_stat_.back(), threads_config_.back())); threads_.emplace_back(&IWorker::Infer, workers_.back()); diff --git a/src/c++/perf_analyzer/periodic_concurrency_manager.h b/src/c++/perf_analyzer/periodic_concurrency_manager.h index db612fd96..40a0634b4 100644 --- a/src/c++/perf_analyzer/periodic_concurrency_manager.h +++ b/src/c++/perf_analyzer/periodic_concurrency_manager.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -61,8 +61,7 @@ class PeriodicConcurrencyManager : public ConcurrencyManager { private: std::shared_ptr MakeWorker( std::shared_ptr thread_stat, - std::shared_ptr thread_config) - override; + std::shared_ptr thread_config) override; void AddConcurrentRequests(uint64_t num_concurrent_requests); diff --git a/src/c++/perf_analyzer/profile_data_exporter.cc b/src/c++/perf_analyzer/profile_data_exporter.cc index e30807a62..ea79d6856 100644 --- a/src/c++/perf_analyzer/profile_data_exporter.cc +++ b/src/c++/perf_analyzer/profile_data_exporter.cc @@ -47,15 +47,17 @@ ProfileDataExporter::Create(std::shared_ptr* exporter) void ProfileDataExporter::Export( const std::vector& raw_experiments, std::string& raw_version, - std::string& file_path) + std::string& file_path, cb::BackendKind& service_kind, + std::string& endpoint) { - ConvertToJson(raw_experiments, raw_version); + ConvertToJson(raw_experiments, raw_version, service_kind, endpoint); OutputToFile(file_path); } void ProfileDataExporter::ConvertToJson( - const std::vector& raw_experiments, std::string& raw_version) + const std::vector& raw_experiments, std::string& raw_version, + cb::BackendKind& service_kind, std::string& endpoint) { ClearDocument(); rapidjson::Value experiments(rapidjson::kArrayType); @@ -75,6 +77,8 @@ ProfileDataExporter::ConvertToJson( document_.AddMember("experiments", experiments, document_.GetAllocator()); AddVersion(raw_version); + AddServiceKind(service_kind); + AddEndpoint(endpoint); } void @@ -245,6 +249,39 @@ ProfileDataExporter::AddVersion(std::string& raw_version) document_.AddMember("version", version, document_.GetAllocator()); } +void +ProfileDataExporter::AddServiceKind(cb::BackendKind& kind) +{ + std::string raw_service_kind{""}; + if (kind == cb::BackendKind::TRITON) { + raw_service_kind = "triton"; + } else if (kind == cb::BackendKind::TENSORFLOW_SERVING) { + raw_service_kind = "tfserving"; + } else if (kind == cb::BackendKind::TORCHSERVE) { + raw_service_kind = "torchserve"; + } else if (kind == cb::BackendKind::TRITON_C_API) { + raw_service_kind = "triton_c_api"; + } else if (kind == cb::BackendKind::OPENAI) { + raw_service_kind = "openai"; + } else { + std::cerr << "Unknown service kind detected. The 'service_kind' will not " + "be specified." + << std::endl; + } + + rapidjson::Value service_kind; + service_kind.SetString(raw_service_kind.c_str(), document_.GetAllocator()); + document_.AddMember("service_kind", service_kind, document_.GetAllocator()); +} + +void +ProfileDataExporter::AddEndpoint(std::string& raw_endpoint) +{ + rapidjson::Value endpoint; + endpoint = rapidjson::StringRef(raw_endpoint.c_str()); + document_.AddMember("endpoint", endpoint, document_.GetAllocator()); +} + void ProfileDataExporter::OutputToFile(std::string& file_path) { diff --git a/src/c++/perf_analyzer/profile_data_exporter.h b/src/c++/perf_analyzer/profile_data_exporter.h index 9465f2b9d..820148d7a 100644 --- a/src/c++/perf_analyzer/profile_data_exporter.h +++ b/src/c++/perf_analyzer/profile_data_exporter.h @@ -49,9 +49,12 @@ class ProfileDataExporter { /// @param raw_version String containing the version number for the json /// output /// @param file_path File path to export profile data to. + /// @param service_kind Service that Perf Analyzer generates load for. + /// @param endpoint Endpoint to send the requests. void Export( const std::vector& raw_experiments, std::string& raw_version, - std::string& file_path); + std::string& file_path, cb::BackendKind& service_kind, + std::string& endpoint); private: ProfileDataExporter() = default; @@ -60,8 +63,11 @@ class ProfileDataExporter { /// analyzer /// @param raw_version String containing the version number for the json /// output + /// @param service_kind Service that Perf Analyzer generates load for. + /// @param endpoint Endpoint to send the requests. virtual void ConvertToJson( - const std::vector& raw_experiments, std::string& raw_version); + const std::vector& raw_experiments, std::string& raw_version, + cb::BackendKind& service_kind, std::string& endpoint); virtual void OutputToFile(std::string& file_path); virtual void AddExperiment( rapidjson::Value& entry, rapidjson::Value& experiment, @@ -83,6 +89,8 @@ class ProfileDataExporter { rapidjson::Value& entry, rapidjson::Value& window_boundaries, const Experiment& raw_experiment); void AddVersion(std::string& raw_version); + void AddServiceKind(cb::BackendKind& service_kind); + void AddEndpoint(std::string& endpoint); void ClearDocument(); rapidjson::Document document_{}; diff --git a/src/c++/perf_analyzer/request_rate_manager.cc b/src/c++/perf_analyzer/request_rate_manager.cc index a79c52ff4..be12282ab 100644 --- a/src/c++/perf_analyzer/request_rate_manager.cc +++ b/src/c++/perf_analyzer/request_rate_manager.cc @@ -1,4 +1,4 @@ -// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -89,10 +89,11 @@ RequestRateManager::InitManagerFinalize() } cb::Error -RequestRateManager::ChangeRequestRate(const double request_rate) +RequestRateManager::ChangeRequestRate( + const double request_rate, const size_t request_count) { PauseWorkers(); - ConfigureThreads(); + ConfigureThreads(request_count); // Can safely update the schedule GenerateSchedule(request_rate); ResumeWorkers(); @@ -229,15 +230,14 @@ RequestRateManager::PauseWorkers() } void -RequestRateManager::ConfigureThreads() +RequestRateManager::ConfigureThreads(const size_t request_count) { if (threads_.empty()) { size_t num_of_threads = DetermineNumThreads(); while (workers_.size() < num_of_threads) { // Launch new thread for inferencing threads_stat_.emplace_back(new ThreadStat()); - threads_config_.emplace_back( - new RequestRateWorker::ThreadConfig(workers_.size())); + threads_config_.emplace_back(new ThreadConfig(workers_.size())); workers_.push_back( MakeWorker(threads_stat_.back(), threads_config_.back())); @@ -247,11 +247,20 @@ RequestRateManager::ConfigureThreads() size_t avg_num_seqs = num_of_sequences_ / workers_.size(); size_t num_seqs_add_one = num_of_sequences_ % workers_.size(); size_t seq_offset = 0; + + size_t avg_req_count = request_count / workers_.size(); + size_t req_count_add_one = request_count % workers_.size(); + + for (size_t i = 0; i < workers_.size(); i++) { size_t num_of_seq = avg_num_seqs + (i < num_seqs_add_one ? 1 : 0); threads_config_[i]->num_sequences_ = num_of_seq; threads_config_[i]->seq_stat_index_offset_ = seq_offset; seq_offset += num_of_seq; + + size_t thread_num_reqs = avg_req_count + (i < req_count_add_one ? 1 : 0); + threads_config_[i]->num_requests_ = thread_num_reqs; + threads_.emplace_back(&IWorker::Infer, workers_[i]); } } @@ -271,7 +280,7 @@ RequestRateManager::ResumeWorkers() std::shared_ptr RequestRateManager::MakeWorker( std::shared_ptr thread_stat, - std::shared_ptr thread_config) + std::shared_ptr thread_config) { size_t id = workers_.size(); size_t num_of_threads = DetermineNumThreads(); diff --git a/src/c++/perf_analyzer/request_rate_manager.h b/src/c++/perf_analyzer/request_rate_manager.h index deb8ed953..8c9131bb4 100644 --- a/src/c++/perf_analyzer/request_rate_manager.h +++ b/src/c++/perf_analyzer/request_rate_manager.h @@ -1,4 +1,4 @@ -// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -99,10 +99,13 @@ class RequestRateManager : public LoadManager { request_parameters); /// Adjusts the rate of issuing requests to be the same as 'request_rate' - /// \param request_rate The rate at which requests must be issued to the - /// server. + /// \param target_request_rate The rate at which requests must be issued to + /// the server. + /// \param request_count The number of requests to generate when profiling. If + /// 0, then there is no limit, and it will generate until told to stop. /// \return cb::Error object indicating success or failure. - cb::Error ChangeRequestRate(const double target_request_rate); + cb::Error ChangeRequestRate( + const double target_request_rate, const size_t request_count = 0); protected: RequestRateManager( @@ -138,19 +141,18 @@ class RequestRateManager : public LoadManager { // Pauses the worker threads void PauseWorkers(); - void ConfigureThreads(); + void ConfigureThreads(const size_t request_count = 0); // Resets the counters and resumes the worker threads void ResumeWorkers(); // Makes a new worker virtual std::shared_ptr MakeWorker( - std::shared_ptr, - std::shared_ptr); + std::shared_ptr, std::shared_ptr); size_t DetermineNumThreads(); - std::vector> threads_config_; + std::vector> threads_config_; std::shared_ptr gen_duration_; Distribution request_distribution_; diff --git a/src/c++/perf_analyzer/request_rate_worker.h b/src/c++/perf_analyzer/request_rate_worker.h index c7f75c75a..e6d1804c6 100644 --- a/src/c++/perf_analyzer/request_rate_worker.h +++ b/src/c++/perf_analyzer/request_rate_worker.h @@ -1,4 +1,4 @@ -// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -31,6 +31,7 @@ #include "load_worker.h" #include "model_parser.h" #include "sequence_manager.h" +#include "thread_config.h" namespace triton { namespace perfanalyzer { @@ -50,21 +51,6 @@ class TestCustomLoadManager; /// class RequestRateWorker : public LoadWorker, public IScheduler { public: - struct ThreadConfig { - ThreadConfig(uint32_t index) - : id_(index), seq_stat_index_offset_(0), is_paused_(false), - num_sequences_(1) - { - } - - uint32_t id_; - - // The starting sequence stat index for this worker - size_t seq_stat_index_offset_; - uint32_t num_sequences_; - bool is_paused_; - }; - RequestRateWorker( uint32_t id, std::shared_ptr thread_stat, std::shared_ptr thread_config, @@ -80,11 +66,12 @@ class RequestRateWorker : public LoadWorker, public IScheduler { const std::shared_ptr& infer_data_manager, std::shared_ptr sequence_manager) : LoadWorker( - id, thread_stat, parser, data_loader, factory, on_sequence_model, - async, streaming, batch_size, using_json_data, wake_signal, - wake_mutex, execute, infer_data_manager, sequence_manager), - thread_config_(thread_config), num_threads_(num_threads), - start_time_(start_time), serial_sequences_(serial_sequences) + id, thread_stat, thread_config, parser, data_loader, factory, + on_sequence_model, async, streaming, batch_size, using_json_data, + wake_signal, wake_mutex, execute, infer_data_manager, + sequence_manager), + num_threads_(num_threads), start_time_(start_time), + serial_sequences_(serial_sequences) { } @@ -101,8 +88,6 @@ class RequestRateWorker : public LoadWorker, public IScheduler { const bool serial_sequences_; std::chrono::steady_clock::time_point& start_time_; - std::shared_ptr thread_config_; - void CreateCtxIdTracker(); std::chrono::nanoseconds GetNextTimestamp(); diff --git a/src/c++/perf_analyzer/test_command_line_parser.cc b/src/c++/perf_analyzer/test_command_line_parser.cc index 2527d2b1b..765def112 100644 --- a/src/c++/perf_analyzer/test_command_line_parser.cc +++ b/src/c++/perf_analyzer/test_command_line_parser.cc @@ -262,6 +262,7 @@ TEST_CASE("Testing PerfAnalyzerParameters") CHECK(params->max_threads_specified == false); CHECK(params->sequence_length == 20); CHECK(params->percentile == -1); + CHECK(params->request_count == 0); CHECK(params->user_data.size() == 0); CHECK_STRING("endpoint", params->endpoint, ""); CHECK(params->input_shapes.size() == 0); @@ -1469,6 +1470,150 @@ TEST_CASE("Testing Command Line Parser") } } + SUBCASE("Option : --request-count") + { + SUBCASE("valid value") + { + int argc = 5; + char* argv[argc] = {app_name, "-m", model_name, "--request-count", "500"}; + + REQUIRE_NOTHROW(act = parser.Parse(argc, argv)); + CHECK(!parser.UsageCalled()); + + exp->request_count = 500; + exp->measurement_mode = MeasurementMode::COUNT_WINDOWS; + exp->measurement_request_count = 500; + } + SUBCASE("negative value") + { + int argc = 5; + char* argv[argc] = {app_name, "-m", model_name, "--request-count", "-2"}; + + expected_msg = + CreateUsageMessage("--request-count", "The value must be > 0."); + CHECK_THROWS_WITH_AS( + act = parser.Parse(argc, argv), expected_msg.c_str(), + PerfAnalyzerException); + check_params = false; + } + SUBCASE("less than request rate") + { + int argc = 7; + char* argv[argc] = {app_name, "-m", + model_name, "--request-count", + "2", "--request-rate-range", + "5"}; + + expected_msg = "request-count can not be less than request-rate"; + CHECK_THROWS_WITH_AS( + act = parser.Parse(argc, argv), expected_msg.c_str(), + PerfAnalyzerException); + check_params = false; + } + SUBCASE("less than concurrency") + { + int argc = 7; + char* argv[argc] = {app_name, "-m", + model_name, "--request-count", + "2", "--concurrency-range", + "5"}; + + expected_msg = "request-count can not be less than concurrency"; + CHECK_THROWS_WITH_AS( + act = parser.Parse(argc, argv), expected_msg.c_str(), + PerfAnalyzerException); + check_params = false; + } + SUBCASE("multiple request rate") + { + int argc = 7; + char* argv[argc] = {app_name, "-m", + model_name, "--request-count", + "20", "--request-rate-range", + "5:6:1"}; + + expected_msg = + "request-count not supported with multiple request-rate values in " + "one run"; + CHECK_THROWS_WITH_AS( + act = parser.Parse(argc, argv), expected_msg.c_str(), + PerfAnalyzerException); + check_params = false; + } + SUBCASE("multiple concurrency") + { + int argc = 7; + char* argv[argc] = {app_name, "-m", + model_name, "--request-count", + "20", "--concurrency-range", + "5:6:1"}; + + expected_msg = + "request-count not supported with multiple concurrency values in " + "one run"; + CHECK_THROWS_WITH_AS( + act = parser.Parse(argc, argv), expected_msg.c_str(), + PerfAnalyzerException); + check_params = false; + } + + SUBCASE("mode and count are overwritten with non-zero request-count") + { + int argc = 9; + char* argv[argc] = { + app_name, + "-m", + model_name, + "--request-count", + "2000", + "--measurement-mode", + "time_windows", + "measurement-request-count", + "30"}; + + REQUIRE_NOTHROW(act = parser.Parse(argc, argv)); + CHECK(!parser.UsageCalled()); + + exp->request_count = 2000; + exp->measurement_mode = MeasurementMode::COUNT_WINDOWS; + exp->measurement_request_count = 2000; + } + SUBCASE("zero value (no override to measurement mode)") + { + int argc = 7; + char* argv[argc] = {app_name, "-m", model_name, + "--request-count", "0", "--measurement-mode", + "time_windows"}; + + REQUIRE_NOTHROW(act = parser.Parse(argc, argv)); + CHECK(!parser.UsageCalled()); + + exp->request_count = 0; + exp->measurement_mode = MeasurementMode::TIME_WINDOWS; + } + SUBCASE("zero value (no override to measurement request count)") + { + int argc = 9; + char* argv[argc] = { + app_name, + "-m", + model_name, + "--request-count", + "0", + "--measurement-mode", + "count_windows", + "--measurement-request-count", + "50"}; + + REQUIRE_NOTHROW(act = parser.Parse(argc, argv)); + CHECK(!parser.UsageCalled()); + + exp->request_count = 0; + exp->measurement_mode = MeasurementMode::COUNT_WINDOWS; + exp->measurement_request_count = 50; + } + } + SUBCASE("Option : --collect-metrics") { SUBCASE("with --service-kind != triton") diff --git a/src/c++/perf_analyzer/test_concurrency_manager.cc b/src/c++/perf_analyzer/test_concurrency_manager.cc index 58d3a3031..1941a018e 100644 --- a/src/c++/perf_analyzer/test_concurrency_manager.cc +++ b/src/c++/perf_analyzer/test_concurrency_manager.cc @@ -1,4 +1,4 @@ -// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -60,7 +60,7 @@ class TestConcurrencyManager : public TestLoadManagerBase, std::shared_ptr MakeWorker( std::shared_ptr thread_stat, - std::shared_ptr thread_config) override + std::shared_ptr thread_config) override { size_t id = workers_.size(); @@ -80,10 +80,10 @@ class TestConcurrencyManager : public TestLoadManagerBase, void TestReconfigThreads( - const size_t concurrent_request_count, - std::vector& expected_configs) + const size_t concurrent_request_count, const size_t num_requests, + std::vector& expected_configs) { - ConcurrencyManager::ReconfigThreads(concurrent_request_count); + ConcurrencyManager::ReconfigThreads(concurrent_request_count, num_requests); auto expected_size = expected_configs.size(); @@ -99,6 +99,9 @@ class TestConcurrencyManager : public TestLoadManagerBase, CHECK( threads_config_[i]->seq_stat_index_offset_ == expected_configs[i].seq_stat_index_offset_); + CHECK( + threads_config_[i]->num_requests_ == + expected_configs[i].num_requests_); } } @@ -461,8 +464,8 @@ TEST_CASE("concurrency_free_ctx_ids") tcm.stats_->SetDelays({50, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}); std::shared_ptr thread_stat{std::make_shared()}; - std::shared_ptr thread_config{ - std::make_shared(0)}; + std::shared_ptr thread_config{ + std::make_shared(0)}; thread_config->concurrency_ = 4; std::shared_ptr worker{tcm.MakeWorker(thread_stat, thread_config)}; @@ -565,8 +568,8 @@ TEST_CASE("Concurrency - shared memory infer input calls") mip.mock_model_parser_, tcm.factory_, mip.mock_data_loader_); std::shared_ptr thread_stat{std::make_shared()}; - std::shared_ptr thread_config{ - std::make_shared(0)}; + std::shared_ptr thread_config{ + std::make_shared(0)}; thread_config->concurrency_ = 1; tcm.parser_ = mip.mock_model_parser_; @@ -867,50 +870,71 @@ TEST_CASE( TEST_CASE( "reconfigure_threads" * doctest::description( - "This test confirms the side-effects of ReconfigureThreads(). Namely, " + "This test confirms the side-effects of ReconfigThreads(). Namely, " "that the correct number of threads are created and that they are " "configured properly")) { PerfAnalyzerParameters params{}; - std::vector expected_config_values; + std::vector expected_config_values; std::vector expected_concurrencies; std::vector expected_seq_stat_index_offsets; + std::vector expected_num_requests; + size_t target_concurrency = 0; + size_t target_num_requests = 0; SUBCASE("normal") { params.max_threads = 10; target_concurrency = 5; + target_num_requests = 15; expected_concurrencies = {1, 1, 1, 1, 1}; expected_seq_stat_index_offsets = {0, 1, 2, 3, 4}; + expected_num_requests = {3, 3, 3, 3, 3}; } SUBCASE("thread_limited") { params.max_threads = 5; target_concurrency = 10; + target_num_requests = 20; expected_concurrencies = {2, 2, 2, 2, 2}; expected_seq_stat_index_offsets = {0, 2, 4, 6, 8}; + expected_num_requests = {4, 4, 4, 4, 4}; } SUBCASE("unbalanced") { params.max_threads = 6; target_concurrency = 14; + target_num_requests = 15; expected_concurrencies = {3, 3, 2, 2, 2, 2}; expected_seq_stat_index_offsets = {0, 3, 6, 8, 10, 12}; + expected_num_requests = {3, 3, 3, 2, 2, 2}; + } + SUBCASE("no requests specified") + { + params.max_threads = 2; + target_concurrency = 14; + target_num_requests = 0; + + expected_concurrencies = {7, 7}; + expected_seq_stat_index_offsets = {0, 7}; + expected_num_requests = {0, 0}; } for (auto i = 0; i < expected_concurrencies.size(); i++) { - ConcurrencyWorker::ThreadConfig tc(i); + ThreadConfig tc(i); tc.concurrency_ = expected_concurrencies[i]; tc.seq_stat_index_offset_ = expected_seq_stat_index_offsets[i]; + tc.num_requests_ = expected_num_requests[i]; expected_config_values.push_back(tc); } TestConcurrencyManager tcm(params); - tcm.TestReconfigThreads(target_concurrency, expected_config_values); + tcm.TestReconfigThreads( + target_concurrency, target_num_requests, expected_config_values); } diff --git a/src/c++/perf_analyzer/test_custom_load_manager.cc b/src/c++/perf_analyzer/test_custom_load_manager.cc index 0cb6c4c5c..ced79af7d 100644 --- a/src/c++/perf_analyzer/test_custom_load_manager.cc +++ b/src/c++/perf_analyzer/test_custom_load_manager.cc @@ -1,4 +1,4 @@ -// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -70,7 +70,7 @@ class TestCustomLoadManager : public TestLoadManagerBase, std::shared_ptr MakeWorker( std::shared_ptr thread_stat, - std::shared_ptr thread_config) override + std::shared_ptr thread_config) override { size_t id = workers_.size(); auto worker = std::make_shared( diff --git a/src/c++/perf_analyzer/test_inference_profiler.cc b/src/c++/perf_analyzer/test_inference_profiler.cc index 683219f15..8ff39605b 100644 --- a/src/c++/perf_analyzer/test_inference_profiler.cc +++ b/src/c++/perf_analyzer/test_inference_profiler.cc @@ -160,8 +160,15 @@ class TestInferenceProfiler : public InferenceProfiler { return InferenceProfiler::DetermineStatsModelVersion( model_identifier, start_stats, end_stats, model_version); } + + cb::Error SetTopLevelResponseCaching(bool enable_top_level_response_caching) + { + return InferenceProfiler::SetTopLevelResponseCaching( + enable_top_level_response_caching); + } }; + TEST_CASE("testing the ValidLatencyMeasurement function") { size_t valid_sequence_count{}; @@ -850,6 +857,25 @@ TEST_CASE("determine_stats_model_version: testing DetermineStatsModelVersion()") expect_exception = true; } + SUBCASE("One entry - version -1 - valid and in start") + { + model_identifier = {"ModelA", "-1"}; + start_stats_map.insert({{"ModelA", "3"}, old_stats}); + end_stats_map.insert({{"ModelA", "3"}, new_stats}); + cb::Error status = tip.SetTopLevelResponseCaching(true); + CHECK(status.IsOk()); + expected_model_version = -1; + } + + SUBCASE("One entry - version -1 - not valid") + { + model_identifier = {"ModelA", "-1"}; + end_stats_map.insert({{"ModelA", "3"}, old_stats}); + cb::Error status = tip.SetTopLevelResponseCaching(false); + CHECK(status.IsOk()); + expected_model_version = -1; + expect_exception = true; + } std::stringstream captured_cerr; std::streambuf* old = std::cerr.rdbuf(captured_cerr.rdbuf()); diff --git a/src/c++/perf_analyzer/test_profile_data_exporter.cc b/src/c++/perf_analyzer/test_profile_data_exporter.cc index 3cef51afb..ffd958c5c 100644 --- a/src/c++/perf_analyzer/test_profile_data_exporter.cc +++ b/src/c++/perf_analyzer/test_profile_data_exporter.cc @@ -102,8 +102,10 @@ TEST_CASE("profile_data_exporter: ConvertToJson") std::vector experiments{experiment}; std::string version{"1.2.3"}; + cb::BackendKind service_kind = cb::BackendKind::TRITON; + std::string endpoint{""}; - exporter.ConvertToJson(experiments, version); + exporter.ConvertToJson(experiments, version, service_kind, endpoint); std::string json{R"( { @@ -125,7 +127,9 @@ TEST_CASE("profile_data_exporter: ConvertToJson") "window_boundaries" : [ 1, 5, 6 ] } ], - "version" : "1.2.3" + "version" : "1.2.3", + "service_kind": "triton", + "endpoint": "" } )"}; @@ -244,4 +248,80 @@ TEST_CASE("profile_data_exporter: OutputToFile") } } +TEST_CASE("profile_data_exporter: AddServiceKind") +{ + MockProfileDataExporter exporter{}; + exporter.ClearDocument(); + + cb::BackendKind service_kind; + std::string json{""}; + + SUBCASE("Backend kind: TRITON") + { + service_kind = cb::BackendKind::TRITON; + json = R"({ "service_kind": "triton" })"; + } + + SUBCASE("Backend kind: TENSORFLOW_SERVING") + { + service_kind = cb::BackendKind::TENSORFLOW_SERVING; + json = R"({ "service_kind": "tfserving" })"; + } + + SUBCASE("Backend kind: TORCHSERVE") + { + service_kind = cb::BackendKind::TORCHSERVE; + json = R"({ "service_kind": "torchserve" })"; + } + + SUBCASE("Backend kind: TRITON_C_API") + { + service_kind = cb::BackendKind::TRITON_C_API; + json = R"({ "service_kind": "triton_c_api" })"; + } + + SUBCASE("Backend kind: OPENAI") + { + service_kind = cb::BackendKind::OPENAI; + json = R"({ "service_kind": "openai" })"; + } + + exporter.AddServiceKind(service_kind); + rapidjson::Document expected_document; + expected_document.Parse(json.c_str()); + + const rapidjson::Value& expected_kind{expected_document["service_kind"]}; + const rapidjson::Value& actual_kind{exporter.document_["service_kind"]}; + CHECK(actual_kind == expected_kind); +} + +TEST_CASE("profile_data_exporter: AddEndpoint") +{ + MockProfileDataExporter exporter{}; + exporter.ClearDocument(); + + std::string endpoint{""}; + std::string json{""}; + + SUBCASE("Endpoint: OpenAI Chat Completions") + { + endpoint = "v1/chat/completions"; + json = R"({ "endpoint": "v1/chat/completions" })"; + } + + SUBCASE("Endpoint: OpenAI Completions") + { + endpoint = "v1/completions"; + json = R"({ "endpoint": "v1/completions" })"; + } + + exporter.AddEndpoint(endpoint); + rapidjson::Document expected_document; + expected_document.Parse(json.c_str()); + + const rapidjson::Value& expected_endpoint{expected_document["endpoint"]}; + const rapidjson::Value& actual_endpoint{exporter.document_["endpoint"]}; + CHECK(actual_endpoint == expected_endpoint); +} + }} // namespace triton::perfanalyzer diff --git a/src/c++/perf_analyzer/test_request_rate_manager.cc b/src/c++/perf_analyzer/test_request_rate_manager.cc index 008424b72..48e428946 100644 --- a/src/c++/perf_analyzer/test_request_rate_manager.cc +++ b/src/c++/perf_analyzer/test_request_rate_manager.cc @@ -1,4 +1,4 @@ -// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -68,7 +68,7 @@ class TestRequestRateManager : public TestLoadManagerBase, std::shared_ptr MakeWorker( std::shared_ptr thread_stat, - std::shared_ptr thread_config) override + std::shared_ptr thread_config) override { size_t id = workers_.size(); auto worker = std::make_shared( @@ -86,9 +86,9 @@ class TestRequestRateManager : public TestLoadManagerBase, } void TestConfigureThreads( - std::vector& expected_configs) + std::vector& expected_configs, size_t request_count) { - RequestRateManager::ConfigureThreads(); + RequestRateManager::ConfigureThreads(request_count); auto expected_size = expected_configs.size(); @@ -105,6 +105,9 @@ class TestRequestRateManager : public TestLoadManagerBase, CHECK( threads_config_[i]->seq_stat_index_offset_ == expected_configs[i].seq_stat_index_offset_); + CHECK( + threads_config_[i]->num_requests_ == + expected_configs[i].num_requests_); } } @@ -401,8 +404,8 @@ class TestRequestRateManager : public TestLoadManagerBase, cb::Error CustomDataTestSendRequests(size_t num_requests) { std::shared_ptr thread_stat{std::make_shared()}; - std::shared_ptr thread_config{ - std::make_shared(0)}; + std::shared_ptr thread_config{ + std::make_shared(0)}; std::shared_ptr worker{MakeWorker(thread_stat, thread_config)}; auto mock_worker = std::dynamic_pointer_cast(worker); @@ -948,8 +951,8 @@ TEST_CASE("request_rate_streaming: test that streaming-specific logic works") schedule->duration = nanoseconds{1}; std::shared_ptr thread_stat{std::make_shared()}; - std::shared_ptr thread_config{ - std::make_shared(0)}; + std::shared_ptr thread_config{ + std::make_shared(0)}; TestRequestRateManager trrm(params, is_sequence, is_decoupled); trrm.InitManager( @@ -1698,8 +1701,8 @@ TEST_CASE("Request rate - Shared memory infer input calls") mip.mock_model_parser_, trrm.factory_, mip.mock_data_loader_); std::shared_ptr thread_stat{std::make_shared()}; - std::shared_ptr thread_config{ - std::make_shared(0)}; + std::shared_ptr thread_config{ + std::make_shared(0)}; trrm.parser_ = mip.mock_model_parser_; trrm.data_loader_ = mip.mock_data_loader_; @@ -1947,59 +1950,71 @@ TEST_CASE( TEST_CASE("request rate manager - Configure threads") { PerfAnalyzerParameters params{}; - std::vector expected_config_values; + std::vector expected_config_values; std::vector expected_number_of_sequences_owned_by_thread; std::vector expected_seq_stat_index_offsets; + std::vector expected_num_requests; bool is_sequence_model = true; bool is_decoupled_model = false; bool use_mock_infer = true; + size_t target_num_requests = 0; SUBCASE("normal") { params.max_threads = 4; params.num_of_sequences = 4; + target_num_requests = 0; expected_number_of_sequences_owned_by_thread = {1, 1, 1, 1}; expected_seq_stat_index_offsets = {0, 1, 2, 3}; + expected_num_requests = {0, 0, 0, 0}; } SUBCASE("max_threads > num_seqs") { params.max_threads = 10; params.num_of_sequences = 4; + target_num_requests = 8; expected_number_of_sequences_owned_by_thread = {1, 1, 1, 1}; expected_seq_stat_index_offsets = {0, 1, 2, 3}; + expected_num_requests = {2, 2, 2, 2}; } SUBCASE("num_seqs > max_threads") { params.max_threads = 4; params.num_of_sequences = 10; + target_num_requests = 20; expected_number_of_sequences_owned_by_thread = {3, 3, 2, 2}; expected_seq_stat_index_offsets = {0, 3, 6, 8}; + expected_num_requests = {5, 5, 5, 5}; } SUBCASE("not divisible") { params.max_threads = 4; params.num_of_sequences = 7; + target_num_requests = 13; expected_number_of_sequences_owned_by_thread = {2, 2, 2, 1}; expected_seq_stat_index_offsets = {0, 2, 4, 6}; + expected_num_requests = {4, 3, 3, 3}; } for (auto i = 0; i < expected_number_of_sequences_owned_by_thread.size(); i++) { - RequestRateWorker::ThreadConfig tc(i); + ThreadConfig tc(i); tc.num_sequences_ = expected_number_of_sequences_owned_by_thread[i]; tc.seq_stat_index_offset_ = expected_seq_stat_index_offsets[i]; + tc.num_requests_ = expected_num_requests[i]; + expected_config_values.push_back(tc); } TestRequestRateManager trrm( params, is_sequence_model, is_decoupled_model, use_mock_infer); - trrm.TestConfigureThreads(expected_config_values); + trrm.TestConfigureThreads(expected_config_values, target_num_requests); } TEST_CASE("request rate manager - Calculate thread ids") diff --git a/src/c++/perf_analyzer/thread_config.h b/src/c++/perf_analyzer/thread_config.h new file mode 100644 index 000000000..4c4845a6e --- /dev/null +++ b/src/c++/perf_analyzer/thread_config.h @@ -0,0 +1,58 @@ +// Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +namespace triton { namespace perfanalyzer { + +// Holds the configuration for a worker thread +struct ThreadConfig { + ThreadConfig(size_t thread_id) : thread_id_(thread_id) {} + + // ID of corresponding worker thread + size_t thread_id_{0}; + + // The concurrency level that the worker should produce + // TPA-69: This is only used in concurrency mode and shouldn't be visible in + // other modes + size_t concurrency_{0}; + + // The number of sequences owned by this worker + // TPA-69: This is only used in request-rate mode and shouldn't be visible in + // other modes + uint32_t num_sequences_{1}; + + // How many requests to generate before stopping. If 0, generate indefinitely + size_t num_requests_{0}; + + // The starting sequence stat index for this worker + size_t seq_stat_index_offset_{0}; + + // Whether or not the thread is issuing new inference requests + bool is_paused_{false}; +}; + + +}} // namespace triton::perfanalyzer