Skip to content

Commit

Permalink
Add helper functions to pull metrics in HTTPAPIServer to pull metrics…
Browse files Browse the repository at this point in the history
… for use in HandleGenerate to add kv_utilization and max_token_capacity to the inference request response header.
  • Loading branch information
BenjaminBraunDev committed Dec 11, 2024
1 parent e0f0734 commit 74492a8
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 1 deletion.
150 changes: 150 additions & 0 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3225,6 +3225,42 @@ HTTPAPIServer::HandleGenerate(
req, RestrictedCategory::INFERENCE, restricted_apis_);

AddContentTypeHeader(req, "application/json");
#ifdef TRITON_ENABLE_METRICS
// logic to add kv_cache metrics to response header
// Get the metrics in Prometheus format
if (std::getenv("ORCA_HEADER_METRIC_TYPE") != nullptr) {
const std::string orca_type = std::getenv("ORCA_HEADER_METRIC_TYPE");
TRITONSERVER_Metrics* metrics = nullptr;
TRITONSERVER_Error* err =
TRITONSERVER_ServerMetrics(server_.get(), &metrics);
if (err == nullptr) {
const char* base;
size_t byte_size;
err = TRITONSERVER_MetricsFormatted(
metrics, TRITONSERVER_METRIC_PROMETHEUS, &base, &byte_size);
if (err == nullptr) {
std::string kv_utilization(base, byte_size);
// Extract the KV utilization metrics from the Prometheus formatted
// string.
std::string extracted_kv_metrics =
ExtractKVMetrics(kv_utilization, orca_type);
if (!extracted_kv_metrics.empty()) {
evhtp_headers_add_header(
req->headers_out,
evhtp_header_new(
"endpoint-load-metrics", extracted_kv_metrics.c_str(), 1, 1));
}
}
}
TRITONSERVER_MetricsDelete(metrics);
// Handle potential errors
if (err != nullptr) {
LOG_ERROR << "Failed to get KV metrics: "
<< TRITONSERVER_ErrorMessage(err);
TRITONSERVER_ErrorDelete(err);
}
}
#endif // TRITON_ENABLE_METRICS
if (req->method != htp_method_POST) {
RETURN_AND_RESPOND_WITH_ERR(
req, EVHTP_RES_METHNALLOWED, "Method Not Allowed");
Expand Down Expand Up @@ -3381,6 +3417,120 @@ HTTPAPIServer::HandleGenerate(
request_release_payload.release();
}

#ifdef TRITON_ENABLE_METRICS
std::vector<HTTPAPIServer::PromMetric>
HTTPAPIServer::MetricFamilyExtractor(
const std::string& input, const std::string& metricFamily)
{
std::vector<PromMetric> metrics;
// Construct the regex pattern using the provided metricFamily
std::string patternStr = metricFamily + R"((?:{(.*?)})?\s+(\d+(?:\.\d+)?))";
re2::RE2 pattern(patternStr);
re2::StringPiece inputPiece(input);

std::string labelString;
std::string metric_value;

while (re2::RE2::FindAndConsume(
&inputPiece, pattern, &labelString, &metric_value)) {
PromMetric metric;

// Extract labels if they exist
if (!labelString.empty()) {
re2::RE2 labelPattern(R"((\w+)=\"([^\"]+)\")");
re2::StringPiece labelPiece(labelString);
std::string key, value;
while (
re2::RE2::FindAndConsume(&labelPiece, labelPattern, &key, &value)) {
metric.labels[key] = value;
}
}

// Assign the value
metric.value = stod(metric_value);
metrics.push_back(metric);
}

return metrics;
}

std::string
HTTPAPIServer::ExtractKVMetrics(
const std::string& prometheus_metrics, const std::string& orca_type)
{
std::string metric_family = "nv_trt_llm_kv_cache_block_metrics";
std::vector<PromMetric> kv_cache_metrics =
MetricFamilyExtractor(prometheus_metrics, metric_family);

double tokens_per_block = -1;
double used_blocks = -1;
double max_blocks = -1;

for (const auto& metric : kv_cache_metrics) {
if (metric.labels.count("kv_cache_block_type") > 0) {
std::string type = metric.labels.at("kv_cache_block_type");
if (type == "tokens_per") {
tokens_per_block = metric.value;
} else if (type == "used") {
used_blocks = metric.value;
} else if (type == "max") {
max_blocks = metric.value;
}
}
}

// One or more of the kv metrics was not found or invalid.
if (tokens_per_block < 0 || used_blocks < 0 || max_blocks < 0) {
return "";
}

// Calculate derived metrics
double kv_cache_utilization = 0;
if (max_blocks > 0) {
kv_cache_utilization = used_blocks / max_blocks;
}
uint64_t max_token_capacity =
static_cast<uint64_t>(max_blocks * tokens_per_block);

// Logic to construct and format response header
std::string header_contents = "";
const std::string named_metrics_key = "named_metrics";
const std::string kv_util_key = "kv_cache_utilization";
const std::string max_token_key = "max_token_capacity";

if (orca_type == "json") {
// Format the metrics according to the ORCA protocol as JSON.
triton::common::TritonJson::Value orca_metrics(
triton::common::TritonJson::ValueType::OBJECT);
triton::common::TritonJson::Value named_metrics(
orca_metrics, triton::common::TritonJson::ValueType::OBJECT);

named_metrics.AddDouble(kv_util_key.c_str(), kv_cache_utilization);
named_metrics.AddUInt(max_token_key.c_str(), max_token_capacity);
orca_metrics.Add(named_metrics_key.c_str(), std::move(named_metrics));

triton::common::TritonJson::WriteBuffer buffer;
orca_metrics.Write(&buffer);
header_contents = std::string("JSON ") + buffer.Contents();

} else if (orca_type == "http") {
// Format the metrics according to the ORCA protocol as Native HTTP
// (comma separated list).
const std::string prefix = named_metrics_key + ".";

header_contents = "TEXT ";
header_contents += prefix + kv_util_key + "=" +
std::to_string(kv_cache_utilization) + ", ";
header_contents +=
prefix + max_token_key + "=" + std::to_string(max_token_capacity);
} else {
LOG_ERROR << "orca_type is set to an invalid type: " << orca_type;
}

return header_contents;
}
#endif // TRITON_ENABLE_METRICS

TRITONSERVER_Error*
HTTPAPIServer::ModelInputMetadata(
const std::string& model_name, const int64_t model_version,
Expand Down
20 changes: 19 additions & 1 deletion src/http_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,14 @@ class HTTPAPIServer : public HTTPServer {
evbuffer* buffer_ = nullptr;
};

#ifdef TRITON_ENABLE_METRICS
private:
struct PromMetric {
std::unordered_map<std::string, std::string> labels;
double value;
};
#endif // TRITON_ENABLE_METRICS

protected:
explicit HTTPAPIServer(
const std::shared_ptr<TRITONSERVER_Server>& server,
Expand Down Expand Up @@ -558,7 +566,17 @@ class HTTPAPIServer : public HTTPServer {
void HandleGenerate(
evhtp_request_t* req, const std::string& model_name,
const std::string& model_version_str, bool streaming);

#ifdef TRITON_ENABLE_METRICS
// Helper function to set get the KV-cache utilization metrics for the
// infer response header
std::string ExtractKVMetrics(
const std::string& prometheus_metrics, const std::string& orca_type);

// Generates a metric struct for a given family with a map of labels and a
// value
std::vector<PromMetric> MetricFamilyExtractor(
const std::string& input, const std::string& metricFamily);
#endif // TRITON_ENABLE_METRICS
// 'meta_data_root' is the root JSON document for 'input_metadata'.
// In TritonJson, the Value objects are references to the root document.
// Therefore the document must stay valid.
Expand Down

0 comments on commit 74492a8

Please sign in to comment.