diff --git a/snmp/datadog_checks/snmp/config.py b/snmp/datadog_checks/snmp/config.py index 7c38d9ea2c820..80885a02ff20a 100644 --- a/snmp/datadog_checks/snmp/config.py +++ b/snmp/datadog_checks/snmp/config.py @@ -5,13 +5,24 @@ from collections import defaultdict from typing import Any, Callable, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple, Union -from pyasn1.type.univ import OctetString -from pysnmp import hlapi -from pysnmp.hlapi.asyncore.cmdgen import lcd -from pysnmp.smi import builder, view - from datadog_checks.base import ConfigurationError, is_affirmative +from .models import ( + CommunityData, + ContextData, + DirMibSource, + MibViewController, + ObjectIdentity, + ObjectType, + OctetString, + SnmpEngine, + UdpTransportTarget, + UsmUserData, + hlapi, + lcd, + usmDESPrivProtocol, + usmHMACMD5AuthProtocol, +) from .resolver import OIDResolver from .utils import to_oid_tuple @@ -50,6 +61,7 @@ class ParsedMetricTag(object): __slots__ = ('name', 'symbol') def __init__(self, name, symbol): + # type: (str, str) -> None self.name = name self.symbol = symbol @@ -158,7 +170,7 @@ def __init__( self.refresh_with_profile(profiles[profile], warning, log) self.add_profile_tag(profile) - self._context_data = hlapi.ContextData(*self.get_context_data(instance)) + self._context_data = ContextData(*self.get_context_data(instance)) self._uptime_metric_added = False @@ -192,22 +204,24 @@ def refresh_with_profile(self, profile, warning, log): self.all_oids.extend(tag_oids) def add_profile_tag(self, profile_name): + # type: (str) -> None self.tags.append('snmp_profile:{}'.format(profile_name)) @staticmethod - def create_snmp_engine(mibs_path): + def create_snmp_engine(mibs_path=None): + # type: (str) -> Tuple[SnmpEngine, MibViewController] """ Create a command generator to perform all the snmp query. If mibs_path is not None, load the mibs present in the custom mibs folder. (Need to be in pysnmp format) """ - snmp_engine = hlapi.SnmpEngine() + snmp_engine = SnmpEngine() mib_builder = snmp_engine.getMibBuilder() if mibs_path is not None: - mib_builder.addMibSources(builder.DirMibSource(mibs_path)) + mib_builder.addMibSources(DirMibSource(mibs_path)) - mib_view_controller = view.MibViewController(mib_builder) + mib_view_controller = MibViewController(mib_builder) return snmp_engine, mib_view_controller @@ -219,7 +233,7 @@ def get_transport_target(instance, timeout, retries): """ ip_address = instance['ip_address'] port = int(instance.get('port', 161)) # Default SNMP port - return hlapi.UdpTransportTarget((ip_address, port), timeout=timeout, retries=retries) + return UdpTransportTarget((ip_address, port), timeout=timeout, retries=retries) @staticmethod def get_auth_data(instance): @@ -232,8 +246,8 @@ def get_auth_data(instance): # SNMP v1 - SNMP v2 # See http://snmplabs.com/pysnmp/docs/api-reference.html#pysnmp.hlapi.CommunityData if int(instance.get('snmp_version', 2)) == 1: - return hlapi.CommunityData(instance['community_string'], mpModel=0) - return hlapi.CommunityData(instance['community_string'], mpModel=1) + return CommunityData(instance['community_string'], mpModel=0) + return CommunityData(instance['community_string'], mpModel=1) if 'user' in instance: # SNMP v3 @@ -245,12 +259,12 @@ def get_auth_data(instance): if 'authKey' in instance: auth_key = instance['authKey'] - auth_protocol = hlapi.usmHMACMD5AuthProtocol + auth_protocol = usmHMACMD5AuthProtocol if 'privKey' in instance: priv_key = instance['privKey'] - auth_protocol = hlapi.usmHMACMD5AuthProtocol - priv_protocol = hlapi.usmDESPrivProtocol + auth_protocol = usmHMACMD5AuthProtocol + priv_protocol = usmDESPrivProtocol if 'authProtocol' in instance: auth_protocol = getattr(hlapi, instance['authProtocol']) @@ -258,7 +272,7 @@ def get_auth_data(instance): if 'privProtocol' in instance: priv_protocol = getattr(hlapi, instance['privProtocol']) - return hlapi.UsmUserData(user, auth_key, priv_key, auth_protocol, priv_protocol) + return UsmUserData(user, auth_key, priv_key, auth_protocol, priv_protocol) raise ConfigurationError('An authentication method needs to be provided') @@ -304,34 +318,38 @@ def parse_metrics( metrics, # type: List[Dict[str, Any]] warning, # type: Callable[..., None] log, # type: Callable[..., None] + object_identity_factory=None, # type: Callable[..., ObjectIdentity] # For unit tests purposes. ): # type: (...) -> Tuple[list, list, List[Union[ParsedMetric, ParsedTableMetric]]] """Parse configuration and returns data to be used for SNMP queries. `oids` is a dictionnary of SNMP tables to symbols to query. """ + if object_identity_factory is None: + object_identity_factory = ObjectIdentity + table_oids = {} # type: Dict[Tuple[str, str], Tuple[Any, List[Any]]] parsed_metrics = [] # type: List[Union[ParsedMetric, ParsedTableMetric]] - def extract_symbol(mib, symbol): + def extract_symbol(mib, symbol): # type: ignore if isinstance(symbol, dict): symbol_oid = symbol['OID'] symbol = symbol['name'] self._resolver.register(to_oid_tuple(symbol_oid), symbol) - identity = hlapi.ObjectIdentity(symbol_oid) + identity = object_identity_factory(symbol_oid) else: - identity = hlapi.ObjectIdentity(mib, symbol) + identity = object_identity_factory(mib, symbol) return identity, symbol - def get_table_symbols(mib, table): + def get_table_symbols(mib, table): # type: ignore identity, table = extract_symbol(mib, table) key = (mib, table) if key in table_oids: return table_oids[key][1], table - table_object = hlapi.ObjectType(identity) + table_object = ObjectType(identity) symbols = [] table_oids[key] = (table_object, symbols) @@ -382,7 +400,7 @@ def get_table_symbols(mib, table): column_tags.append((tag_key, column)) try: - object_type = hlapi.ObjectType(identity) + object_type = ObjectType(identity) except Exception as e: warning("Can't generate MIB object for variable : %s\nException: %s", metric, e) else: @@ -417,7 +435,7 @@ def get_table_symbols(mib, table): identity, parsed_metric_name = extract_symbol(metric['MIB'], symbol) try: - symbols.append(hlapi.ObjectType(identity)) + symbols.append(ObjectType(identity)) except Exception as e: warning("Can't generate MIB object for variable : %s\nException: %s", metric, e) @@ -425,7 +443,7 @@ def get_table_symbols(mib, table): parsed_metrics.append(parsed_table_metric) elif 'OID' in metric: - oid_object = hlapi.ObjectType(hlapi.ObjectIdentity(metric['OID'])) + oid_object = ObjectType(object_identity_factory(metric['OID'])) table_oids[metric['OID']] = (oid_object, []) self._resolver.register(to_oid_tuple(metric['OID']), metric['name']) @@ -467,12 +485,12 @@ def parse_metric_tags(self, metric_tags): tag_name = tag['tag'] if 'MIB' in tag: mib = tag['MIB'] - identity = hlapi.ObjectIdentity(mib, symbol) + identity = ObjectIdentity(mib, symbol) else: oid = tag['OID'] - identity = hlapi.ObjectIdentity(oid) + identity = ObjectIdentity(oid) self._resolver.register(to_oid_tuple(oid), symbol) - object_type = hlapi.ObjectType(identity) + object_type = ObjectType(identity) oids.append(object_type) parsed_metric_tags.append(ParsedMetricTag(tag_name, symbol)) return oids, parsed_metric_tags @@ -483,7 +501,7 @@ def add_uptime_metric(self): return # Reference sysUpTimeInstance directly, see http://oidref.com/1.3.6.1.2.1.1.3.0 uptime_oid = '1.3.6.1.2.1.1.3.0' - oid_object = hlapi.ObjectType(hlapi.ObjectIdentity(uptime_oid)) + oid_object = ObjectType(ObjectIdentity(uptime_oid)) self.all_oids.append(oid_object) self._resolver.register(to_oid_tuple(uptime_oid), 'sysUpTimeInstance') diff --git a/snmp/datadog_checks/snmp/exceptions.py b/snmp/datadog_checks/snmp/exceptions.py new file mode 100644 index 0000000000000..53c5d8b6cb505 --- /dev/null +++ b/snmp/datadog_checks/snmp/exceptions.py @@ -0,0 +1,11 @@ +# (C) Datadog, Inc. 2020-present +# All rights reserved +# Licensed under Simplified BSD License (see LICENSE) +""" +Re-export PySNMP exceptions that we use, so that we can access them from a single module. +""" + +from pysnmp.error import PySnmpError +from pysnmp.smi.error import SmiError + +__all__ = ['PySnmpError', 'SmiError'] diff --git a/snmp/datadog_checks/snmp/models.py b/snmp/datadog_checks/snmp/models.py new file mode 100644 index 0000000000000..dc1ffcb7c60db --- /dev/null +++ b/snmp/datadog_checks/snmp/models.py @@ -0,0 +1,66 @@ +# (C) Datadog, Inc. 2020-present +# All rights reserved +# Licensed under Simplified BSD License (see LICENSE) +""" +Re-export PyASN1/PySNMP types and classes that we use, so that we can access them from a single module. +""" + +from pyasn1.type.base import Asn1Type +from pyasn1.type.univ import OctetString +from pysnmp import hlapi +from pysnmp.hlapi import ( + CommunityData, + ContextData, + ObjectIdentity, + ObjectType, + SnmpEngine, + UdpTransportTarget, + UsmUserData, + usmDESPrivProtocol, + usmHMACMD5AuthProtocol, +) +from pysnmp.hlapi.asyncore.cmdgen import lcd +from pysnmp.hlapi.transport import AbstractTransportTarget +from pysnmp.proto.rfc1902 import Counter32, Counter64, Gauge32, Integer, Integer32, ObjectName, Unsigned32 +from pysnmp.smi.builder import DirMibSource, MibBuilder +from pysnmp.smi.exval import endOfMibView, noSuchInstance, noSuchObject +from pysnmp.smi.view import MibViewController + +# Additional types that are not part of the SNMP protocol (see RFC 2856). +CounterBasedGauge64, ZeroBasedCounter64 = MibBuilder().importSymbols( + 'HCNUM-TC', 'CounterBasedGauge64', 'ZeroBasedCounter64' +) + +# Cleanup. +del MibBuilder + +__all__ = [ + 'AbstractTransportTarget', + 'Asn1Type', + 'DirMibSource', + 'CommunityData', + 'ContextData', + 'CounterBasedGauge64', + 'endOfMibView', + 'hlapi', + 'lcd', + 'MibViewController', + 'noSuchInstance', + 'noSuchObject', + 'ObjectIdentity', + 'ObjectName', + 'ObjectType', + 'OctetString', + 'SnmpEngine', + 'UdpTransportTarget', + 'usmDESPrivProtocol', + 'usmHMACMD5AuthProtocol', + 'UsmUserData', + 'ZeroBasedCounter64', + 'Counter32', + 'Counter64', + 'Gauge32', + 'Unsigned32', + 'Integer', + 'Integer32', +] diff --git a/snmp/datadog_checks/snmp/resolver.py b/snmp/datadog_checks/snmp/resolver.py index 307de9f2c19ec..ef70b1009f829 100644 --- a/snmp/datadog_checks/snmp/resolver.py +++ b/snmp/datadog_checks/snmp/resolver.py @@ -4,7 +4,7 @@ from collections import defaultdict -from pysnmp import hlapi +from .models import ObjectIdentity class OIDTreeNode(object): @@ -87,7 +87,7 @@ def resolve_oid(self, oid): # if enforce_constraints is false, then MIB resolution has not been done yet # so we need to do it manually. We have to specify the mibs that we will need # to resolve the name. - oid_to_resolve = hlapi.ObjectIdentity(oid_tuple) + oid_to_resolve = ObjectIdentity(oid_tuple) result_oid = oid_to_resolve.resolveWithMib(self._mib_view_controller) _, metric, indexes = result_oid.getMibSymbol() return metric, tuple(index.prettyPrint() for index in indexes) diff --git a/snmp/datadog_checks/snmp/snmp.py b/snmp/datadog_checks/snmp/snmp.py index 3b1f2ced0814a..c8442e524f3e4 100644 --- a/snmp/datadog_checks/snmp/snmp.py +++ b/snmp/datadog_checks/snmp/snmp.py @@ -12,12 +12,7 @@ from concurrent import futures from typing import Any, DefaultDict, Dict, List, Optional, Tuple, Union -import pysnmp.proto.rfc1902 as snmp_type -from pyasn1.codec.ber import decoder -from pysnmp import hlapi -from pysnmp.error import PySnmpError -from pysnmp.smi import builder -from pysnmp.smi.exval import noSuchInstance, noSuchObject +from pyasn1.codec.ber.decoder import decode as pyasn1_decode from six import iteritems from datadog_checks.base import AgentCheck, ConfigurationError, is_affirmative @@ -26,30 +21,35 @@ from .commands import snmp_bulk, snmp_get, snmp_getnext from .compat import read_persistent_cache, total_time_to_temporal_percent, write_persistent_cache from .config import InstanceConfig, ParsedMetric, ParsedMetricTag, ParsedTableMetric -from .utils import OIDPrinter, get_profile_definition, oid_pattern_specificity, recursively_expand_base_profiles - -# Additional types that are not part of the SNMP protocol. cf RFC 2856 -CounterBasedGauge64, ZeroBasedCounter64 = builder.MibBuilder().importSymbols( - 'HCNUM-TC', 'CounterBasedGauge64', 'ZeroBasedCounter64' +from .exceptions import PySnmpError +from .models import ( + Counter32, + Counter64, + CounterBasedGauge64, + Gauge32, + Integer, + Integer32, + ObjectIdentity, + ObjectType, + Unsigned32, + ZeroBasedCounter64, + noSuchInstance, + noSuchObject, ) +from .utils import OIDPrinter, get_profile_definition, oid_pattern_specificity, recursively_expand_base_profiles # Metric type that we support -SNMP_COUNTERS = frozenset([snmp_type.Counter32.__name__, snmp_type.Counter64.__name__, ZeroBasedCounter64.__name__]) +SNMP_COUNTERS = frozenset([Counter32.__name__, Counter64.__name__, ZeroBasedCounter64.__name__]) SNMP_GAUGES = frozenset( - [ - snmp_type.Gauge32.__name__, - snmp_type.Unsigned32.__name__, - CounterBasedGauge64.__name__, - snmp_type.Integer.__name__, - snmp_type.Integer32.__name__, - ] + [Gauge32.__name__, Unsigned32.__name__, CounterBasedGauge64.__name__, Integer.__name__, Integer32.__name__] ) DEFAULT_OID_BATCH_SIZE = 10 def reply_invalid(oid): + # type: (Any) -> bool return noSuchInstance.isSameTypeWith(oid) or noSuchObject.isSameTypeWith(oid) @@ -236,7 +236,7 @@ def fetch_oids(self, config, oids, enforce_constraints): result_oid, value = var if reply_invalid(value): oid_tuple = result_oid.asTuple() - missing_results.append(hlapi.ObjectType(hlapi.ObjectIdentity(oid_tuple))) + missing_results.append(ObjectType(ObjectIdentity(oid_tuple))) else: all_binds.append(var) @@ -246,11 +246,13 @@ def fetch_oids(self, config, oids, enforce_constraints): self.log.debug( 'Running SNMP command getNext on OIDS: %s', OIDPrinter(missing_results, with_values=False) ) - binds = snmp_getnext( - config, - missing_results, - lookup_mib=enforce_constraints, - ignore_nonincreasing_oid=self.ignore_nonincreasing_oid, + binds = list( + snmp_getnext( + config, + missing_results, + lookup_mib=enforce_constraints, + ignore_nonincreasing_oid=self.ignore_nonincreasing_oid, + ) ) self.log.debug('Returned vars: %s', OIDPrinter(binds, with_values=True)) all_binds.extend(binds) @@ -270,7 +272,7 @@ def fetch_sysobject_oid(self, config): # type: (InstanceConfig) -> str """Return the sysObjectID of the instance.""" # Reference sysObjectID directly, see http://oidref.com/1.3.6.1.2.1.1.2 - oid = hlapi.ObjectType(hlapi.ObjectIdentity((1, 3, 6, 1, 2, 1, 1, 2, 0))) + oid = ObjectType(ObjectIdentity((1, 3, 6, 1, 2, 1, 1, 2, 0))) self.log.debug('Running SNMP command on OID: %r', OIDPrinter((oid,), with_values=False)) var_binds = snmp_get(config, [oid], lookup_mib=False) self.log.debug('Returned vars: %s', OIDPrinter(var_binds, with_values=True)) @@ -327,9 +329,13 @@ def check(self, instance): if self._thread is None: self._start_discovery() + executor = self._executor + if executor is None: + raise RuntimeError("Expected executor be set") + sent = [] for host, discovered in list(config.discovered_instances.items()): - future = self._executor.submit(self._check_with_config, discovered) + future = executor.submit(self._check_with_config, discovered) sent.append(future) future.add_done_callback(functools.partial(self._check_config_done, host)) futures.wait(sent) @@ -341,6 +347,7 @@ def check(self, instance): self._check_with_config(config) def _check_config_done(self, host, future): + # type: (str, futures.Future) -> None config = self._config if future.result(): config.failing_instances[host] += 1 @@ -535,7 +542,7 @@ def submit_metric(self, name, snmp_value, forced_type, tags): if snmp_class == 'Opaque': # Try support for floats try: - value = float(decoder.decode(bytes(snmp_value))[0]) + value = float(pyasn1_decode(bytes(snmp_value))[0]) except Exception: pass else: diff --git a/snmp/datadog_checks/snmp/utils.py b/snmp/datadog_checks/snmp/utils.py index 6a20dd2e228b4..2ff826dbb1612 100644 --- a/snmp/datadog_checks/snmp/utils.py +++ b/snmp/datadog_checks/snmp/utils.py @@ -2,15 +2,13 @@ # All rights reserved # Licensed under Simplified BSD License (see LICENSE) import os -from typing import Any, Dict, Tuple +from typing import Any, Dict, Mapping, Sequence, Tuple, Union import yaml -from pysnmp import hlapi -from pysnmp.proto.rfc1902 import ObjectName -from pysnmp.smi.error import SmiError -from pysnmp.smi.exval import endOfMibView, noSuchInstance from .compat import get_config +from .exceptions import SmiError +from .models import ObjectName, ObjectType, endOfMibView, noSuchInstance def get_profile_definition(profile): @@ -37,6 +35,7 @@ def get_profile_definition(profile): def _get_profiles_root(): + # type: () -> str # NOTE: this separate helper function exists for mocking purposes. confd = get_config('confd_path') return os.path.join(confd, 'snmp.d', 'profiles') @@ -109,11 +108,12 @@ class OIDPrinter(object): """ def __init__(self, oids, with_values): + # type: (Union[Mapping, Sequence], bool) -> None self.oids = oids self.with_values = with_values def oid_str(self, oid): - # type: (hlapi.ObjectType) -> str + # type: (ObjectType) -> str """Display an OID object (or MIB symbol), even if the object is not initialized by PySNMP. Output: @@ -131,7 +131,7 @@ def oid_str(self, oid): return arg def oid_str_value(self, oid): - # type: (hlapi.ObjectType) -> str + # type: (ObjectType) -> str """Display an OID object and its associated value. Output: @@ -181,6 +181,7 @@ def oid_dict(self, key, value): return "'{}': {}".format(key, displayed) def __str__(self): + # type: () -> str if isinstance(self.oids, dict): return '{{{}}}'.format(', '.join(self.oid_dict(key, value) for (key, value) in self.oids.items())) if self.with_values: diff --git a/snmp/tests/test_unit.py b/snmp/tests/test_unit.py index 26fedb286df86..f14f9b3670e43 100644 --- a/snmp/tests/test_unit.py +++ b/snmp/tests/test_unit.py @@ -5,7 +5,7 @@ import os import time from concurrent import futures -from typing import List +from typing import Any, List import mock import pytest @@ -15,38 +15,43 @@ from datadog_checks.dev import temp_dir from datadog_checks.snmp import SnmpCheck from datadog_checks.snmp.config import InstanceConfig +from datadog_checks.snmp.models import ObjectIdentity from datadog_checks.snmp.resolver import OIDTrie from datadog_checks.snmp.utils import oid_pattern_specificity, recursively_expand_base_profiles from . import common -from .utils import mock_profiles_root +from .utils import ClassInstantiationSpy, mock_profiles_root pytestmark = pytest.mark.unit -@mock.patch("datadog_checks.snmp.config.hlapi") @mock.patch("datadog_checks.snmp.config.lcd") -def test_parse_metrics(lcd_mock, hlapi_mock): +def test_parse_metrics(lcd_mock): + # type: (Any) -> None lcd_mock.configure.return_value = ('addr', None) instance = common.generate_instance_config(common.SUPPORTED_METRIC_TYPES) check = SnmpCheck('snmp', {}, [instance]) # Unsupported metric - metrics = [{"foo": "bar"}] + metrics = [{"foo": "bar"}] # type: list config = InstanceConfig( {"ip_address": "127.0.0.1", "community_string": "public", "metrics": [{"OID": "1.2.3", "name": "foo"}]}, warning=check.warning, log=check.log, ) - hlapi_mock.reset_mock() + + object_identity_factory = ClassInstantiationSpy(ObjectIdentity) + with pytest.raises(Exception): config.parse_metrics(metrics, check.warning, check.log) # Simple OID metrics = [{"OID": "1.2.3", "name": "foo"}] - table, _, _ = config.parse_metrics(metrics, check.warning, check.log) + table, _, _ = config.parse_metrics( + metrics, check.warning, check.log, object_identity_factory=object_identity_factory + ) assert len(table) == 1 - hlapi_mock.ObjectIdentity.assert_called_once_with("1.2.3") - hlapi_mock.reset_mock() + object_identity_factory.assert_called_once_with("1.2.3") + object_identity_factory.reset() # MIB with no symbol or table metrics = [{"MIB": "foo_mib"}] @@ -55,10 +60,12 @@ def test_parse_metrics(lcd_mock, hlapi_mock): # MIB with symbol metrics = [{"MIB": "foo_mib", "symbol": "foo"}] - table, _, _ = config.parse_metrics(metrics, check.warning, check.log) + table, _, _ = config.parse_metrics( + metrics, check.warning, check.log, object_identity_factory=object_identity_factory + ) assert len(table) == 1 - hlapi_mock.ObjectIdentity.assert_called_once_with("foo_mib", "foo") - hlapi_mock.reset_mock() + object_identity_factory.assert_called_once_with("foo_mib", "foo") + object_identity_factory.reset() # MIB with table, no symbols metrics = [{"MIB": "foo_mib", "table": "foo"}] @@ -67,11 +74,13 @@ def test_parse_metrics(lcd_mock, hlapi_mock): # MIB with table and symbols metrics = [{"MIB": "foo_mib", "table": "foo", "symbols": ["foo", "bar"]}] - table, _, _ = config.parse_metrics(metrics, check.warning, check.log) + table, _, _ = config.parse_metrics( + metrics, check.warning, check.log, object_identity_factory=object_identity_factory + ) assert len(table) == 2 - hlapi_mock.ObjectIdentity.assert_any_call("foo_mib", "foo") - hlapi_mock.ObjectIdentity.assert_any_call("foo_mib", "bar") - hlapi_mock.reset_mock() + object_identity_factory.assert_any_call("foo_mib", "foo") + object_identity_factory.assert_any_call("foo_mib", "bar") + object_identity_factory.reset() # MIB with table, symbols, bad metrics_tags metrics = [{"MIB": "foo_mib", "table": "foo", "symbols": ["foo", "bar"], "metric_tags": [{}]}] @@ -85,47 +94,55 @@ def test_parse_metrics(lcd_mock, hlapi_mock): # Table with manual OID metrics = [{"MIB": "foo_mib", "table": "foo", "symbols": [{"OID": "1.2.3", "name": "foo"}]}] - table, _, _ = config.parse_metrics(metrics, check.warning, check.log) + table, _, _ = config.parse_metrics( + metrics, check.warning, check.log, object_identity_factory=object_identity_factory + ) assert len(table) == 1 - hlapi_mock.ObjectIdentity.assert_any_call("1.2.3") - hlapi_mock.reset_mock() + object_identity_factory.assert_any_call("1.2.3") + object_identity_factory.reset() # MIB with table, symbols, metrics_tags index metrics = [ {"MIB": "foo_mib", "table": "foo", "symbols": ["foo", "bar"], "metric_tags": [{"tag": "foo", "index": "1"}]} ] - table, _, _ = config.parse_metrics(metrics, check.warning, check.log) + table, _, _ = config.parse_metrics( + metrics, check.warning, check.log, object_identity_factory=object_identity_factory + ) assert len(table) == 2 - hlapi_mock.ObjectIdentity.assert_any_call("foo_mib", "foo") - hlapi_mock.ObjectIdentity.assert_any_call("foo_mib", "bar") - hlapi_mock.reset_mock() + object_identity_factory.assert_any_call("foo_mib", "foo") + object_identity_factory.assert_any_call("foo_mib", "bar") + object_identity_factory.reset() # MIB with table, symbols, metrics_tags column metrics = [ {"MIB": "foo_mib", "table": "foo", "symbols": ["foo", "bar"], "metric_tags": [{"tag": "foo", "column": "baz"}]} ] - table, _, _ = config.parse_metrics(metrics, check.warning, check.log) + table, _, _ = config.parse_metrics( + metrics, check.warning, check.log, object_identity_factory=object_identity_factory + ) assert len(table) == 3 - hlapi_mock.ObjectIdentity.assert_any_call("foo_mib", "foo") - hlapi_mock.ObjectIdentity.assert_any_call("foo_mib", "bar") - hlapi_mock.ObjectIdentity.assert_any_call("foo_mib", "baz") - hlapi_mock.reset_mock() + object_identity_factory.assert_any_call("foo_mib", "foo") + object_identity_factory.assert_any_call("foo_mib", "bar") + object_identity_factory.assert_any_call("foo_mib", "baz") + object_identity_factory.reset() # MIB with table, symbols, metrics_tags column with OID metrics = [ { "MIB": "foo_mib", - "table": "foo", + "table": "foo_table", "symbols": ["foo", "bar"], "metric_tags": [{"tag": "foo", "column": {"name": "baz", "OID": "1.5.6"}}], } ] - table, _, _ = config.parse_metrics(metrics, check.warning, check.log) + table, _, _ = config.parse_metrics( + metrics, check.warning, check.log, object_identity_factory=object_identity_factory + ) assert len(table) == 3 - hlapi_mock.ObjectIdentity.assert_any_call("foo_mib", "foo") - hlapi_mock.ObjectIdentity.assert_any_call("foo_mib", "bar") - hlapi_mock.ObjectIdentity.assert_any_call("1.5.6") - hlapi_mock.reset_mock() + object_identity_factory.assert_any_call("1.5.6") + object_identity_factory.assert_any_call("foo_mib", "foo") + object_identity_factory.assert_any_call("foo_mib", "bar") + object_identity_factory.reset() def test_ignore_ip_addresses(): diff --git a/snmp/tests/utils.py b/snmp/tests/utils.py index 7c56cbf75d72a..286f7be710368 100644 --- a/snmp/tests/utils.py +++ b/snmp/tests/utils.py @@ -3,15 +3,45 @@ # Licensed under Simplified BSD License (see LICENSE) import contextlib -from typing import Iterator +from typing import Any, Generic, Iterator, List, Type, TypeVar import mock from datadog_checks.snmp import utils +T = TypeVar("T") + @contextlib.contextmanager def mock_profiles_root(root): # type: (str) -> Iterator[None] with mock.patch.object(utils, '_get_profiles_root', return_value=root): yield + + +class ClassInstantiationSpy(Generic[T]): + """ + Record instantiations of a class. + """ + + def __init__(self, cls): + # type: (Type[T]) -> None + self.cls = cls + self.calls = [] # type: List[tuple] + + def __call__(self, *args): + # type: (*Any) -> T + self.calls.append(args) + return self.cls(*args) + + def assert_called_once_with(self, *args): + # type: (*Any) -> None + assert self.calls.count(args) == 1 + + def assert_any_call(self, *args): + # type: (*Any) -> None + assert args in self.calls + + def reset(self): + # type: () -> None + self.calls = []