diff --git a/README.md b/README.md index 9582830c..ad40e400 100644 --- a/README.md +++ b/README.md @@ -338,6 +338,14 @@ from prometheus_client import start_wsgi_server start_wsgi_server(8000) ``` +By default, the WSGI application will respect `Accept-Encoding:gzip` headers used by Prometheus +and compress the response if such a header is present. This behaviour can be disabled by passing +`disable_compression=True` when creating the app, like this: + +```python +app = make_wsgi_app(disable_compression=True) +``` + #### ASGI To use Prometheus with [ASGI](http://asgi.readthedocs.org/en/latest/), there is @@ -351,6 +359,14 @@ app = make_asgi_app() Such an application can be useful when integrating Prometheus metrics with ASGI apps. +By default, the WSGI application will respect `Accept-Encoding:gzip` headers used by Prometheus +and compress the response if such a header is present. This behaviour can be disabled by passing +`disable_compression=True` when creating the app, like this: + +```python +app = make_asgi_app(disable_compression=True) +``` + #### Flask To use Prometheus with [Flask](http://flask.pocoo.org/) we need to serve metrics through a Prometheus WSGI application. This can be achieved using [Flask's application dispatching](http://flask.pocoo.org/docs/latest/patterns/appdispatch/). Below is a working example. diff --git a/prometheus_client/asgi.py b/prometheus_client/asgi.py index c1518ce5..6f72838b 100644 --- a/prometheus_client/asgi.py +++ b/prometheus_client/asgi.py @@ -5,7 +5,7 @@ from .registry import CollectorRegistry, REGISTRY -def make_asgi_app(registry: CollectorRegistry = REGISTRY) -> Callable: +def make_asgi_app(registry: CollectorRegistry = REGISTRY, disable_compression: bool = False) -> Callable: """Create a ASGI app which serves the metrics from a registry.""" async def prometheus_app(scope, receive, send): @@ -14,10 +14,17 @@ async def prometheus_app(scope, receive, send): params = parse_qs(scope.get('query_string', b'')) accept_header = "Accept: " + ",".join([ value.decode("utf8") for (name, value) in scope.get('headers') - if name.decode("utf8") == 'accept' + if name.decode("utf8").lower() == 'accept' + ]) + accept_encoding_header = ",".join([ + value.decode("utf8") for (name, value) in scope.get('headers') + if name.decode("utf8").lower() == 'accept-encoding' ]) # Bake output - status, header, output = _bake_output(registry, accept_header, params) + status, headers, output = _bake_output(registry, accept_header, accept_encoding_header, params, disable_compression) + formatted_headers = [] + for header in headers: + formatted_headers.append(tuple(x.encode('utf8') for x in header)) # Return output payload = await receive() if payload.get("type") == "http.request": @@ -25,9 +32,7 @@ async def prometheus_app(scope, receive, send): { "type": "http.response.start", "status": int(status.split(' ')[0]), - "headers": [ - tuple(x.encode('utf8') for x in header) - ] + "headers": formatted_headers, } ) await send({"type": "http.response.body", "body": output}) diff --git a/prometheus_client/exposition.py b/prometheus_client/exposition.py index fc0a1f18..86a9be48 100644 --- a/prometheus_client/exposition.py +++ b/prometheus_client/exposition.py @@ -1,5 +1,6 @@ import base64 from contextlib import closing +import gzip from http.server import BaseHTTPRequestHandler import os import socket @@ -93,32 +94,39 @@ def redirect_request(self, req, fp, code, msg, headers, newurl): return new_request -def _bake_output(registry, accept_header, params): +def _bake_output(registry, accept_header, accept_encoding_header, params, disable_compression): """Bake output for metrics output.""" - encoder, content_type = choose_encoder(accept_header) + # Choose the correct plain text format of the output. + formatter, content_type = choose_formatter(accept_header) if 'name[]' in params: registry = registry.restricted_registry(params['name[]']) - output = encoder(registry) - return '200 OK', ('Content-Type', content_type), output + output = formatter(registry) + headers = [('Content-Type', content_type)] + # If gzip encoding required, gzip the output. + if not disable_compression and gzip_accepted(accept_encoding_header): + output = gzip.compress(output) + headers.append(('Content-Encoding', 'gzip')) + return '200 OK', headers, output -def make_wsgi_app(registry: CollectorRegistry = REGISTRY) -> Callable: +def make_wsgi_app(registry: CollectorRegistry = REGISTRY, disable_compression: bool = False) -> Callable: """Create a WSGI app which serves the metrics from a registry.""" def prometheus_app(environ, start_response): # Prepare parameters accept_header = environ.get('HTTP_ACCEPT') + accept_encoding_header = environ.get('HTTP_ACCEPT_ENCODING') params = parse_qs(environ.get('QUERY_STRING', '')) if environ['PATH_INFO'] == '/favicon.ico': # Serve empty response for browsers status = '200 OK' - header = ('', '') + headers = [('', '')] output = b'' else: # Bake output - status, header, output = _bake_output(registry, accept_header, params) + status, headers, output = _bake_output(registry, accept_header, accept_encoding_header, params, disable_compression) # Return output - start_response(status, [header]) + start_response(status, headers) return [output] return prometheus_app @@ -152,8 +160,10 @@ def _get_best_family(address, port): def start_wsgi_server(port: int, addr: str = '0.0.0.0', registry: CollectorRegistry = REGISTRY) -> None: """Starts a WSGI server for prometheus metrics as a daemon thread.""" + class TmpServer(ThreadingWSGIServer): """Copy of ThreadingWSGIServer to update address_family locally""" + TmpServer.address_family, addr = _get_best_family(addr, port) app = make_wsgi_app(registry) httpd = make_server(addr, port, app, TmpServer, handler_class=_SilentHandler) @@ -227,7 +237,7 @@ def sample_line(line): return ''.join(output).encode('utf-8') -def choose_encoder(accept_header: str) -> Tuple[Callable[[CollectorRegistry], bytes], str]: +def choose_formatter(accept_header: str) -> Tuple[Callable[[CollectorRegistry], bytes], str]: accept_header = accept_header or '' for accepted in accept_header.split(','): if accepted.split(';')[0].strip() == 'application/openmetrics-text': @@ -236,6 +246,14 @@ def choose_encoder(accept_header: str) -> Tuple[Callable[[CollectorRegistry], by return generate_latest, CONTENT_TYPE_LATEST +def gzip_accepted(accept_encoding_header: str) -> bool: + accept_encoding_header = accept_encoding_header or '' + for accepted in accept_encoding_header.split(','): + if accepted.split(';')[0].strip().lower() == 'gzip': + return True + return False + + class MetricsHandler(BaseHTTPRequestHandler): """HTTP handler that gives metrics from ``REGISTRY``.""" registry: CollectorRegistry = REGISTRY @@ -244,12 +262,14 @@ def do_GET(self) -> None: # Prepare parameters registry = self.registry accept_header = self.headers.get('Accept') + accept_encoding_header = self.headers.get('Accept-Encoding') params = parse_qs(urlparse(self.path).query) # Bake output - status, header, output = _bake_output(registry, accept_header, params) + status, headers, output = _bake_output(registry, accept_header, accept_encoding_header, params, False) # Return output self.send_response(int(status.split(' ')[0])) - self.send_header(*header) + for header in headers: + self.send_header(*header) self.end_headers() self.wfile.write(output) @@ -289,14 +309,13 @@ def write_to_textfile(path: str, registry: CollectorRegistry) -> None: def _make_handler( - url: str, - method: str, - timeout: Optional[float], - headers: Sequence[Tuple[str, str]], - data: bytes, - base_handler: type, + url: str, + method: str, + timeout: Optional[float], + headers: Sequence[Tuple[str, str]], + data: bytes, + base_handler: type, ) -> Callable[[], None]: - def handle() -> None: request = Request(url, data=data) request.get_method = lambda: method # type: ignore @@ -310,11 +329,11 @@ def handle() -> None: def default_handler( - url: str, - method: str, - timeout: Optional[float], - headers: List[Tuple[str, str]], - data: bytes, + url: str, + method: str, + timeout: Optional[float], + headers: List[Tuple[str, str]], + data: bytes, ) -> Callable[[], None]: """Default handler that implements HTTP/HTTPS connections. @@ -324,11 +343,11 @@ def default_handler( def passthrough_redirect_handler( - url: str, - method: str, - timeout: Optional[float], - headers: List[Tuple[str, str]], - data: bytes, + url: str, + method: str, + timeout: Optional[float], + headers: List[Tuple[str, str]], + data: bytes, ) -> Callable[[], None]: """ Handler that automatically trusts redirect responses for all HTTP methods. @@ -344,13 +363,13 @@ def passthrough_redirect_handler( def basic_auth_handler( - url: str, - method: str, - timeout: Optional[float], - headers: List[Tuple[str, str]], - data: bytes, - username: str = None, - password: str = None, + url: str, + method: str, + timeout: Optional[float], + headers: List[Tuple[str, str]], + data: bytes, + username: str = None, + password: str = None, ) -> Callable[[], None]: """Handler that implements HTTP/HTTPS connections with Basic Auth. @@ -371,12 +390,12 @@ def handle(): def push_to_gateway( - gateway: str, - job: str, - registry: CollectorRegistry, - grouping_key: Optional[Dict[str, Any]] = None, - timeout: Optional[float] = 30, - handler: Callable = default_handler, + gateway: str, + job: str, + registry: CollectorRegistry, + grouping_key: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = 30, + handler: Callable = default_handler, ) -> None: """Push metrics to the given pushgateway. @@ -420,12 +439,12 @@ def push_to_gateway( def pushadd_to_gateway( - gateway: str, - job: str, - registry: Optional[CollectorRegistry], - grouping_key: Optional[Dict[str, Any]] = None, - timeout: Optional[float] = 30, - handler: Callable = default_handler, + gateway: str, + job: str, + registry: Optional[CollectorRegistry], + grouping_key: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = 30, + handler: Callable = default_handler, ) -> None: """PushAdd metrics to the given pushgateway. @@ -451,11 +470,11 @@ def pushadd_to_gateway( def delete_from_gateway( - gateway: str, - job: str, - grouping_key: Optional[Dict[str, Any]] = None, - timeout: Optional[float] = 30, - handler: Callable = default_handler, + gateway: str, + job: str, + grouping_key: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = 30, + handler: Callable = default_handler, ) -> None: """Delete metrics from the given pushgateway. @@ -480,13 +499,13 @@ def delete_from_gateway( def _use_gateway( - method: str, - gateway: str, - job: str, - registry: Optional[CollectorRegistry], - grouping_key: Optional[Dict[str, Any]], - timeout: Optional[float], - handler: Callable, + method: str, + gateway: str, + job: str, + registry: Optional[CollectorRegistry], + grouping_key: Optional[Dict[str, Any]], + timeout: Optional[float], + handler: Callable, ) -> None: gateway_url = urlparse(gateway) # See https://bugs.python.org/issue27657 for details on urlparse in py>=3.7.6. diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 77278bb6..50d76d6d 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -1,3 +1,4 @@ +import gzip from unittest import skipUnless, TestCase from prometheus_client import CollectorRegistry, Counter @@ -74,20 +75,12 @@ def get_all_output(self): break return outputs - def validate_metrics(self, metric_name, help_text, increments): - """ - ASGI app serves the metrics from the provided registry. - """ + def increment_metrics(self, metric_name, help_text, increments): c = Counter(metric_name, help_text, registry=self.registry) for _ in range(increments): c.inc() - # Create and run ASGI app - app = make_asgi_app(self.registry) - self.seed_app(app) - self.send_default_request() - # Assert outputs - outputs = self.get_all_output() - # Assert outputs + + def assert_outputs(self, outputs, metric_name, help_text, increments, compressed): self.assertEqual(len(outputs), 2) response_start = outputs[0] self.assertEqual(response_start['type'], 'http.response.start') @@ -96,14 +89,33 @@ def validate_metrics(self, metric_name, help_text, increments): # Status code self.assertEqual(response_start['status'], 200) # Headers - self.assertEqual(len(response_start['headers']), 1) - self.assertEqual(response_start['headers'][0], (b"Content-Type", CONTENT_TYPE_LATEST.encode('utf8'))) + num_of_headers = 2 if compressed else 1 + self.assertEqual(len(response_start['headers']), num_of_headers) + self.assertIn((b"Content-Type", CONTENT_TYPE_LATEST.encode('utf8')), response_start['headers']) + if compressed: + self.assertIn((b"Content-Encoding", b"gzip"), response_start['headers']) # Body - output = response_body['body'].decode('utf8') + if compressed: + output = gzip.decompress(response_body['body']).decode('utf8') + else: + output = response_body['body'].decode('utf8') self.assertIn("# HELP " + metric_name + "_total " + help_text + "\n", output) self.assertIn("# TYPE " + metric_name + "_total counter\n", output) self.assertIn(metric_name + "_total " + str(increments) + ".0\n", output) + def validate_metrics(self, metric_name, help_text, increments): + """ + ASGI app serves the metrics from the provided registry. + """ + self.increment_metrics(metric_name, help_text, increments) + # Create and run ASGI app + app = make_asgi_app(self.registry) + self.seed_app(app) + self.send_default_request() + # Assert outputs + outputs = self.get_all_output() + self.assert_outputs(outputs, metric_name, help_text, increments, compressed=False) + def test_report_metrics_1(self): self.validate_metrics("counter", "A counter", 2) @@ -115,3 +127,34 @@ def test_report_metrics_3(self): def test_report_metrics_4(self): self.validate_metrics("failed_requests", "Number of failed requests", 7) + + def test_gzip(self): + # Increment a metric. + metric_name = "counter" + help_text = "A counter" + increments = 2 + self.increment_metrics(metric_name, help_text, increments) + app = make_asgi_app(self.registry) + self.seed_app(app) + # Send input with gzip header. + self.scope["headers"] = [(b"accept-encoding", b"gzip")] + self.send_input({"type": "http.request", "body": b""}) + # Assert outputs are compressed. + outputs = self.get_all_output() + self.assert_outputs(outputs, metric_name, help_text, increments, compressed=True) + + def test_gzip_disabled(self): + # Increment a metric. + metric_name = "counter" + help_text = "A counter" + increments = 2 + self.increment_metrics(metric_name, help_text, increments) + # Disable compression explicitly. + app = make_asgi_app(self.registry, disable_compression=True) + self.seed_app(app) + # Send input with gzip header. + self.scope["headers"] = [(b"accept-encoding", b"gzip")] + self.send_input({"type": "http.request", "body": b""}) + # Assert outputs are not compressed. + outputs = self.get_all_output() + self.assert_outputs(outputs, metric_name, help_text, increments, compressed=False) diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py index 050c5add..2ecfd728 100644 --- a/tests/test_wsgi.py +++ b/tests/test_wsgi.py @@ -1,3 +1,4 @@ +import gzip from unittest import TestCase from wsgiref.util import setup_testing_defaults @@ -18,29 +19,41 @@ def capture(self, status, header): self.captured_status = status self.captured_headers = header - def validate_metrics(self, metric_name, help_text, increments): - """ - WSGI app serves the metrics from the provided registry. - """ + def increment_metrics(self, metric_name, help_text, increments): c = Counter(metric_name, help_text, registry=self.registry) for _ in range(increments): c.inc() - # Create and run WSGI app - app = make_wsgi_app(self.registry) - outputs = app(self.environ, self.capture) - # Assert outputs + + def assert_outputs(self, outputs, metric_name, help_text, increments, compressed): self.assertEqual(len(outputs), 1) - output = outputs[0].decode('utf8') + if compressed: + output = gzip.decompress(outputs[0]).decode(encoding="utf-8") + else: + output = outputs[0].decode('utf8') # Status code self.assertEqual(self.captured_status, "200 OK") # Headers - self.assertEqual(len(self.captured_headers), 1) - self.assertEqual(self.captured_headers[0], ("Content-Type", CONTENT_TYPE_LATEST)) + num_of_headers = 2 if compressed else 1 + self.assertEqual(len(self.captured_headers), num_of_headers) + self.assertIn(("Content-Type", CONTENT_TYPE_LATEST), self.captured_headers) + if compressed: + self.assertIn(("Content-Encoding", "gzip"), self.captured_headers) # Body self.assertIn("# HELP " + metric_name + "_total " + help_text + "\n", output) self.assertIn("# TYPE " + metric_name + "_total counter\n", output) self.assertIn(metric_name + "_total " + str(increments) + ".0\n", output) + def validate_metrics(self, metric_name, help_text, increments): + """ + WSGI app serves the metrics from the provided registry. + """ + self.increment_metrics(metric_name, help_text, increments) + # Create and run WSGI app + app = make_wsgi_app(self.registry) + outputs = app(self.environ, self.capture) + # Assert outputs + self.assert_outputs(outputs, metric_name, help_text, increments, compressed=False) + def test_report_metrics_1(self): self.validate_metrics("counter", "A counter", 2) @@ -70,3 +83,32 @@ def test_favicon_path(self): # Try accessing normal paths app(self.environ, self.capture) self.assertEqual(mock.call_count, 1) + + def test_gzip(self): + # Increment a metric + metric_name = "counter" + help_text = "A counter" + increments = 2 + self.increment_metrics(metric_name, help_text, increments) + app = make_wsgi_app(self.registry) + # Try accessing metrics using the gzip Accept-Content header. + gzip_environ = dict(self.environ) + gzip_environ['HTTP_ACCEPT_ENCODING'] = 'gzip' + outputs = app(gzip_environ, self.capture) + # Assert outputs are compressed. + self.assert_outputs(outputs, metric_name, help_text, increments, compressed=True) + + def test_gzip_disabled(self): + # Increment a metric + metric_name = "counter" + help_text = "A counter" + increments = 2 + self.increment_metrics(metric_name, help_text, increments) + # Disable compression explicitly. + app = make_wsgi_app(self.registry, disable_compression=True) + # Try accessing metrics using the gzip Accept-Content header. + gzip_environ = dict(self.environ) + gzip_environ['HTTP_ACCEPT_ENCODING'] = 'gzip' + outputs = app(gzip_environ, self.capture) + # Assert outputs are not compressed. + self.assert_outputs(outputs, metric_name, help_text, increments, compressed=False)