From 8c66f918ec2950d19c51f82575853813c4f72d94 Mon Sep 17 00:00:00 2001 From: Arpad Borsos Date: Wed, 16 Oct 2024 11:32:03 +0200 Subject: [PATCH] Use correct signatures for Celery Task Hooks This explicitly declares and forwards all the arguments to the `on_success/retry/failure` task hooks, as they are documented on https://docs.celeryq.dev/en/main/_modules/celery/app/task.html#Task.on_success The reason being that the Sentry tags set via the `MetricContext` constructor are not making their way to Sentry, possibly because the `kwargs` were misused in the previous hooks, and were overwriting those tag values with `None`. So this should ideally solve that mystery. As a driveby change, I also took the liberty of removing all the deprecated statsd `metrics` calls. All of the relevant metrics also have prometheus equivalents. --- helpers/telemetry.py | 21 +++------ tasks/base.py | 84 +++++++++++++---------------------- tasks/tests/unit/test_base.py | 44 +++--------------- 3 files changed, 44 insertions(+), 105 deletions(-) diff --git a/helpers/telemetry.py b/helpers/telemetry.py index 98af0061a..00d1897eb 100644 --- a/helpers/telemetry.py +++ b/helpers/telemetry.py @@ -68,39 +68,30 @@ def __init__( sentry_sdk.set_tag("owner_id", owner_id) sentry_sdk.set_tag("repo_id", repo_id) sentry_sdk.set_tag("commit_sha", commit_sha) - transaction = sentry_sdk.get_current_scope().transaction - if transaction is not None: - transaction.set_tag("owner_id", owner_id) - transaction.set_tag("repo_id", repo_id) - transaction.set_tag("commit_sha", commit_sha) def populate(self): if self.populated: return - repo = None - commit = None dbsession = get_db_session() if self.repo_id: if not self.owner_id: - repo = ( - dbsession.query(Repository) + self.owner_id = ( + dbsession.query(Repository.ownerid) .filter(Repository.repoid == self.repo_id) - .first() + .first()[0] ) - self.owner_id = repo.ownerid if self.commit_sha and not self.commit_id: - commit = ( - dbsession.query(Commit) + self.commit_id = ( + dbsession.query(Commit.id_) .filter( Commit.repoid == self.repo_id, Commit.commitid == self.commit_sha, ) - .first() + .first()[0] ) - self.commit_id = commit.id_ self.populated = True diff --git a/tasks/base.py b/tasks/base.py index ea549e91c..00f1f2982 100644 --- a/tasks/base.py +++ b/tasks/base.py @@ -24,7 +24,6 @@ log_set_task_id, log_set_task_name, ) -from helpers.metrics import metrics from helpers.telemetry import MetricContext, TimeseriesTimer from helpers.timeseries import timeseries_enabled @@ -51,9 +50,7 @@ def on_timeout(self, soft: bool, timeout: int): res = super().on_timeout(soft, timeout) if not soft: REQUEST_HARD_TIMEOUT_COUNTER.labels(task=self.name).inc() - metrics.incr(f"{self.metrics_prefix}.hardtimeout") REQUEST_TIMEOUT_COUNTER.labels(task=self.name).inc() - metrics.incr(f"{self.metrics_prefix}.timeout") return res @@ -245,7 +242,6 @@ def _emit_queue_metrics(self): enqueued_time = datetime.fromisoformat(created_timestamp) now = datetime.now() delta = now - enqueued_time - metrics.timing(f"{self.metrics_prefix}.time_in_queue", delta) queue_name = self.request.get("delivery_info", {}).get("routing_key", None) time_in_queue_timer = TASK_TIME_IN_QUEUE.labels( @@ -253,12 +249,6 @@ def _emit_queue_metrics(self): ) # TODO is None a valid label value time_in_queue_timer.observe(delta.total_seconds()) - if queue_name: - metrics.timing(f"worker.queues.{queue_name}.time_in_queue", delta) - metrics.timing( - f"{self.metrics_prefix}.{queue_name}.time_in_queue", delta - ) - def run(self, *args, **kwargs): task = get_current_task() @@ -279,39 +269,32 @@ def run(self, *args, **kwargs): owner_id=kwargs.get("ownerid"), ) - with TimeseriesTimer( - metric_context, f"{self.metrics_prefix}.full_runtime", sync=True - ): - with self.task_full_runtime.time(): # Timer isn't tested - with metrics.timer(f"{self.metrics_prefix}.full"): - db_session = get_db_session() - try: - with TimeseriesTimer( - metric_context, - f"{self.metrics_prefix}.core_runtime", - sync=True, - ): - with self.task_core_runtime.time(): # Timer isn't tested - with metrics.timer(f"{self.metrics_prefix}.run"): - return self.run_impl(db_session, *args, **kwargs) - except (DataError, IntegrityError): - log.exception( - "Errors related to the constraints of database happened", - extra=dict(task_args=args, task_kwargs=kwargs), - ) - db_session.rollback() - self._rollback_django() - self.retry() - except SQLAlchemyError as ex: - self._analyse_error(ex, args, kwargs) - db_session.rollback() - self._rollback_django() - self.retry() - finally: - log_set_task_name(None) - log_set_task_id(None) - self.wrap_up_dbsession(db_session) - self._commit_django() + with self.task_full_runtime.time(): # Timer isn't tested + db_session = get_db_session() + try: + with TimeseriesTimer( + metric_context, f"{self.metrics_prefix}.core_runtime", sync=True + ): + with self.task_core_runtime.time(): # Timer isn't tested + return self.run_impl(db_session, *args, **kwargs) + except (DataError, IntegrityError): + log.exception( + "Errors related to the constraints of database happened", + extra=dict(task_args=args, task_kwargs=kwargs), + ) + db_session.rollback() + self._rollback_django() + self.retry() + except SQLAlchemyError as ex: + self._analyse_error(ex, args, kwargs) + db_session.rollback() + self._rollback_django() + self.retry() + finally: + log_set_task_name(None) + log_set_task_id(None) + self.wrap_up_dbsession(db_session) + self._commit_django() def wrap_up_dbsession(self, db_session): """ @@ -352,10 +335,9 @@ def wrap_up_dbsession(self, db_session): ) get_db_session.remove() - def on_retry(self, *args, **kwargs): - res = super().on_retry(*args, **kwargs) + def on_retry(self, exc, task_id, args, kwargs, einfo): + res = super().on_retry(exc, task_id, args, kwargs, einfo) self.task_retry_counter.inc() - metrics.incr(f"{self.metrics_prefix}.retries") metric_context = MetricContext( commit_sha=kwargs.get("commitid"), repo_id=kwargs.get("repoid"), @@ -364,10 +346,9 @@ def on_retry(self, *args, **kwargs): metric_context.log_simple_metric(f"{self.metrics_prefix}.retry", 1.0) return res - def on_success(self, *args, **kwargs): - res = super().on_success(*args, **kwargs) + def on_success(self, retval, task_id, args, kwargs): + res = super().on_success(retval, task_id, args, kwargs) self.task_success_counter.inc() - metrics.incr(f"{self.metrics_prefix}.successes") metric_context = MetricContext( commit_sha=kwargs.get("commitid"), repo_id=kwargs.get("repoid"), @@ -376,13 +357,12 @@ def on_success(self, *args, **kwargs): metric_context.log_simple_metric(f"{self.metrics_prefix}.success", 1.0) return res - def on_failure(self, *args, **kwargs): + def on_failure(self, exc, task_id, args, kwargs, einfo): """ Includes SoftTimeoutLimitException, for example """ - res = super().on_failure(*args, **kwargs) + res = super().on_failure(exc, task_id, args, kwargs, einfo) self.task_failure_counter.inc() - metrics.incr(f"{self.metrics_prefix}.failures") metric_context = MetricContext( commit_sha=kwargs.get("commitid"), repo_id=kwargs.get("repoid"), diff --git a/tasks/tests/unit/test_base.py b/tasks/tests/unit/test_base.py index 6c0018c33..246a65003 100644 --- a/tasks/tests/unit/test_base.py +++ b/tasks/tests/unit/test_base.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta +from datetime import datetime from pathlib import Path from unittest.mock import patch @@ -104,7 +104,6 @@ def test_hard_time_limit_task_from_default_app(self, mocker): @patch("helpers.telemetry.MetricContext.log_simple_metric") def test_sample_run(self, mock_simple_metric, mocker, dbsession): mocked_get_db_session = mocker.patch("tasks.base.get_db_session") - mocked_metrics = mocker.patch("tasks.base.metrics") mock_task_request = mocker.patch("tasks.base.BaseCodecovTask.request") fake_request_values = dict( created_timestamp="2023-06-13 10:00:00.000000", @@ -117,23 +116,6 @@ def test_sample_run(self, mock_simple_metric, mocker, dbsession): task_instance = SampleTask() result = task_instance.run() assert result == {"unusual": "return", "value": ["There"]} - assert mocked_metrics.timing.call_count == 3 - mocked_metrics.timing.assert_has_calls( - [ - call( - "worker.task.test.SampleTask.time_in_queue", - timedelta(seconds=61, microseconds=123), - ), - call( - "worker.queues.my-queue.time_in_queue", - timedelta(seconds=61, microseconds=123), - ), - call( - "worker.task.test.SampleTask.my-queue.time_in_queue", - timedelta(seconds=61, microseconds=123), - ), - ] - ) assert ( REGISTRY.get_sample_value( "worker_tasks_timers_time_in_queue_seconds_sum", @@ -142,10 +124,7 @@ def test_sample_run(self, mock_simple_metric, mocker, dbsession): == 61.000123 ) mock_simple_metric.assert_has_calls( - [ - call("worker.task.test.SampleTask.core_runtime", ANY), - call("worker.task.test.SampleTask.full_runtime", ANY), - ] + [call("worker.task.test.SampleTask.core_runtime", ANY)] ) @patch("tasks.base.BaseCodecovTask._emit_queue_metrics") @@ -329,12 +308,11 @@ def test_run_sqlalchemy_error_rollback(self, mocker, dbsession, celery_app): @pytest.mark.django_db(databases={"default", "timeseries"}) class TestBaseCodecovTaskHooks(object): - def test_sample_task_success(self, celery_app, mocker): + def test_sample_task_success(self, celery_app): class SampleTask(BaseCodecovTask, name="test.SampleTask"): def run_impl(self, dbsession): return {"unusual": "return", "value": ["There"]} - mock_metrics = mocker.patch("tasks.base.metrics.incr") DTask = celery_app.register_task(SampleTask()) task = celery_app.tasks[DTask.name] @@ -354,16 +332,14 @@ def run_impl(self, dbsession): res = k.get() assert res == {"unusual": "return", "value": ["There"]} - mock_metrics.assert_called_with("worker.task.test.SampleTask.successes") assert prom_run_counter_after - prom_run_counter_before == 1 assert prom_success_counter_after - prom_success_counter_before == 1 - def test_sample_task_failure(self, celery_app, mocker): + def test_sample_task_failure(self, celery_app): class FailureSampleTask(BaseCodecovTask, name="test.FailureSampleTask"): def run_impl(self, *args, **kwargs): raise Exception("Whhhhyyyyyyy") - mock_metrics = mocker.patch("tasks.base.metrics.incr") DTask = celery_app.register_task(FailureSampleTask()) task = celery_app.tasks[DTask.name] with pytest.raises(Exception) as exc: @@ -383,24 +359,21 @@ def run_impl(self, *args, **kwargs): assert prom_run_counter_after - prom_run_counter_before == 1 assert prom_failure_counter_after - prom_failure_counter_before == 1 assert exc.value.args == ("Whhhhyyyyyyy",) - mock_metrics.assert_called_with("worker.task.test.FailureSampleTask.failures") - def test_sample_task_retry(self, celery_app, mocker): + def test_sample_task_retry(self): # Unfortunately we cant really call the task with apply().get() # Something happens inside celery as of version 4.3 that makes them # not call on_Retry at all. # best we can do is to call on_retry ourselves and ensure this makes the # metric be called - mock_metrics = mocker.patch("tasks.base.metrics.incr") task = RetrySampleTask() prom_retry_counter_before = REGISTRY.get_sample_value( "worker_task_counts_retries_total", labels={"task": task.name} ) - task.on_retry("exc", "task_id", "args", "kwargs", "einfo") + task.on_retry("exc", "task_id", ("args",), {"kwargs": "foo"}, "einfo") prom_retry_counter_after = REGISTRY.get_sample_value( "worker_task_counts_retries_total", labels={"task": task.name} ) - mock_metrics.assert_called_with("worker.task.test.RetrySampleTask.retries") assert prom_retry_counter_after - prom_retry_counter_before == 1 @@ -435,7 +408,6 @@ def test_sample_task_timeout(self, celery_app, mocker): class SampleTask(BaseCodecovTask, name="test.SampleTask"): pass - mock_metrics = mocker.patch("tasks.base.metrics.incr") DTask = celery_app.register_task(SampleTask()) request = self.xRequest(mocker, DTask.name, celery_app) prom_timeout_counter_before = ( @@ -448,14 +420,12 @@ class SampleTask(BaseCodecovTask, name="test.SampleTask"): prom_timeout_counter_after = REGISTRY.get_sample_value( "worker_task_counts_timeouts_total", labels={"task": DTask.name} ) - mock_metrics.assert_called_with("worker.task.test.SampleTask.timeout") assert prom_timeout_counter_after - prom_timeout_counter_before == 1 def test_sample_task_hard_timeout(self, celery_app, mocker): class SampleTask(BaseCodecovTask, name="test.SampleTask"): pass - mock_metrics = mocker.patch("tasks.base.metrics.incr") DTask = celery_app.register_task(SampleTask()) request = self.xRequest(mocker, DTask.name, celery_app) prom_timeout_counter_before = ( @@ -477,8 +447,6 @@ class SampleTask(BaseCodecovTask, name="test.SampleTask"): prom_hard_timeout_counter_after = REGISTRY.get_sample_value( "worker_task_counts_hard_timeouts_total", labels={"task": DTask.name} ) - mock_metrics.assert_any_call("worker.task.test.SampleTask.hardtimeout") - mock_metrics.assert_any_call("worker.task.test.SampleTask.timeout") assert prom_timeout_counter_after - prom_timeout_counter_before == 1 assert prom_hard_timeout_counter_after - prom_hard_timeout_counter_before == 1