Skip to content

Commit

Permalink
Merge pull request #31 from ThomasGerstenberg/integrated-tests
Browse files Browse the repository at this point in the history
adds tests for scanning, minor updates to test infrastructure
  • Loading branch information
ThomasGerstenberg authored May 2, 2020
2 parents 889437f + 269e8cf commit bfb4d4c
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 42 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ with v4.1.1 of [pc-ble-driver](https://github.com/NordicSemiconductor/pc-ble-dri
- [X] License
- [ ] Unit Tests
- [ ] Integration Tests
- In progress. Advertising and Scanning mostly done


The library aims to support both event-driven and procedural program styles. It takes similar paradigms from C#/.NET's event function signatures,
Expand Down
43 changes: 35 additions & 8 deletions tests/integrated/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import logging
import time
from typing import Optional
from unittest import TestCase, SkipTest
from unittest.util import safe_repr
from functools import wraps

from blatann import BleDevice
Expand Down Expand Up @@ -46,6 +48,7 @@ class BlatannTestCase(TestCase):
def setUpClass(cls) -> None:
if BlatannTestCase.logger is None:
BlatannTestCase.logger = setup_logger()
cls.logger = logging.getLogger(cls.__module__)

cls.dev1 = _configure_device(1, cls.dev1_config)
cls.dev2 = _configure_device(2, cls.dev2_config)
Expand All @@ -68,26 +71,50 @@ def tearDownClass(cls) -> None:
cls.dev2.close()
if cls.dev3:
cls.dev3.close()
# Wait some time for the devices to close and the device to reset
# The nRF52 dev kits don't need this, but the nrf52840 USB dongles seem to need a 2s delay.
# Guessing this is due to the Comport being USB-CDC and during the reset the USB device is not enumerated
# versus the dev kits where the port persists across MCU reboots since it is routed through the on-board J-Link
time.sleep(2)

def assertDeltaWithin(self, expected_value, actual_value, acceptable_delta, message=""):
actual_delta = abs(expected_value - actual_value)
self.logger.debug(f"Delta: {actual_delta:.3f}, Acceptable: {acceptable_delta:.3f}")
if not actual_delta <= acceptable_delta:
standard_msg = "%s is not within %s +- %s" % (safe_repr(actual_value), safe_repr(expected_value),
safe_repr(acceptable_delta))
self.fail(self._formatMessage(message, standard_msg))


class TestParams(object):
def __init__(self, test_params, setup=None, teardown=None):
def __init__(self, test_params, setup=None, teardown=None, long_running_params=None):
self.test_params = test_params
self.long_running_params = long_running_params or []
self._setup = setup
self._teardown = teardown

def __call__(self, func):
@wraps(func)
def subtest_runner(test_case: BlatannTestCase):
tc_name = f"{test_case.__class__.__name__}.{func.__name__}"

quick_tests = int(os.environ.get(BLATANN_QUICK_ENVKEY, 0))
test_params = self.test_params + self.long_running_params
n_tests_to_run = len(self.test_params) if quick_tests else len(test_params)

try:
self.setup(test_case)

for tc in self.test_params:
with test_case.subTest(**tc):
param_s = ", ".join("{}={!r}".format(k, v) for k, v in tc.items())
test_case.logger.info(
"Running {}.{}({})".format(self.__class__.__name__, func.__name__, param_s))
func(test_case, **tc)
for i, params in enumerate(test_params):
param_str = ", ".join(f"{k}={repr(v)}" for k, v in params.items())
subtest_str = f"{tc_name}({param_str})"

with test_case.subTest(**params):
if i < n_tests_to_run:
test_case.logger.info(f"Running {subtest_str}")
func(test_case, **params)
else:
test_case.skipTest(f"Skipping {subtest_str} because it's a long-running test")
finally:
self.teardown(test_case)

Expand All @@ -112,7 +139,7 @@ def teardown(self, test_case: BlatannTestCase):

def long_running(func):
@wraps(func)
def f(self: TestCase, *args, **kwargs):
def f(self: BlatannTestCase, *args, **kwargs):
quick_tests = int(os.environ.get(BLATANN_QUICK_ENVKEY, 0))
if quick_tests:
name = "{}.{}".format(self.__class__.__name__, func.__name__)
Expand Down
20 changes: 2 additions & 18 deletions tests/integrated/test_advertising_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,28 +94,12 @@ def test_advertising_scan_data(self):
self.assertEqual(self.default_adv_data.local_name, adv_data.local_name)
self.assertEqual(self.default_scan_data.service_uuid16s, adv_data.service_uuid16s)

def test_non_active_scanning_no_scan_response_packets_received(self):
self.dev1.advertiser.set_advertise_data(self.default_adv_data, self.default_scan_data)
self.dev1.advertiser.start()
self._configure_scan(active_scan=False)
self.dev2.scanner.set_default_scan_params(100, 100, 5, active_scanning=False)
results = self.dev2.scanner.start_scan(clear_scan_reports=True).wait(10)

# Get the list of all advertising packets from the advertiser
all_packets, adv_packets, scan_response_packets = self._get_packets_for_adv(results)
self.assertGreater(len(all_packets), 0)
self.assertEqual(len(all_packets), len(adv_packets))
self.assertEqual(0, len(scan_response_packets))

for p in adv_packets:
self.assertEqual(self.default_adv_data_bytes, p.raw_bytes)

def test_non_connectable_undirected_no_scan_response_packets_received(self):
self._configure_adv(adv_mode=AdvertisingMode.non_connectable_undirected)
self.dev1.advertiser.set_advertise_data(self.default_adv_data, self.default_scan_data)
self.dev1.advertiser.start()
self.dev2.scanner.set_default_scan_params(100, 100, 5, active_scanning=False)
results = self.dev2.scanner.start_scan(clear_scan_reports=True).wait(10)
self.dev2.scanner.set_default_scan_params(100, 100, 5, active_scanning=True)
results = self.dev2.scanner.start_scan().wait(10)

# Get the list of all advertising packets from the advertiser
all_packets, adv_packets, scan_response_packets = self._get_packets_for_adv(results)
Expand Down
28 changes: 12 additions & 16 deletions tests/integrated/test_advertising_duration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from tests.integrated.base import BlatannTestCase, TestParams, long_running


# TODO: The acceptable duration deltas are generous because the nRF52 dev kits (being UART) are slower
# than the nRF52840 dongles by roughly an order of magnitude (1M baud UART vs. USB-CDC)


class TestAdvertisingDuration(BlatannTestCase):
def setUp(self) -> None:
self.adv_interval_ms = 50
Expand All @@ -23,8 +27,8 @@ def tearDown(self) -> None:
self.dev1.advertiser.stop()
self.dev2.scanner.stop()

@long_running
@TestParams([dict(duration=x) for x in [1, 4, 8, 10]])
@TestParams([dict(duration=x) for x in [1, 4, 10]], long_running_params=
[dict(duration=x) for x in [120, 180]])
def test_advertise_duration(self, duration):
acceptable_delta = 0.100
acceptable_delta_scan = 1.000
Expand All @@ -49,15 +53,11 @@ def on_scan_report(scanner: Scanner, report: ScanReport):
self.assertFalse(wait_stopwatch.is_running)
self.assertFalse(self.dev1.advertiser.is_advertising)

wait_delta = abs(duration - wait_stopwatch.elapsed)
self.assertLessEqual(wait_delta, acceptable_delta)

scan_delta = abs(duration - scan_stopwatch.elapsed)
self.assertLessEqual(scan_delta, acceptable_delta_scan)

self.logger.info("Wait Delta: {:.3f}, Scan Delta: {:.3f}".format(wait_delta, scan_delta))
self.assertDeltaWithin(duration, wait_stopwatch.elapsed, acceptable_delta)
self.assertDeltaWithin(duration, scan_stopwatch.elapsed, acceptable_delta_scan)

@TestParams([dict(duration=x) for x in [1, 2, 4]])
@TestParams([dict(duration=x) for x in [1, 2, 4, 10]], long_running_params=
[dict(duration=x) for x in [30, 60]])
def test_advertise_duration_timeout_event(self, duration):
acceptable_delta = 0.100
on_timeout_event = threading.Event()
Expand All @@ -73,9 +73,7 @@ def on_timeout(*args, **kwargs):
self.assertTrue(on_timeout_event.is_set())
self.assertFalse(self.dev1.advertiser.is_advertising)

actual_delta = abs(duration - stopwatch.elapsed)
self.assertLessEqual(actual_delta, acceptable_delta)
self.logger.info("Delta: {:.3f}".format(actual_delta))
self.assertDeltaWithin(duration, stopwatch.elapsed, acceptable_delta)

def test_advertise_auto_restart(self):
# Scan for longer than the advertising duration,
Expand All @@ -100,9 +98,7 @@ def test_advertise_auto_restart(self):

report_seen_duration = report_timestamps[-1] - report_timestamps[0]

delta = abs(report_seen_duration - scan_duration)
self.assertLessEqual(delta, acceptable_delta)
self.logger.info("Delta: {:.3f}".format(delta))
self.assertDeltaWithin(scan_duration, report_seen_duration, acceptable_delta)


if __name__ == '__main__':
Expand Down
93 changes: 93 additions & 0 deletions tests/integrated/test_scanner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import threading
import unittest

from blatann.gap.advertising import AdvertisingMode
from blatann.gap.advertise_data import AdvertisingData, AdvertisingFlags, AdvertisingPacketType
from blatann.gap.scanning import MIN_SCAN_WINDOW_MS, MIN_SCAN_INTERVAL_MS, ScanParameters
from blatann.uuid import Uuid16
from blatann.utils import Stopwatch

from tests.integrated.base import BlatannTestCase, TestParams, long_running


class TestScanner(BlatannTestCase):
def setUp(self) -> None:
self.adv_interval_ms = 20
self.adv_mac_addr = self.dev1.address
self.adv_mode = AdvertisingMode.scanable_undirected
self.scan_params = ScanParameters(MIN_SCAN_INTERVAL_MS, MIN_SCAN_WINDOW_MS, 4)
self.flags = AdvertisingFlags.GENERAL_DISCOVERY_MODE | AdvertisingFlags.BR_EDR_NOT_SUPPORTED
self.uuid16s = [Uuid16(0xABCD), Uuid16(0xDEF0)]
self.default_adv_data = AdvertisingData(flags=self.flags, local_name="Blatann Test")
self.default_adv_data_bytes = self.default_adv_data.to_bytes()
self.default_scan_data = AdvertisingData(service_uuid16s=self.uuid16s)
self.default_scan_data_bytes = self.default_scan_data.to_bytes()

def tearDown(self) -> None:
self.dev1.advertiser.stop()
self.dev2.scanner.stop()

def _get_packets_for_adv(self, results):
all_packets = [p for p in results.all_scan_reports if p.peer_address == self.adv_mac_addr]
adv_packets = [p for p in all_packets if p.packet_type == self.adv_mode]
scan_response_packets = [p for p in all_packets if p.packet_type == AdvertisingPacketType.scan_response]
return all_packets, adv_packets, scan_response_packets

@TestParams([dict(duration=x) for x in [1, 2, 4, 10]], long_running_params=
[dict(duration=x) for x in [60, 120]])
def test_scan_duration(self, duration):
acceptable_delta = 0.100
on_timeout_event = threading.Event()
self.scan_params.timeout_s = duration

self.dev1.advertiser.start(self.adv_interval_ms, duration+2)

def on_timeout(*args, **kwargs):
on_timeout_event.set()

with self.dev2.scanner.on_scan_timeout.register(on_timeout):
with Stopwatch() as stopwatch:
self.dev2.scanner.start_scan(self.scan_params)
on_timeout_event.wait(duration + 2)

self.assertTrue(on_timeout_event.is_set())
self.assertFalse(self.dev2.scanner.is_scanning)

actual_delta = abs(duration - stopwatch.elapsed)
self.assertLessEqual(actual_delta, acceptable_delta)
self.logger.info("Delta: {:.3f}".format(actual_delta))

def test_scan_iterator(self):
acceptable_delta = 0.100
self.scan_params.timeout_s = 5

self.dev1.advertiser.start(self.adv_interval_ms, self.scan_params.timeout_s+2)

adv_address = self.dev1.address
report_count_from_advertiser = 0
with Stopwatch() as stopwatch:
for report in self.dev2.scanner.start_scan(self.scan_params).scan_reports:
if report.peer_address == adv_address:
report_count_from_advertiser += 1

self.assertGreater(report_count_from_advertiser, 0)
self.assertDeltaWithin(self.scan_params.timeout_s, stopwatch.elapsed, acceptable_delta)

def test_non_active_scanning_no_scan_response_packets_received(self):
self.dev1.advertiser.set_advertise_data(self.default_adv_data, self.default_scan_data)
self.dev1.advertiser.start(advertise_mode=self.adv_mode)
self.scan_params.active = False
results = self.dev2.scanner.start_scan(self.scan_params).wait(10)

# Get the list of all advertising packets from the advertiser
all_packets, adv_packets, scan_response_packets = self._get_packets_for_adv(results)
self.assertGreater(len(all_packets), 0)
self.assertEqual(len(all_packets), len(adv_packets))
self.assertEqual(0, len(scan_response_packets))

for p in adv_packets:
self.assertEqual(self.default_adv_data_bytes, p.raw_bytes)


if __name__ == '__main__':
unittest.main()

0 comments on commit bfb4d4c

Please sign in to comment.