Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional argument for overriding get_tls_context() parameters #8275

Merged
merged 11 commits into from
Jan 7, 2021
25 changes: 21 additions & 4 deletions datadog_checks_base/datadog_checks/base/checks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,20 @@
import unicodedata
from collections import defaultdict, deque
from os.path import basename
from typing import TYPE_CHECKING, Any, Callable, DefaultDict, Deque, Dict, List, Optional, Sequence, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
AnyStr,
Callable,
DefaultDict,
Deque,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
)

import yaml
from six import binary_type, iteritems, text_type
Expand Down Expand Up @@ -309,15 +322,19 @@ def http(self):

return self._http

def get_tls_context(self, refresh=False):
# type: (bool) -> ssl.SSLContext
def get_tls_context(self, refresh=False, overrides=None):
# type: (bool, Dict[AnyStr, Any]) -> ssl.SSLContext
"""
Creates and cache an SSLContext instance based on user configuration.
Note that user configuration can be overridden by using `overrides`.
This should only be applied to older integration that manually set config values.

Since: Agent 7.24
"""
if not hasattr(self, '_tls_context_wrapper'):
self._tls_context_wrapper = TlsContextWrapper(self.instance or {}, self.TLS_CONFIG_REMAPPER)
self._tls_context_wrapper = TlsContextWrapper(
self.instance or {}, self.TLS_CONFIG_REMAPPER, overrides=overrides
)

if refresh:
self._tls_context_wrapper.refresh_tls_context()
Expand Down
13 changes: 11 additions & 2 deletions datadog_checks_base/datadog_checks/base/utils/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import ssl
from copy import deepcopy
from typing import TYPE_CHECKING, Any, AnyStr, Dict

from six import iteritems
Expand Down Expand Up @@ -35,10 +36,18 @@
class TlsContextWrapper(object):
__slots__ = ('logger', 'config', 'tls_context')

def __init__(self, instance, remapper=None):
# type: (InstanceType, Dict[AnyStr, Dict[AnyStr, Any]]) -> None
def __init__(self, instance, remapper=None, overrides=None):
yzhan289 marked this conversation as resolved.
Show resolved Hide resolved
# type: (InstanceType, Dict[AnyStr, Dict[AnyStr, Any]], Dict[AnyStr, Any]) -> None
default_fields = dict(STANDARD_FIELDS)

# Override existing config options if there exists any overrides
instance = deepcopy(instance)

if overrides:
for overridden_field, data in iteritems(overrides):
if instance.get(overridden_field):
instance[overridden_field] = data
yzhan289 marked this conversation as resolved.
Show resolved Hide resolved

# Populate with the default values
config = {field: instance.get(field, value) for field, value in iteritems(default_fields)}
for field in STANDARD_FIELDS:
Expand Down
40 changes: 40 additions & 0 deletions datadog_checks_base/tests/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,43 @@ def test_client_key_expanded(self):
with patch('ssl.SSLContext'), patch('os.path.expanduser') as mock_expand:
check.get_tls_context()
mock_expand.assert_called_with('~/foo')


class TestTLSContextOverrides:
yzhan289 marked this conversation as resolved.
Show resolved Hide resolved
def test_override_context(self):
instance = {'tls_cert': 'foo', 'tls_private_key': 'bar'}
check = AgentCheck('test', {}, [instance])

overrides = {'tls_cert': 'not_foo'}
with patch('ssl.SSLContext'):
context = check.get_tls_context(overrides=overrides) # type: MagicMock
context.load_cert_chain.assert_called_with('not_foo', keyfile='bar', password=None)

def test_override_context_empty(self):
instance = {'tls_cert': 'foo', 'tls_private_key': 'bar'}
check = AgentCheck('test', {}, [instance])

overrides = {}
with patch('ssl.SSLContext'):
context = check.get_tls_context(overrides=overrides) # type: MagicMock
context.load_cert_chain.assert_called_with('foo', keyfile='bar', password=None)

def test_override_context_wrapper_config(self):
instance = {'tls_verify': True}
overrides = {'tls_verify': False}
tls = TlsContextWrapper(instance, overrides=overrides)
assert tls.config['tls_verify'] is False
assert instance['tls_verify'] is True # Overrides should not affect the original instance

def test_override_context_wrapper_config_empty(self):
instance = {'tls_verify': True}
overrides = {}
tls = TlsContextWrapper(instance, overrides=overrides)
assert tls.config['tls_verify'] is True

def test_override_non_exist_instance_config(self):
instance = {'tls_verify': True}
overrides = {'fake_config': 'foo'}
tls = TlsContextWrapper(instance, overrides=overrides)
assert instance.get('fake_config') is None
assert tls.config['tls_verify'] is True