diff --git a/tests/dvslib/dvs_common.py b/tests/dvslib/dvs_common.py index ce7db9859666..ba8ef8c26b50 100644 --- a/tests/dvslib/dvs_common.py +++ b/tests/dvslib/dvs_common.py @@ -1,47 +1,41 @@ -""" - dvs_common contains common infrastructure for writing tests for the - virtual switch. -""" +"""Common infrastructure for writing VS tests.""" import collections import time +from typing import Any, Callable, Tuple + _PollingConfig = collections.namedtuple('PollingConfig', 'polling_interval timeout strict') + class PollingConfig(_PollingConfig): - """ - PollingConfig provides parameters that are used to control the behavior - for polling functions. - - Params: - polling_interval (int): How often to poll, in seconds. - timeout (int): The maximum amount of time to wait, in seconds. - strict (bool): If the strict flag is set, reaching the timeout - will cause tests to fail (e.g. assert False) + """PollingConfig provides parameters that are used to control polling behavior. + + Attributes: + polling_interval (int): How often to poll, in seconds. + timeout (int): The maximum amount of time to wait, in seconds. + strict (bool): If the strict flag is set, reaching the timeout will cause tests to fail. """ - pass -def wait_for_result(polling_function, polling_config): - """ - wait_for_result will periodically run `polling_function` - using the parameters described in `polling_config` and return the - output of the polling function. - - Args: - polling_config (PollingConfig): The parameters to use to poll - the db. - polling_function (Callable[[], (bool, Any)]): The function being - polled. The function takes no arguments and must return a - status which indicates if the function was succesful or - not, as well as some return value. - - Returns: - (bool, Any): If the polling function succeeds, then this method - will return True and the output of the polling function. If it - does not succeed within the provided timeout, it will return False - and whatever the output of the polling function was on the final - attempt. +def wait_for_result( + polling_function: Callable[[], Tuple[bool, Any]], + polling_config: PollingConfig, +) -> Tuple[bool, Any]: + """Run `polling_function` periodically using the specified `polling_config`. + + Args: + polling_function: The function being polled. The function cannot take any arguments and + must return a status which indicates if the function was succesful or not, as well as + some return value. + polling_config: The parameters to use to poll the polling function. + + Returns: + If the polling function succeeds, then this method will return True and the output of the + polling function. + + If it does not succeed within the provided timeout, it will return False and whatever the + output of the polling function was on the final attempt. """ if polling_config.polling_interval == 0: iterations = 1 @@ -57,6 +51,6 @@ def wait_for_result(polling_function, polling_config): time.sleep(polling_config.polling_interval) if polling_config.strict: - assert False, "Operation timed out after {}s".format(polling_config.timeout) + assert False, f"Operation timed out after {polling_config.timeout} seconds" return (False, result) diff --git a/tests/dvslib/dvs_database.py b/tests/dvslib/dvs_database.py index b76196a96a05..6d5829f02a62 100644 --- a/tests/dvslib/dvs_database.py +++ b/tests/dvslib/dvs_database.py @@ -1,76 +1,62 @@ -""" - dvs_database contains utilities for interacting with redis when writing - tests for the virtual switch. -""" +"""Utilities for interacting with redis when writing VS tests.""" +from typing import Dict, List from swsscommon import swsscommon from dvslib.dvs_common import wait_for_result, PollingConfig -class DVSDatabase(object): - """ - DVSDatabase provides access to redis databases on the virtual switch. +class DVSDatabase: + """DVSDatabase provides access to redis databases on the virtual switch. - By default, database operations are configured to use - `DEFAULT_POLLING_CONFIG`. Users can specify their own PollingConfig, - but this shouldn't typically be necessary. + By default, database operations are configured to use `DEFAULT_POLLING_CONFIG`. Users can + specify their own PollingConfig, but this shouldn't typically be necessary. """ + DEFAULT_POLLING_CONFIG = PollingConfig(polling_interval=0.01, timeout=5, strict=True) - def __init__(self, db_id, connector): - """ - Initializes a DVSDatabase instance. + def __init__(self, db_id: int, connector: str): + """Initialize a DVSDatabase instance. - Args: - db_id (int): The integer ID used to identify the given database - instance in redis. - connector (str): The I/O connection used to communicate with - redis (e.g. unix socket, tcp socket, etc.). + Args: + db_id: The integer ID used to identify the given database instance in redis. + connector: The I/O connection used to communicate with + redis (e.g. UNIX socket, TCP socket, etc.). """ - self.db_connection = swsscommon.DBConnector(db_id, connector, 0) - def create_entry(self, table_name, key, entry): - """ - Adds the mapping {`key` -> `entry`} to the specified table. + def create_entry(self, table_name: str, key: str, entry: Dict[str, str]) -> None: + """Add the mapping {`key` -> `entry`} to the specified table. - Args: - table_name (str): The name of the table to add the entry to. - key (str): The key that maps to the entry. - entry (Dict[str, str]): A set of key-value pairs to be stored. + Args: + table_name: The name of the table to add the entry to. + key: The key that maps to the entry. + entry: A set of key-value pairs to be stored. """ - table = swsscommon.Table(self.db_connection, table_name) formatted_entry = swsscommon.FieldValuePairs(list(entry.items())) table.set(key, formatted_entry) - def update_entry(self, table_name, key, entry): - """ - Updates entries of an existing key in the specified table. + def update_entry(self, table_name: str, key: str, entry: Dict[str, str]) -> None: + """Update entry of an existing key in the specified table. - Args: - table_name (str): The name of the table. - key (str): The key that needs to be updated. - entry (Dict[str, str]): A set of key-value pairs to be updated. + Args: + table_name: The name of the table. + key: The key that needs to be updated. + entry: A set of key-value pairs to be updated. """ - table = swsscommon.Table(self.db_connection, table_name) formatted_entry = swsscommon.FieldValuePairs(list(entry.items())) table.set(key, formatted_entry) - def get_entry(self, table_name, key): - """ - Gets the entry stored at `key` in the specified table. + def get_entry(self, table_name: str, key: str) -> Dict[str, str]: + """Get the entry stored at `key` in the specified table. - Args: - table_name (str): The name of the table where the entry is - stored. - key (str): The key that maps to the entry being retrieved. + Args: + table_name: The name of the table where the entry is stored. + key: The key that maps to the entry being retrieved. - Returns: - Dict[str, str]: The entry stored at `key`. If no entry is found, - then an empty Dict will be returned. + Returns: + The entry stored at `key`. If no entry is found, then an empty Dict is returned. """ - table = swsscommon.Table(self.db_connection, table_name) (status, fv_pairs) = table.get(key) @@ -79,284 +65,258 @@ def get_entry(self, table_name, key): return dict(fv_pairs) - def delete_entry(self, table_name, key): - """ - Removes the entry stored at `key` in the specified table. + def delete_entry(self, table_name: str, key: str) -> None: + """Remove the entry stored at `key` in the specified table. - Args: - table_name (str): The name of the table where the entry is - being removed. - key (str): The key that maps to the entry being removed. + Args: + table_name: The name of the table where the entry is being removed. + key: The key that maps to the entry being removed. """ - table = swsscommon.Table(self.db_connection, table_name) table._del(key) # pylint: disable=protected-access - def get_keys(self, table_name): - """ - Gets all of the keys stored in the specified table. + def get_keys(self, table_name: str) -> List[str]: + """Get all of the keys stored in the specified table. - Args: - table_name (str): The name of the table from which to fetch - the keys. + Args: + table_name: The name of the table from which to fetch the keys. - Returns: - List[str]: The keys stored in the table. If no keys are found, - then an empty List will be returned. + Returns: + The keys stored in the table. If no keys are found, then an empty List is returned. """ - table = swsscommon.Table(self.db_connection, table_name) keys = table.getKeys() return keys if keys else [] - def wait_for_entry(self, table_name, key, - polling_config=DEFAULT_POLLING_CONFIG): - """ - Gets the entry stored at `key` in the specified table. This method - will wait for the entry to exist. - - Args: - table_name (str): The name of the table where the entry is - stored. - key (str): The key that maps to the entry being retrieved. - polling_config (PollingConfig): The parameters to use to poll - the db. - - Returns: - Dict[str, str]: The entry stored at `key`. If no entry is found, - then an empty Dict will be returned. - + def wait_for_entry( + self, + table_name: str, + key: str, + polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + ) -> Dict[str, str]: + """Wait for the entry stored at `key` in the specified table to exist and retrieve it. + + Args: + table_name: The name of the table where the entry is stored. + key: The key that maps to the entry being retrieved. + polling_config: The parameters to use to poll the db. + + Returns: + The entry stored at `key`. If no entry is found, then an empty Dict is returned. """ - - def _access_function(): + def __access_function(): fv_pairs = self.get_entry(table_name, key) return (bool(fv_pairs), fv_pairs) - status, result = wait_for_result(_access_function, - self._disable_strict_polling(polling_config)) + status, result = wait_for_result( + __access_function, + self._disable_strict_polling(polling_config)) if not status: assert not polling_config.strict, \ - "Entry not found: key=\"{}\", table=\"{}\"".format(key, table_name) + f"Entry not found: key=\"{key}\", table=\"{table_name}\"" return result - def wait_for_field_match(self, - table_name, - key, - expected_fields, - polling_config=DEFAULT_POLLING_CONFIG): + def wait_for_field_match( + self, + table_name: str, + key: str, + expected_fields: Dict[str, str], + polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + ) -> Dict[str, str]: + """Wait for the entry stored at `key` to have the specified fields and retrieve it. + + This method is useful if you only care about a subset of the fields stored in the + specified entry. + + Args: + table_name: The name of the table where the entry is stored. + key: The key that maps to the entry being checked. + expected_fields: The fields and their values we expect to see in the entry. + polling_config: The parameters to use to poll the db. + + Returns: + The entry stored at `key`. If no entry is found, then an empty Dict is returned. """ - Checks if the provided fields are contained in the entry stored - at `key` in the specified table. This method will wait for the - fields to exist. - - Args: - table_name (str): The name of the table where the entry is - stored. - key (str): The key that maps to the entry being checked. - expected_fields (dict): The fields and their values we expect - to see in the entry. - polling_config (PollingConfig): The parameters to use to poll - the db. - - Returns: - Dict[str, str]: The entry stored at `key`. If no entry is found, - then an empty Dict will be returned. - """ - - def _access_function(): + def __access_function(): fv_pairs = self.get_entry(table_name, key) return (all(fv_pairs.get(k) == v for k, v in expected_fields.items()), fv_pairs) - status, result = wait_for_result(_access_function, - self._disable_strict_polling(polling_config)) + status, result = wait_for_result( + __access_function, + self._disable_strict_polling(polling_config)) if not status: assert not polling_config.strict, \ - "Expected fields not found: expected={}, received={}, \ - key=\"{}\", table=\"{}\"".format(expected_fields, result, key, table_name) + f"Expected fields not found: expected={expected_fields}, \ + received={result}, key=\"{key}\", table=\"{table_name}\"" return result - def wait_for_exact_match(self, - table_name, - key, - expected_entry, - polling_config=DEFAULT_POLLING_CONFIG): - """ - Checks if the provided entry matches the entry stored at `key` - in the specified table. This method will wait for the exact entry - to exist. - - Args: - table_name (str): The name of the table where the entry is - stored. - key (str): The key that maps to the entry being checked. - expected_entry (dict): The entry we expect to see. - polling_config (PollingConfig): The parameters to use to poll - the db. - - Returns: - Dict[str, str]: The entry stored at `key`. If no entry is found, - then an empty Dict will be returned. + def wait_for_exact_match( + self, + table_name: str, + key: str, + expected_entry: Dict[str, str], + polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + ) -> Dict[str, str]: + """Wait for the entry stored at `key` to match `expected_entry` and retrieve it. + + This method is useful if you care about *all* the fields stored in the specfied entry. + + Args: + table_name: The name of the table where the entry is stored. + key: The key that maps to the entry being checked. + expected_entry: The entry we expect to see. + polling_config: The parameters to use to poll the db. + + Returns: + The entry stored at `key`. If no entry is found, then an empty Dict is returned. """ - def _access_function(): + def __access_function(): fv_pairs = self.get_entry(table_name, key) return (fv_pairs == expected_entry, fv_pairs) - status, result = wait_for_result(_access_function, - self._disable_strict_polling(polling_config)) + status, result = wait_for_result( + __access_function, + self._disable_strict_polling(polling_config)) if not status: assert not polling_config.strict, \ - "Exact match not found: expected={}, received={}, \ - key=\"{}\", table=\"{}\"".format(expected_entry, result, key, table_name) + f"Exact match not found: expected={expected_entry}, received={result}, \ + key=\"{key}\", table=\"{table_name}\"" return result - def wait_for_deleted_entry(self, - table_name, - key, - polling_config=DEFAULT_POLLING_CONFIG): - """ - Checks if there is any entry stored at `key` in the specified - table. This method will wait for the entry to be empty. - - Args: - table_name (str): The name of the table being checked. - key (str): The key to be checked. - polling_config (PollingConfig): The parameters to use to poll - the db. - - Returns: - Dict[str, str]: The entry stored at `key`. If no entry is found, - then an empty Dict will be returned. + def wait_for_deleted_entry( + self, + table_name: str, + key: str, + polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + ) -> Dict[str, str]: + """Wait for no entry to exist at `key` in the specified table. + + Args: + table_name: The name of the table being checked. + key: The key to be checked. + polling_config: The parameters to use to poll the db. + + Returns: + The entry stored at `key`. If no entry is found, then an empty Dict is returned. """ - - def _access_function(): + def __access_function(): fv_pairs = self.get_entry(table_name, key) return (not bool(fv_pairs), fv_pairs) - status, result = wait_for_result(_access_function, - self._disable_strict_polling(polling_config)) + status, result = wait_for_result( + __access_function, + self._disable_strict_polling(polling_config)) if not status: assert not polling_config.strict, \ - "Entry still exists: entry={}, key=\"{}\", table=\"{}\""\ - .format(result, key, table_name) + f"Entry still exists: entry={result}, key=\"{key}\", table=\"{table_name}\"" return result - def wait_for_n_keys(self, - table_name, - num_keys, - polling_config=DEFAULT_POLLING_CONFIG): - """ - Gets all of the keys stored in the specified table. This method - will wait for the specified number of keys. - - Args: - table_name (str): The name of the table from which to fetch - the keys. - num_keys (int): The expected number of keys to retrieve from - the table. - polling_config (PollingConfig): The parameters to use to poll - the db. - - Returns: - List[str]: The keys stored in the table. If no keys are found, - then an empty List will be returned. + def wait_for_n_keys( + self, + table_name: str, + num_keys: int, + polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + ) -> List[str]: + """Wait for the specified number of keys to exist in the table. + + Args: + table_name: The name of the table from which to fetch the keys. + num_keys: The expected number of keys to retrieve from the table. + polling_config: The parameters to use to poll the db. + + Returns: + The keys stored in the table. If no keys are found, then an empty List is returned. """ - - def _access_function(): + def __access_function(): keys = self.get_keys(table_name) return (len(keys) == num_keys, keys) - status, result = wait_for_result(_access_function, - self._disable_strict_polling(polling_config)) + status, result = wait_for_result( + __access_function, + self._disable_strict_polling(polling_config)) if not status: assert not polling_config.strict, \ - "Unexpected number of keys: expected={}, received={} ({}), table=\"{}\""\ - .format(num_keys, len(result), result, table_name) + f"Unexpected number of keys: expected={num_keys}, \ + received={len(result)} ({result}), table=\"{table_name}\"" return result - def wait_for_matching_keys(self, - table_name, - expected_keys, - polling_config=DEFAULT_POLLING_CONFIG): - """ - Checks if the specified keys exist in the table. This method - will wait for the keys to exist. - - Args: - table_name (str): The name of the table from which to fetch - the keys. - expected_keys (List[str]): The keys we expect to see in the - table. - polling_config (PollingConfig): The parameters to use to poll - the db. - - Returns: - List[str]: The keys stored in the table. If no keys are found, - then an empty List will be returned. + def wait_for_matching_keys( + self, + table_name: str, + expected_keys: List[str], + polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + ) -> List[str]: + """Wait for the specified keys to exist in the table. + + Args: + table_name: The name of the table from which to fetch the keys. + expected_keys: The keys we expect to see in the table. + polling_config: The parameters to use to poll the db. + + Returns: + The keys stored in the table. If no keys are found, then an empty List is returned. """ - - def _access_function(): + def __access_function(): keys = self.get_keys(table_name) return (all(key in keys for key in expected_keys), keys) - status, result = wait_for_result(_access_function, - self._disable_strict_polling(polling_config)) + status, result = wait_for_result( + __access_function, + self._disable_strict_polling(polling_config)) if not status: assert not polling_config.strict, \ - "Expected keys not found: expected={}, received={}, table=\"{}\""\ - .format(expected_keys, result, table_name) + f"Expected keys not found: expected={expected_keys}, received={result}, \ + table=\"{table_name}\"" return result - def wait_for_deleted_keys(self, - table_name, - deleted_keys, - polling_config=DEFAULT_POLLING_CONFIG): - """ - Checks if the specified keys no longer exist in the table. This - method will wait for the keys to be deleted. - - Args: - table_name (str): The name of the table from which to fetch - the keys. - deleted_keys (List[str]): The keys we expect to be removed from - the table. - polling_config (PollingConfig): The parameters to use to poll - the db. - - Returns: - List[str]: The keys stored in the table. If no keys are found, - then an empty List will be returned. + def wait_for_deleted_keys( + self, + table_name: str, + deleted_keys: List[str], + polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + ) -> List[str]: + """Wait for the specfied keys to no longer exist in the table. + + Args: + table_name: The name of the table from which to fetch the keys. + deleted_keys: The keys we expect to be removed from the table. + polling_config: The parameters to use to poll the db. + + Returns: + The keys stored in the table. If no keys are found, then an empty List is returned. """ - - def _access_function(): + def __access_function(): keys = self.get_keys(table_name) return (all(key not in keys for key in deleted_keys), keys) - status, result = wait_for_result(_access_function, - self._disable_strict_polling(polling_config)) + status, result = wait_for_result( + __access_function, + self._disable_strict_polling(polling_config)) if not status: + expected = [key for key in result if key not in deleted_keys] assert not polling_config.strict, \ - "Unexpected keys found: expected={}, received={}, table=\"{}\""\ - .format(deleted_keys, result, table_name) + f"Unexpected keys found: expected={expected}, received={result}, \ + table=\"{table_name}\"" return result @staticmethod - def _disable_strict_polling(polling_config): + def _disable_strict_polling(polling_config: PollingConfig) -> PollingConfig: disabled_config = PollingConfig(polling_interval=polling_config.polling_interval, timeout=polling_config.timeout, strict=False)