diff --git a/datadog_checks_base/datadog_checks/base/utils/db/utils.py b/datadog_checks_base/datadog_checks/base/utils/db/utils.py index be3e64fa96985c..6c897031d1750f 100644 --- a/datadog_checks_base/datadog_checks/base/utils/db/utils.py +++ b/datadog_checks_base/datadog_checks/base/utils/db/utils.py @@ -4,12 +4,18 @@ import datetime import decimal import logging +import os import socket +import threading import time +from concurrent.futures.thread import ThreadPoolExecutor from itertools import chain from cachetools import TTLCache +from datadog_checks.base import is_affirmative +from datadog_checks.base.log import get_check_logger + try: import datadog_agent except ImportError: @@ -75,8 +81,8 @@ def __init__(self, rate_limit_s): """ :param rate_limit_s: rate limit in seconds """ - self.rate_limit_s = rate_limit_s - self.period_s = 1.0 / rate_limit_s if rate_limit_s > 0 else 0 + self.rate_limit_s = max(rate_limit_s, 0) + self.period_s = 1.0 / self.rate_limit_s if self.rate_limit_s > 0 else 0 self.last_event = 0 def sleep(self): @@ -144,3 +150,118 @@ def default_json_event_encoding(o): if isinstance(o, (datetime.date, datetime.datetime)): return o.isoformat() raise TypeError + + +class DBMAsyncJob(object): + executor = ThreadPoolExecutor() + + """ + Runs Async Jobs + """ + + def __init__( + self, + check, + config_host=None, + min_collection_interval=15, + dbms="TODO", + rate_limit=1, + run_sync=False, + enabled=True, + expected_db_exceptions=(), + shutdown_callback=None, + job_name=None, + ): + self._check = check + self._config_host = config_host + self._min_collection_interval = min_collection_interval + # map[dbname -> psycopg connection] + self._log = get_check_logger() + self._job_loop_future = None + self._cancel_event = threading.Event() + self._tags = None + self._tags_no_db = None + self._run_sync = None + self._db_hostname = None + self._last_check_run = 0 + self._shutdown_callback = shutdown_callback + self._dbms = dbms + self._rate_limiter = ConstantRateLimiter(rate_limit) + self._run_sync = run_sync + self._enabled = enabled + self._expected_db_exceptions = expected_db_exceptions + self._job_name = job_name + + def cancel(self): + self._cancel_event.set() + + def run_job_loop(self, tags): + """ + :param tags: + :return: + """ + if not self._enabled: + self._log.debug("[job=%s] Job not enabled.", self._job_name) + return + if not self._db_hostname: + self._db_hostname = resolve_db_host(self._config_host) + self._tags = tags + self._tags_str = ','.join(self._tags) + self._job_tags = self._tags + ["job:{}".format(self._job_name)] + self._job_tags_str = ','.join(self._job_tags) + self._last_check_run = time.time() + if self._run_sync or is_affirmative(os.environ.get('DBM_THREADED_JOB_RUN_SYNC', "false")): + self._log.debug("Running threaded job synchronously. job=%s", self._job_name) + self._run_job_rate_limited() + elif self._job_loop_future is None or not self._job_loop_future.running(): + self._job_loop_future = DBMAsyncJob.executor.submit(self._job_loop) + else: + self._log.debug("Job loop already running. job=%s", self._job_name) + + def _job_loop(self): + try: + self._log.info("[%s] Starting job loop", self._job_tags_str) + while True: + if self._cancel_event.isSet(): + self._log.info("[%s] Job loop cancelled", self._job_tags_str) + self._check.count("dd.{}.async_job.cancel".format(self._dbms), 1, tags=self._job_tags) + break + if time.time() - self._last_check_run > self._min_collection_interval * 2: + self._log.info("[%s] Job loop stopping due to check inactivity", self._job_tags_str) + self._check.count("dd.{}.async_job.inactive_stop".format(self._dbms), 1, tags=self._job_tags) + break + self._run_job_rate_limited() + except self._expected_db_exceptions as e: + self._log.warning( + "[%s] Job loop database error: %s", + self._job_tags_str, + e, + exc_info=self._log.getEffectiveLevel() == logging.DEBUG, + ) + self._check.count( + "dd.{}.async_job.error".format(self._dbms), + 1, + tags=self._job_tags + ["error:database-{}".format(type(e))], + ) + except Exception as e: + self._log.exception("[%s] Job loop crash", self._job_tags_str) + self._check.count( + "dd.{}.async_job.error".format(self._dbms), + 1, + tags=self._job_tags + ["error:crash-{}".format(type(e))], + ) + finally: + self._log.info("[%s] Shutting down job loop", self._job_tags_str) + if self._shutdown_callback: + self._shutdown_callback() + + def _set_rate_limit(self, rate_limit): + if self._rate_limiter.rate_limit_s != rate_limit: + self._rate_limiter = ConstantRateLimiter(rate_limit) + + def _run_job_rate_limited(self): + self.run_job() + self._rate_limiter.sleep() + + def run_job(self): + raise NotImplementedError() diff --git a/datadog_checks_base/tests/test_db_util.py b/datadog_checks_base/tests/test_db_util.py index 7235335019b3c9..c39b2224e788c4 100644 --- a/datadog_checks_base/tests/test_db_util.py +++ b/datadog_checks_base/tests/test_db_util.py @@ -3,8 +3,12 @@ # All rights reserved # Licensed under a 3-clause BSD style license (see LICENSE) import time +from concurrent.futures.thread import ThreadPoolExecutor -from datadog_checks.base.utils.db.utils import ConstantRateLimiter, RateLimitingTTLCache +import pytest + +from datadog_checks.base import AgentCheck +from datadog_checks.base.utils.db.utils import ConstantRateLimiter, DBMAsyncJob, RateLimitingTTLCache def test_constant_rate_limiter(): @@ -37,3 +41,98 @@ def test_ratelimiting_ttl_cache(): for i in range(5, 10): assert cache.acquire(i), "cache should be empty again so these keys should go in OK" + + +class TestDBExcepption(BaseException): + pass + + +class TestJob(DBMAsyncJob): + def __init__(self, check, run_sync=False, enabled=True, rate_limit=10, min_collection_interval=15): + super(TestJob, self).__init__( + check, + run_sync=run_sync, + enabled=enabled, + expected_db_exceptions=(TestDBExcepption,), + min_collection_interval=min_collection_interval, + config_host="test-host", + dbms="test-dbms", + rate_limit=rate_limit, + job_name="test-job", + shutdown_callback=self.test_shutdown, + ) + + def test_shutdown(self): + self._check.count("dbm.async_job_test.shutdown", 1) + + def run_job(self): + self._check.count("dbm.async_job_test.run_job", 1) + + +def test_dbm_async_job(): + check = AgentCheck() + TestJob(check) + + +@pytest.fixture(autouse=True) +def stop_orphaned_threads(): + # make sure we shut down any orphaned threads and create a new Executor for each test + DBMAsyncJob.executor.shutdown(wait=True) + DBMAsyncJob.executor = ThreadPoolExecutor() + + +@pytest.mark.parametrize("enabled", [True, False]) +def test_dbm_async_job_enabled(enabled): + check = AgentCheck() + job = TestJob(check, enabled=enabled) + job.run_job_loop([]) + if enabled: + assert job._job_loop_future is not None + job.cancel() + job._job_loop_future.result() + else: + assert job._job_loop_future is None + + +def test_dbm_async_job_cancel(aggregator): + job = TestJob(AgentCheck()) + tags = ["hello:there"] + job.run_job_loop(tags) + job.cancel() + job._job_loop_future.result() + assert not job._job_loop_future.running(), "thread should be stopped" + # if the thread doesn't start until after the cancel signal is set then the db connection will never + # be created in the first place + expected_tags = tags + ['job:test-job'] + aggregator.assert_metric("dd.test-dbms.async_job.cancel", tags=expected_tags) + aggregator.assert_metric("dbm.async_job_test.shutdown") + + +def test_dbm_async_job_run_sync(aggregator): + job = TestJob(AgentCheck(), run_sync=True) + job.run_job_loop([]) + assert job._job_loop_future is None + aggregator.assert_metric("dbm.async_job_test.run_job") + + +def test_dbm_async_job_rate_limit(aggregator): + # test the main collection loop rate limit + rate_limit = 10 + sleep_time = 1 + + job = TestJob(AgentCheck(), rate_limit=rate_limit) + job.run_job_loop([]) + + time.sleep(sleep_time) + max_collections = int(rate_limit * sleep_time) + 1 + job.cancel() + + metrics = aggregator.metrics("dbm.async_job_test.run_job") + assert max_collections / 2.0 <= len(metrics) <= max_collections + + +def test_dbm_async_job_inactive_stop(aggregator): + job = TestJob(AgentCheck(), rate_limit=10, min_collection_interval=1) + job.run_job_loop([]) + job._job_loop_future.result() + aggregator.assert_metric("dd.test-dbms.async_job.inactive_stop", tags=['job:test-job'])