From df7548749926a58c8fd7be6d7c1f49ae3145be64 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 30 Apr 2020 00:24:00 -0700 Subject: [PATCH] adds param validation to adv/scan/conn intervals, adds type annotations, fixes couple potential issues --- blatann/examples/centeral_uart_service.py | 2 +- blatann/examples/peripheral_uart_service.py | 8 +- blatann/gap/advertising.py | 36 +++++- blatann/gap/scanning.py | 48 +++++-- blatann/gatt/service_discovery.py | 2 +- blatann/nrf/nrf_types/gap.py | 48 +++++++ blatann/peer.py | 136 ++++++++++++++++---- blatann/waitables/connection_waitable.py | 2 +- blatann/waitables/event_waitable.py | 2 +- blatann/waitables/waitable.py | 28 +++- 10 files changed, 261 insertions(+), 51 deletions(-) diff --git a/blatann/examples/centeral_uart_service.py b/blatann/examples/centeral_uart_service.py index 610b62a..6d727ce 100644 --- a/blatann/examples/centeral_uart_service.py +++ b/blatann/examples/centeral_uart_service.py @@ -76,7 +76,7 @@ def main(serial_port): # Set scan duration for 4 seconds ble_device.scanner.set_default_scan_params(timeout_seconds=4) - ble_device.set_default_peripheral_connection_params(10, 30, 4000) + ble_device.set_default_peripheral_connection_params(7.5, 15, 4000) logger.info("Scanning for peripherals advertising UUID {}".format(nordic_uart.NORDIC_UART_SERVICE_UUID)) target_address = None diff --git a/blatann/examples/peripheral_uart_service.py b/blatann/examples/peripheral_uart_service.py index e538a3a..370741d 100644 --- a/blatann/examples/peripheral_uart_service.py +++ b/blatann/examples/peripheral_uart_service.py @@ -25,7 +25,7 @@ def on_connect(peer, event_args): :param event_args: None """ if peer: - logger.info("Connected to peer") + logger.info("Connected to peer, initiating MTU exchange") peer.exchange_mtu() else: logger.warning("Connection timed out") @@ -52,6 +52,8 @@ def on_mtu_size_update(peer, event_args): :type event_args: blatann.event_args.MtuSizeUpdatedEventArgs """ logger.info("MTU size updated from {} to {}".format(event_args.previous_mtu_size, event_args.current_mtu_size)) + # Request that the connection parameters be re-negotiated using our preferred parameters + peer.update_connection_parameters() def on_data_rx(service, data): @@ -92,7 +94,7 @@ def main(serial_port): # Configure the client to prefer the max MTU size ble_device.client.preferred_mtu_size = ble_device.max_mtu_size - ble_device.client.set_connection_parameters(10, 30, 4000) + ble_device.client.set_connection_parameters(7.5, 15, 4000) # Advertise the service UUID adv_data = advertising.AdvertisingData(flags=0x06, local_name="Nordic UART Server") @@ -115,4 +117,4 @@ def main(serial_port): if __name__ == '__main__': - main("COM8") + main("COM7") diff --git a/blatann/gap/advertising.py b/blatann/gap/advertising.py index 7b7a08a..1ce5fef 100644 --- a/blatann/gap/advertising.py +++ b/blatann/gap/advertising.py @@ -11,6 +11,9 @@ AdvertisingMode = nrf_types.BLEGapAdvType +MIN_ADVERTISING_INTERVAL_MS = nrf_types.adv_interval_range.min +MAX_ADVERTISING_INTERVAL_MS = nrf_types.adv_interval_range.max + class Advertiser(object): # Constant used to indicate that the BLE device should advertise indefinitely, until @@ -48,8 +51,28 @@ def on_advertising_timeout(self): @property def is_advertising(self): + """ + Current state of advertising + :return: + """ return self._is_advertising + @property + def min_interval_ms(self) -> float: + """ + The minimum allowed advertising interval, in millseconds. + This is defined by the Bluetooth specification. + """ + return MIN_ADVERTISING_INTERVAL_MS + + @property + def max_interval_ms(self) -> float: + """ + The maximum allowed advertising interval, in milliseconds. + This is defined by the Bluetooth specification. + """ + return MAX_ADVERTISING_INTERVAL_MS + def set_advertise_data(self, advertise_data=AdvertisingData(), scan_response=AdvertisingData()): """ Sets the advertising and scan response data which will be broadcasted to peers during advertising @@ -79,11 +102,13 @@ def set_default_advertise_params(self, advertise_interval_ms, timeout_seconds, a """ Sets the default advertising parameters so they do not need to be specified on each start - :param advertise_interval_ms: The advertising interval, in milliseconds - :param timeout_seconds: How long to advertise for before timing out, in seconds + :param advertise_interval_ms: The advertising interval, in milliseconds. + Should be a multiple of 0.625ms, otherwise it'll be rounded down to the nearest 0.625ms + :param timeout_seconds: How long to advertise for before timing out, in seconds. For no timeout, use ADVERTISE_FOREVER (0) :param advertise_mode: The mode the advertiser should use :type advertise_mode: AdvertisingMode """ + nrf_types.adv_interval_range.validate(advertise_interval_ms) self._advertise_interval = advertise_interval_ms self._timeout = timeout_seconds self._advertise_mode = advertise_mode @@ -92,8 +117,9 @@ def start(self, adv_interval_ms=None, timeout_sec=None, auto_restart=None, adver """ Starts advertising with the given parameters. If none given, will use the default - :param adv_interval_ms: The interval at which to send out advertise packets, in milliseconds - :param timeout_sec: The duration which to advertise for + :param adv_interval_ms: The interval at which to send out advertise packets, in milliseconds. + Should be a multiple of 0.625ms, otherwise it'll be rounde down to the nearest 0.625ms + :param timeout_sec: The duration which to advertise for. For no timeout, use ADVERTISE_FOREVER (0) :param auto_restart: Flag indicating that advertising should restart automatically when the timeout expires, or when the client disconnects :param advertise_mode: The mode the advertiser should use @@ -105,6 +131,8 @@ def start(self, adv_interval_ms=None, timeout_sec=None, auto_restart=None, adver self._stop() if adv_interval_ms is None: adv_interval_ms = self._advertise_interval + else: + nrf_types.adv_interval_range.validate(adv_interval_ms) if timeout_sec is None: timeout_sec = self._timeout if advertise_mode is None: diff --git a/blatann/gap/scanning.py b/blatann/gap/scanning.py index 1a58618..87c1337 100644 --- a/blatann/gap/scanning.py +++ b/blatann/gap/scanning.py @@ -8,8 +8,34 @@ logger = logging.getLogger(__name__) +MIN_SCAN_INTERVAL_MS = nrf_types.scan_interval_range.min +MAX_SCAN_INTERVAL_MS = nrf_types.scan_interval_range.max +MIN_SCAN_WINDOW_MS = nrf_types.scan_window_range.min +MAX_SCAN_WINDOW_MS = nrf_types.scan_window_range.max +MIN_SCAN_TIMEOUT_S = nrf_types.scan_timeout_range.min +MAX_SCAN_TIMEOUT_S = nrf_types.scan_timeout_range.max + + class ScanParameters(nrf_types.BLEGapScanParams): - pass + def validate(self): + self._validate(self.window_ms, self.interval_ms, self.timeout_s) + + def update(self, window_ms, interval_ms, timeout_s, active): + self._validate(window_ms, interval_ms, timeout_s) + self.window_ms = window_ms + self.interval_ms = interval_ms + self.timeout_s = timeout_s, + self.active = active + + def _validate(self, window_ms, interval_ms, timeout_s): + # Check against absolute limits + nrf_types.scan_window_range.validate(window_ms) + nrf_types.scan_interval_range.validate(interval_ms) + if timeout_s: + nrf_types.scan_timeout_range.validate(timeout_s) + # Verify that the window is not larger than the interval + if window_ms > interval_ms: + raise ValueError(f"Window cannot be greater than the interval (window: {window_ms}, interval: {interval_ms}") class Scanner(object): @@ -19,7 +45,7 @@ def __init__(self, ble_device): """ self.ble_device = ble_device self._default_scan_params = ScanParameters(200, 150, 10) - self.scanning = False + self._is_scanning = False ble_device.ble_driver.event_subscribe(self._on_adv_report, nrf_events.GapEvtAdvReport) ble_device.ble_driver.event_subscribe(self._on_timeout_event, nrf_events.GapEvtTimeout) self.scan_report = ScanReportCollection() @@ -42,6 +68,10 @@ def on_scan_timeout(self) -> Event[Scanner, ScanReportCollection]: """ return self._on_scan_timeout + @property + def is_scanning(self) -> bool: + return self._is_scanning + def set_default_scan_params(self, interval_ms=200, window_ms=150, timeout_seconds=10, active_scanning=True): """ Sets the default scan parameters so they do not have to be specified each time a scan is started. @@ -53,10 +83,7 @@ def set_default_scan_params(self, interval_ms=200, window_ms=150, timeout_second :param timeout_seconds: How long to advertise for, in seconds :param active_scanning: Whether or not to fetch scan response packets from advertisers """ - self._default_scan_params.interval_ms = interval_ms - self._default_scan_params.window_ms = window_ms - self._default_scan_params.timeout_s = timeout_seconds - self._default_scan_params.active = active_scanning + self._default_scan_params.update(window_ms, interval_ms, timeout_seconds, active_scanning) def start_scan(self, scan_parameters=None, clear_scan_reports=True): """ @@ -74,15 +101,18 @@ def start_scan(self, scan_parameters=None, clear_scan_reports=True): self.scan_report = ScanReportCollection() if not scan_parameters: scan_parameters = self._default_scan_params + else: + # Make sure the scan parameters are valid + scan_parameters.validate() self.ble_device.ble_driver.ble_gap_scan_start(scan_parameters) - self.scanning = True + self._is_scanning = True return scan_waitable.ScanFinishedWaitable(self.ble_device) def stop(self): """ Stops an active scan """ - self.scanning = False + self._is_scanning = False try: self.ble_device.ble_driver.ble_gap_scan_stop() @@ -98,5 +128,5 @@ def _on_timeout_event(self, driver, event): :type event: nrf_events.GapEvtTimeout """ if event.src == nrf_events.BLEGapTimeoutSrc.scan: - self.scanning = False + self._is_scanning = False self._on_scan_timeout.notify(self.ble_device, self.scan_report) diff --git a/blatann/gatt/service_discovery.py b/blatann/gatt/service_discovery.py index 1ef02e6..fd57067 100644 --- a/blatann/gatt/service_discovery.py +++ b/blatann/gatt/service_discovery.py @@ -355,7 +355,7 @@ def __init__(self, ble_device, peer): @property def on_discovery_complete(self): """ - :rtype: Event + :rtype: Event[blatann.peer.Peripheral, DatabaseDiscoveryCompleteEventArgs] """ return self._on_discovery_complete diff --git a/blatann/nrf/nrf_types/gap.py b/blatann/nrf/nrf_types/gap.py index 142c0ff..582b0bf 100644 --- a/blatann/nrf/nrf_types/gap.py +++ b/blatann/nrf/nrf_types/gap.py @@ -9,6 +9,54 @@ logger = logging.getLogger(__name__) +class TimeRange(object): + + def __init__(self, name, val_min, val_max, unit_ms_conversion, divisor=1.0, units="ms"): + self._name = name + self._units = units + self._min = util.units_to_msec(val_min, unit_ms_conversion) / divisor + self._max = util.units_to_msec(val_max, unit_ms_conversion) / divisor + + @property + def name(self) -> str: + return self._name + + @property + def min(self) -> float: + return self._min + + @property + def max(self) -> float: + return self._max + + @property + def units(self) -> str: + return self._units + + def is_in_range(self, value): + return self._min <= value <= self._max + + def validate(self, value): + if value < self._min: + raise ValueError(f"Minimum {self.name} is {self._min}{self.units} (Got {value})") + if value > self._max: + raise ValueError(f"Maximum {self.name} is {self._max}{self.units} (Got {value})") + + +adv_interval_range = TimeRange("Advertising Interval", + driver.BLE_GAP_ADV_INTERVAL_MIN, driver.BLE_GAP_ADV_INTERVAL_MAX, util.UNIT_0_625_MS) +scan_window_range = TimeRange("Scan Window", + driver.BLE_GAP_SCAN_WINDOW_MIN, driver.BLE_GAP_SCAN_WINDOW_MAX, util.UNIT_0_625_MS) +scan_interval_range = TimeRange("Scan Interval", + driver.BLE_GAP_SCAN_INTERVAL_MIN, driver.BLE_GAP_SCAN_INTERVAL_MAX, util.UNIT_0_625_MS) +scan_timeout_range = TimeRange("Scan Timeout", + driver.BLE_GAP_SCAN_TIMEOUT_MIN, driver.BLE_GAP_SCAN_TIMEOUT_MAX, util.UNIT_10_MS, 1000.0, "s") +conn_interval_range = TimeRange("Connection Interval", + driver.BLE_GAP_CP_MIN_CONN_INTVL_MIN, driver.BLE_GAP_CP_MAX_CONN_INTVL_MAX, util.UNIT_1_25_MS) +conn_timeout_range = TimeRange("Connection Timeout", + driver.BLE_GAP_CP_CONN_SUP_TIMEOUT_MIN, driver.BLE_GAP_CP_CONN_SUP_TIMEOUT_MAX, util.UNIT_10_MS) + + class BLEGapAdvParams(object): def __init__(self, interval_ms, timeout_s, advertising_type=BLEGapAdvType.connectable_undirected): self.interval_ms = interval_ms diff --git a/blatann/peer.py b/blatann/peer.py index b23b5bb..a4c3bc1 100644 --- a/blatann/peer.py +++ b/blatann/peer.py @@ -8,7 +8,10 @@ from blatann.gatt import gattc, service_discovery, MTU_SIZE_DEFAULT, MTU_SIZE_MINIMUM from blatann.nrf import nrf_events from blatann.nrf.nrf_types.enums import BLE_CONN_HANDLE_INVALID -from blatann.waitables import connection_waitable, event_waitable +from blatann.nrf.nrf_types import conn_interval_range, conn_timeout_range +from blatann.waitables.waitable import EmptyWaitable +from blatann.waitables.connection_waitable import DisconnectionWaitable +from blatann.waitables.event_waitable import EventWaitable from blatann.event_args import * logger = logging.getLogger(__name__) @@ -25,9 +28,62 @@ class PeerAddress(nrf_events.BLEGapAddr): class ConnectionParameters(nrf_events.BLEGapConnParams): + """ + Represents the connection parameters that are sent during negotiation. This includes + the preferred min/max interval range, timeout, and slave latency + """ def __init__(self, min_conn_interval_ms, max_conn_interval_ms, timeout_ms, slave_latency=0): # TODO: Parameter validation super(ConnectionParameters, self).__init__(min_conn_interval_ms, max_conn_interval_ms, timeout_ms, slave_latency) + self.validate() + + def validate(self): + conn_interval_range.validate(self.min_conn_interval_ms) + conn_interval_range.validate(self.max_conn_interval_ms) + conn_timeout_range.validate(self.conn_sup_timeout_ms) + if self.min_conn_interval_ms > self.max_conn_interval_ms: + raise ValueError(f"Minimum connection interval must be <= max connection interval " + f"(Min: {self.min_conn_interval_ms} Max: {self.max_conn_interval_ms}") + + +class ActiveConnectionParameters(object): + """ + Represents the connection parameters that are currently in use with a peer device. + This is similar to ConnectionParameters with the sole difference being + the connection interval is not a min/max range but a single number + """ + def __init__(self, conn_params: ConnectionParameters): + self._interval_ms = conn_params.min_conn_interval_ms + self._timeout_ms = conn_params.conn_sup_timeout_ms + self._slave_latency = conn_params.slave_latency + + def __repr__(self): + return str(self) + + def __str__(self): + return f"ConnectionParams({self._interval_ms}ms/{self._slave_latency}/{self._timeout_ms}ms)" + + @property + def interval_ms(self) -> float: + """ + The connection interval, in milliseconds + """ + return self._interval_ms + + @property + def timeout_ms(self) -> float: + """ + The connection timeout, in milliseconds + """ + return self._timeout_ms + + @property + def slave_latency(self) -> int: + """ + The slave latency (the number of connection intervals the slave is allowed to skip before being + required to respond) + """ + return self._slave_latency DEFAULT_CONNECTION_PARAMS = ConnectionParameters(15, 30, 4000, 0) @@ -50,8 +106,8 @@ def __init__(self, ble_device, role, connection_params=DEFAULT_CONNECTION_PARAMS """ self._ble_device = ble_device self._role = role - self._ideal_connection_params = connection_params - self._current_connection_params = DEFAULT_CONNECTION_PARAMS + self._preferred_connection_params = connection_params + self._current_connection_params = ActiveConnectionParameters(connection_params) self.conn_handle = BLE_CONN_HANDLE_INVALID self.peer_address = "", self.connection_state = PeerState.DISCONNECTED @@ -62,6 +118,8 @@ def __init__(self, ble_device, role, connection_params=DEFAULT_CONNECTION_PARAMS self._mtu_size = MTU_SIZE_DEFAULT self._preferred_mtu_size = MTU_SIZE_DEFAULT self._negotiated_mtu_size = None + self._disconnection_reason = nrf_events.BLEHci.local_host_terminated_connection + self._connection_based_driver_event_handlers = {} self._connection_handler_lock = threading.Lock() @@ -110,6 +168,21 @@ def is_previously_bonded(self): """ return self.security.is_previously_bonded + @property + def preferred_connection_params(self) -> ConnectionParameters: + """ + Returns the connection parameters that were negotiated for this peer + """ + return self._preferred_connection_params + + @property + def active_connection_params(self) -> ActiveConnectionParameters: + """ + Gets the active connection parameters in use with the peer. + If the peer is disconnected, this will return the connection parameters last used + """ + return self._current_connection_params + @property def mtu_size(self): """ @@ -189,38 +262,42 @@ def on_mtu_size_updated(self) -> Event[Peer, MtuSizeUpdatedEventArgs]: Public Methods """ - def disconnect(self, status_code=nrf_events.BLEHci.remote_user_terminated_connection): + def disconnect(self, status_code=nrf_events.BLEHci.remote_user_terminated_connection) -> DisconnectionWaitable: """ Disconnects from the peer, giving the optional status code. - Returns a waitable that will fire when the disconnection is complete + Returns a waitable that will fire when the disconnection is complete. + If the peer is already disconnected, the waitable will fire immediately :param status_code: The HCI Status code to send back to the peer :return: A waitable that will fire when the peer is disconnected - :rtype: connection_waitable.DisconnectionWaitable """ if self.connection_state != PeerState.CONNECTED: - return + return EmptyWaitable(self, self._disconnection_reason) self._ble_device.ble_driver.ble_gap_disconnect(self.conn_handle, status_code) return self._disconnect_waitable def set_connection_parameters(self, min_connection_interval_ms, max_connection_interval_ms, connection_timeout_ms, slave_latency=0): """ - Sets the connection parameters for the peer and starts the connection parameter update process + Sets the connection parameters for the peer and starts the connection parameter update process (if connected) :param min_connection_interval_ms: The minimum acceptable connection interval, in milliseconds :param max_connection_interval_ms: The maximum acceptable connection interval, in milliseconds :param connection_timeout_ms: The connection timeout, in milliseconds :param slave_latency: The slave latency allowed """ - self._ideal_connection_params = ConnectionParameters(min_connection_interval_ms, max_connection_interval_ms, - connection_timeout_ms, slave_latency) - if not self.connected: - return - # Do stuff to set the connection parameters - self._ble_device.ble_driver.ble_gap_conn_param_update(self.conn_handle, self._ideal_connection_params) + self._preferred_connection_params = ConnectionParameters(min_connection_interval_ms, max_connection_interval_ms, + connection_timeout_ms, slave_latency) + if self.connected: + self.update_connection_parameters() + + def update_connection_parameters(self): + """ + Starts the process to re-negotiate the connection parameters using the previously-set connection parameters + """ + self._ble_device.ble_driver.ble_gap_conn_param_update(self.conn_handle, self._preferred_connection_params) - def exchange_mtu(self, mtu_size=None): + def exchange_mtu(self, mtu_size=None) -> EventWaitable[Peer, MtuSizeUpdatedEventArgs]: """ Initiates the MTU Exchange sequence with the peer device. @@ -229,7 +306,6 @@ def exchange_mtu(self, mtu_size=None): :param mtu_size: Optional MTU size to use. If provided, it will also updated the preferred MTU size :return: A waitable that will fire when the MTU exchange completes - :rtype: event_waitable.EventWaitable """ # If the MTU size has already been negotiated we need to use the same value # as the previous exchange (Vol 3, Part F 3.4.2.2) @@ -241,7 +317,7 @@ def exchange_mtu(self, mtu_size=None): self._negotiated_mtu_size = self.preferred_mtu_size self._ble_device.ble_driver.ble_gattc_exchange_mtu_req(self.conn_handle, self._negotiated_mtu_size) - return event_waitable.EventWaitable(self._on_mtu_exchange_complete) + return EventWaitable(self._on_mtu_exchange_complete) """ Internal Library Methods @@ -255,9 +331,9 @@ def peer_connected(self, conn_handle, peer_address, connection_params): self.peer_address = peer_address self._mtu_size = MTU_SIZE_DEFAULT self._negotiated_mtu_size = None - self._disconnect_waitable = connection_waitable.DisconnectionWaitable(self) + self._disconnect_waitable = DisconnectionWaitable(self) self.connection_state = PeerState.CONNECTED - self._current_connection_params = connection_params + self._current_connection_params = ActiveConnectionParameters(connection_params) self._ble_device.ble_driver.event_subscribe(self._on_disconnect_event, nrf_events.GapEvtDisconnected) self.driver_event_subscribe(self._on_connection_param_update, nrf_events.GapEvtConnParamUpdate, nrf_events.GapEvtConnParamUpdateRequest) @@ -318,6 +394,7 @@ def _on_disconnect_event(self, driver, event): return self.conn_handle = BLE_CONN_HANDLE_INVALID self.connection_state = PeerState.DISCONNECTED + self._disconnection_reason = event.reason self._on_disconnect.notify(self, DisconnectionEventArgs(event.reason)) with self._connection_handler_lock: @@ -334,11 +411,11 @@ def _on_connection_param_update(self, driver, event): if not self.connected or self.conn_handle != event.conn_handle: return if isinstance(event, nrf_events.GapEvtConnParamUpdateRequest): - logger.debug("[{}] Conn Params updating to {}".format(self.conn_handle, self._ideal_connection_params)) - self._ble_device.ble_driver.ble_gap_conn_param_update(self.conn_handle, self._ideal_connection_params) + logger.debug("[{}] Conn Params updating to {}".format(self.conn_handle, self._preferred_connection_params)) + self._ble_device.ble_driver.ble_gap_conn_param_update(self.conn_handle, self._preferred_connection_params) else: logger.debug("[{}] Updated to {}".format(self.conn_handle, event.conn_params)) - self._current_connection_params = event.conn_params + self._current_connection_params = ActiveConnectionParameters(event.conn_params) def _validate_mtu_size(self, mtu_size): if mtu_size < MTU_SIZE_MINIMUM: @@ -393,17 +470,23 @@ def __init__(self, ble_device, peer_address, connection_params=DEFAULT_CONNECTIO self._discoverer = service_discovery.DatabaseDiscoverer(ble_device, self) @property - def database(self): + def database(self) -> gattc.GattcDatabase: """ Gets the database on the peripheral. NOTE: This is not useful until services are discovered first :return: The database instance - :rtype: gattc.GattcDatabase """ return self._db - def discover_services(self): + @property + def on_database_discovery_complete(self) -> Event[Peripheral, DatabaseDiscoveryCompleteEventArgs]: + """ + Event that is triggered when database discovery has completed + """ + return self._discoverer.on_discovery_complete + + def discover_services(self) -> EventWaitable[Peripheral, DatabaseDiscoveryCompleteEventArgs]: """ Starts the database discovery process of the peripheral. This will discover all services, characteristics, and descriptors on the remote database. @@ -411,10 +494,9 @@ def discover_services(self): Waitable returns 2 parameters: (Peripheral this, DatabaseDiscoveryCompleteEventArgs event args) :return: a Waitable that will fire when service discovery is complete - :rtype: event_waitable.EventWaitable """ self._discoverer.start() - return event_waitable.EventWaitable(self._discoverer.on_discovery_complete) + return EventWaitable(self._discoverer.on_discovery_complete) class Client(Peer): diff --git a/blatann/waitables/connection_waitable.py b/blatann/waitables/connection_waitable.py index 4768635..697bdf3 100644 --- a/blatann/waitables/connection_waitable.py +++ b/blatann/waitables/connection_waitable.py @@ -79,7 +79,7 @@ def __init__(self, connected_peer): :type ble_device: blatann.device.BleDevice :type connected_peer: blatann.peer.Peer """ - super(DisconnectionWaitable, self).__init__() + super(DisconnectionWaitable, self).__init__(n_args=2) if not connected_peer: raise InvalidStateException("Peer already disconnected") connected_peer.on_disconnect.register(self._on_disconnect) diff --git a/blatann/waitables/event_waitable.py b/blatann/waitables/event_waitable.py index 879786b..1772208 100644 --- a/blatann/waitables/event_waitable.py +++ b/blatann/waitables/event_waitable.py @@ -11,7 +11,7 @@ def __init__(self, event): """ :type event: blatann.event_type.Event """ - super(EventWaitable, self).__init__() + super(EventWaitable, self).__init__(n_args=2) self._event = event self._event.register(self._on_event) diff --git a/blatann/waitables/waitable.py b/blatann/waitables/waitable.py index 2b82d6e..ebe1a66 100644 --- a/blatann/waitables/waitable.py +++ b/blatann/waitables/waitable.py @@ -3,9 +3,12 @@ class Waitable(object): - def __init__(self): + def __init__(self, n_args=1): self._queue = queue.Queue() self._callback = None + self._n_args = n_args + if n_args < 1: + raise ValueError() def wait(self, timeout=None, exception_on_timeout=True): try: @@ -18,9 +21,9 @@ def wait(self, timeout=None, exception_on_timeout=True): if exception_on_timeout: raise TimeoutError("Timed out waiting for event to occur. " "Waitable type: {}".format(self.__class__.__name__)) - # TODO: This will fail if the waitable implementation normally returns more than one value and - # the caller tries to unpack - return None + if self._n_args == 1: + return None + return [None] * self._n_args def then(self, func_to_execute): self._callback = func_to_execute @@ -38,3 +41,20 @@ def _notify(self, *results): class GenericWaitable(Waitable): def notify(self, *results): self._notify(*results) + + +class EmptyWaitable(Waitable): + """ + Waitable class which will immediately return the args provided when waited on + or when a callback function is registered + """ + def __init__(self, *args): + super(EmptyWaitable, self).__init__(len(args)) + self._args = args + + def wait(self, timeout=None, exception_on_timeout=True): + return self._args + + def then(self, func_to_execute): + func_to_execute(*self._args) + return self