From aa9df3da428e38694ab04763207a758a1dc3c308 Mon Sep 17 00:00:00 2001 From: Erik Jaegervall Date: Mon, 25 Sep 2023 12:38:21 +0200 Subject: [PATCH] Raise exception if not connected Also make sure that we have consistent check for available connection Fixes #667 --- kuksa-client/docs/examples/async-grpc.md | 8 +-- kuksa-client/docs/examples/sync-grpc.md | 4 +- kuksa-client/kuksa_client/grpc/__init__.py | 42 +++++++++------- kuksa-client/kuksa_client/grpc/aio.py | 57 ++++++++++++++++------ kuksa-client/tests/test_grpc.py | 10 ++++ 5 files changed, 85 insertions(+), 36 deletions(-) diff --git a/kuksa-client/docs/examples/async-grpc.md b/kuksa-client/docs/examples/async-grpc.md index fe47add84..a214786cb 100644 --- a/kuksa-client/docs/examples/async-grpc.md +++ b/kuksa-client/docs/examples/async-grpc.md @@ -37,7 +37,9 @@ async def main(): asyncio.run(main()) ``` -Besides this there is a solution where you are not using the client as context-manager +Besides this there is a solution where you are not using the client as context-manager. +Then you must explicitly call `connect()`. + ```python import asyncio @@ -85,7 +87,7 @@ async def main(): async for updates in client.subscribe_current_values([ 'Vehicle.Body.Windshield.Front.Wiping.System.TargetPosition', ]): - if current_values['Vehicle.Body.Windshield.Front.Wiping.System.TargetPosition'] is not None: + if updates['Vehicle.Body.Windshield.Front.Wiping.System.TargetPosition'] is not None: current_position = updates['Vehicle.Body.Windshield.Front.Wiping.System.TargetPosition'].value print(f"Current wiper position is: {current_position}") @@ -108,7 +110,7 @@ async def main(): current_values = await client.get_current_values([ 'Vehicle.Speed', ]) - + if current_values['Vehicle.Speed'] is not None: print(current_values['Vehicle.Speed'].value) diff --git a/kuksa-client/docs/examples/sync-grpc.md b/kuksa-client/docs/examples/sync-grpc.md index 2cb068f58..c6836861a 100644 --- a/kuksa-client/docs/examples/sync-grpc.md +++ b/kuksa-client/docs/examples/sync-grpc.md @@ -33,7 +33,9 @@ with VSSClient('127.0.0.1', 55555) as client: print(current_values['Vehicle.Speed'].value) ``` -Besides this there is a solution where you are not using the client as context-manager +Besides this there is a solution where you are not using the client as context-manager. +Then you must explicitly call `connect()`. + ```python from kuksa_client.grpc import VSSClient diff --git a/kuksa-client/kuksa_client/grpc/__init__.py b/kuksa-client/kuksa_client/grpc/__init__.py index fa1d24ede..7c8c887c9 100644 --- a/kuksa-client/kuksa_client/grpc/__init__.py +++ b/kuksa-client/kuksa_client/grpc/__init__.py @@ -559,7 +559,7 @@ def _load_creds(self) -> Optional[grpc.ChannelCredentials]: else: logger.info("No client certificates provided, mutual TLS not supported!") return grpc.ssl_channel_credentials(root_certificates) - logger.info("No Root CA present, it will not be posible to use a secure connection!") + logger.info("No Root CA present, it will not be possible to use a secure connection!") return None def _prepare_get_request(self, entries: Iterable[EntryRequest]) -> val_pb2.GetRequest: @@ -675,8 +675,9 @@ def wrapper(self, *args, **kwargs): if self.connected: return func(self, *args, **kwargs) else: - logger.info( - "Disconnected from server! Try connect.") + # This shall normally not happen if you use the client as context manager + # as then a connect will happen automatically when you enter the context + raise Exception("Server not connected! Call connect() before using this command!") return wrapper def connect(self, target_host=None): @@ -709,6 +710,7 @@ def disconnect(self): self.channel = None self.connected = False + @check_connected def get_current_values(self, paths: Iterable[str], **rpc_kwargs) -> Dict[str, Datapoint]: """ Parameters: @@ -726,6 +728,7 @@ def get_current_values(self, paths: Iterable[str], **rpc_kwargs) -> Dict[str, Da ) return {entry.path: entry.value for entry in entries} + @check_connected def get_target_values(self, paths: Iterable[str], **rpc_kwargs) -> Dict[str, Datapoint]: """ Parameters: @@ -742,6 +745,7 @@ def get_target_values(self, paths: Iterable[str], **rpc_kwargs) -> Dict[str, Dat ) for path in paths), **rpc_kwargs) return {entry.path: entry.actuator_target for entry in entries} + @check_connected def get_metadata( self, paths: Iterable[str], field: MetadataField = MetadataField.ALL, **rpc_kwargs, ) -> Dict[str, Metadata]: @@ -761,6 +765,7 @@ def get_metadata( ) return {entry.path: entry.metadata for entry in entries} + @check_connected def set_current_values(self, updates: Dict[str, Datapoint], **rpc_kwargs) -> None: """ Parameters: @@ -778,6 +783,7 @@ def set_current_values(self, updates: Dict[str, Datapoint], **rpc_kwargs) -> Non **rpc_kwargs, ) + @check_connected def set_target_values(self, updates: Dict[str, Datapoint], **rpc_kwargs) -> None: """ Parameters: @@ -791,6 +797,7 @@ def set_target_values(self, updates: Dict[str, Datapoint], **rpc_kwargs) -> None DataEntry(path, actuator_target=dp), (Field.ACTUATOR_TARGET,), ) for path, dp in updates.items()], **rpc_kwargs) + @check_connected def set_metadata( self, updates: Dict[str, Metadata], field: MetadataField = MetadataField.ALL, **rpc_kwargs, ) -> None: @@ -807,6 +814,7 @@ def set_metadata( DataEntry(path, metadata=md), (Field(field.value),), ) for path, md in updates.items()], **rpc_kwargs) + @check_connected def subscribe_current_values(self, paths: Iterable[str], **rpc_kwargs) -> Iterator[Dict[str, Datapoint]]: """ Parameters: @@ -826,6 +834,7 @@ def subscribe_current_values(self, paths: Iterable[str], **rpc_kwargs) -> Iterat ): yield {update.entry.path: update.entry.value for update in updates} + @check_connected def subscribe_target_values(self, paths: Iterable[str], **rpc_kwargs) -> Iterator[Dict[str, Datapoint]]: """ Parameters: @@ -845,6 +854,7 @@ def subscribe_target_values(self, paths: Iterable[str], **rpc_kwargs) -> Iterato ): yield {update.entry.path: update.entry.actuator_target for update in updates} + @check_connected def subscribe_metadata( self, paths: Iterable[str], field: MetadataField = MetadataField.ALL, @@ -907,7 +917,7 @@ def set(self, updates: Collection[EntryUpdate], **rpc_kwargs) -> None: raise VSSClientError.from_grpc_error(exc) from exc self._process_set_response(resp) - # needs to be handled differently + @check_connected def subscribe(self, entries: Iterable[SubscribeEntry], **rpc_kwargs) -> Iterator[List[EntryUpdate]]: """ Parameters: @@ -915,19 +925,16 @@ def subscribe(self, entries: Iterable[SubscribeEntry], **rpc_kwargs) -> Iterator grpc.*MultiCallable kwargs e.g. timeout, metadata, credentials. """ - if self.connected: - rpc_kwargs["metadata"] = self.generate_metadata_header( - rpc_kwargs.get("metadata")) - req = self._prepare_subscribe_request(entries) - resp_stream = self.client_stub.Subscribe(req, **rpc_kwargs) - try: - for resp in resp_stream: - logger.debug("%s: %s", type(resp).__name__, resp) - yield [EntryUpdate.from_message(update) for update in resp.updates] - except RpcError as exc: - raise VSSClientError.from_grpc_error(exc) from exc - else: - logger.info("Disconnected from server! Try connect.") + rpc_kwargs["metadata"] = self.generate_metadata_header( + rpc_kwargs.get("metadata")) + req = self._prepare_subscribe_request(entries) + resp_stream = self.client_stub.Subscribe(req, **rpc_kwargs) + try: + for resp in resp_stream: + logger.debug("%s: %s", type(resp).__name__, resp) + yield [EntryUpdate.from_message(update) for update in resp.updates] + except RpcError as exc: + raise VSSClientError.from_grpc_error(exc) from exc @check_connected def authorize(self, token: str, **rpc_kwargs) -> str: @@ -971,6 +978,7 @@ def get_server_info(self, **rpc_kwargs) -> Optional[ServerInfo]: raise VSSClientError.from_grpc_error(exc) from exc return None + @check_connected def get_value_types(self, paths: Collection[str], **rpc_kwargs) -> Dict[str, DataType]: """ Parameters: diff --git a/kuksa-client/kuksa_client/grpc/aio.py b/kuksa-client/kuksa_client/grpc/aio.py index f22dba8a2..6661aab73 100644 --- a/kuksa-client/kuksa_client/grpc/aio.py +++ b/kuksa-client/kuksa_client/grpc/aio.py @@ -93,14 +93,34 @@ async def disconnect(self): self.connected = False def check_connected_async(func): + """ + Decorator to verify that there is a connection before calling underlying method + For generator methods use check_connected_async_iter + """ async def wrapper(self, *args, **kwargs): if self.connected: return await func(self, *args, **kwargs) else: - logger.info( - "Disconnected from server! Try cli command connect.") + # This shall normally not happen if you use the client as context manager + # as then a connect will happen automatically when you enter the context + raise Exception("Server not connected! Call connect() before using this command!") + return wrapper + + def check_connected_async_iter(func): + """ + Decorator for generator methods to verify that there is a connection before calling underlying method + """ + async def wrapper(self, *args, **kwargs): + if self.connected: + async for v in func(self, *args, **kwargs): + yield v + else: + # This shall normally not happen if you use the client as context manager + # as then a connect will happen automatically when you enter the context + raise Exception("Server not connected! Call connect() before using this command!") return wrapper + @check_connected_async async def get_current_values(self, paths: Iterable[str], **rpc_kwargs) -> Dict[str, Datapoint]: """ Parameters: @@ -120,6 +140,7 @@ async def get_current_values(self, paths: Iterable[str], **rpc_kwargs) -> Dict[s ) return {entry.path: entry.value for entry in entries} + @check_connected_async async def get_target_values(self, paths: Iterable[str], **rpc_kwargs) -> Dict[str, Datapoint]: """ Parameters: @@ -137,6 +158,7 @@ async def get_target_values(self, paths: Iterable[str], **rpc_kwargs) -> Dict[st ) for path in paths)) return {entry.path: entry.actuator_target for entry in entries} + @check_connected_async async def get_metadata( self, paths: Iterable[str], field: MetadataField = MetadataField.ALL, **rpc_kwargs, ) -> Dict[str, Metadata]: @@ -158,6 +180,7 @@ async def get_metadata( ) return {entry.path: entry.metadata for entry in entries} + @check_connected_async async def set_current_values(self, updates: Dict[str, Datapoint], **rpc_kwargs) -> None: """ Parameters: @@ -175,6 +198,7 @@ async def set_current_values(self, updates: Dict[str, Datapoint], **rpc_kwargs) **rpc_kwargs, ) + @check_connected_async async def set_target_values(self, updates: Dict[str, Datapoint], **rpc_kwargs) -> None: """ Parameters: @@ -187,6 +211,7 @@ async def set_target_values(self, updates: Dict[str, Datapoint], **rpc_kwargs) - DataEntry(path, actuator_target=dp), (Field.ACTUATOR_TARGET,), ) for path, dp in updates.items()], **rpc_kwargs) + @check_connected_async async def set_metadata( self, updates: Dict[str, Metadata], field: MetadataField = MetadataField.ALL, **rpc_kwargs, ) -> None: @@ -203,6 +228,7 @@ async def set_metadata( DataEntry(path, metadata=md), (Field(field.value),), ) for path, md in updates.items()], **rpc_kwargs) + @check_connected_async_iter async def subscribe_current_values(self, paths: Iterable[str], **rpc_kwargs) -> AsyncIterator[Dict[str, Datapoint]]: """ Parameters: @@ -222,6 +248,7 @@ async def subscribe_current_values(self, paths: Iterable[str], **rpc_kwargs) -> ): yield {update.entry.path: update.entry.value for update in updates} + @check_connected_async_iter async def subscribe_target_values(self, paths: Iterable[str], **rpc_kwargs) -> AsyncIterator[Dict[str, Datapoint]]: """ Parameters: @@ -241,6 +268,7 @@ async def subscribe_target_values(self, paths: Iterable[str], **rpc_kwargs) -> A ): yield {update.entry.path: update.entry.actuator_target for update in updates} + @check_connected_async_iter async def subscribe_metadata( self, paths: Iterable[str], field: MetadataField = MetadataField.ALL, @@ -302,6 +330,7 @@ async def set(self, updates: Collection[EntryUpdate], **rpc_kwargs) -> None: raise VSSClientError.from_grpc_error(exc) from exc self._process_set_response(resp) + @check_connected_async_iter async def subscribe(self, entries: Iterable[SubscribeEntry], **rpc_kwargs, @@ -311,19 +340,16 @@ async def subscribe(self, rpc_kwargs grpc.*MultiCallable kwargs e.g. timeout, metadata, credentials. """ - if self.connected: - rpc_kwargs["metadata"] = self.generate_metadata_header( - rpc_kwargs.get("metadata")) - req = self._prepare_subscribe_request(entries) - resp_stream = self.client_stub.Subscribe(req, **rpc_kwargs) - try: - async for resp in resp_stream: - logger.debug("%s: %s", type(resp).__name__, resp) - yield [EntryUpdate.from_message(update) for update in resp.updates] - except AioRpcError as exc: - raise VSSClientError.from_grpc_error(exc) from exc - else: - logger.info("Disconnected from server! Try connect.") + rpc_kwargs["metadata"] = self.generate_metadata_header( + rpc_kwargs.get("metadata")) + req = self._prepare_subscribe_request(entries) + resp_stream = self.client_stub.Subscribe(req, **rpc_kwargs) + try: + async for resp in resp_stream: + logger.debug("%s: %s", type(resp).__name__, resp) + yield [EntryUpdate.from_message(update) for update in resp.updates] + except AioRpcError as exc: + raise VSSClientError.from_grpc_error(exc) from exc @check_connected_async async def authorize(self, token: str, **rpc_kwargs) -> str: @@ -367,6 +393,7 @@ async def get_server_info(self, **rpc_kwargs) -> Optional[ServerInfo]: raise VSSClientError.from_grpc_error(exc) from exc return None + @check_connected_async async def get_value_types(self, paths: Collection[str], **rpc_kwargs) -> Dict[str, DataType]: """ Parameters: diff --git a/kuksa-client/tests/test_grpc.py b/kuksa-client/tests/test_grpc.py index 650d2943d..91d06f517 100644 --- a/kuksa-client/tests/test_grpc.py +++ b/kuksa-client/tests/test_grpc.py @@ -354,6 +354,8 @@ async def test_secure_connection(self, unused_tcp_port, resources_path, val_serv async def test_get_current_values(self, mocker, unused_tcp_port): client = VSSClient('127.0.0.1', unused_tcp_port) + client.connected = True # To bypass connection test + mocker.patch.object(client, 'get', return_value=[ DataEntry('Vehicle.Speed', value=Datapoint( 42.0, datetime.datetime( @@ -382,6 +384,7 @@ async def test_get_current_values(self, mocker, unused_tcp_port): async def test_get_target_values(self, mocker, unused_tcp_port): client = VSSClient('127.0.0.1', unused_tcp_port) + client.connected = True # To bypass connection check mocker.patch.object(client, 'get', return_value=[ DataEntry('Vehicle.ADAS.ABS.IsActive', actuator_target=Datapoint( True, datetime.datetime( @@ -405,6 +408,7 @@ async def test_get_target_values(self, mocker, unused_tcp_port): async def test_get_metadata(self, mocker, unused_tcp_port): client = VSSClient('127.0.0.1', unused_tcp_port) + client.connected = True # To bypass connection check mocker.patch.object(client, 'get', return_value=[ DataEntry('Vehicle.Speed', metadata=Metadata( entry_type=EntryType.SENSOR)), @@ -433,6 +437,7 @@ async def test_get_metadata(self, mocker, unused_tcp_port): async def test_set_current_values(self, mocker, unused_tcp_port): client = VSSClient('127.0.0.1', unused_tcp_port) + client.connected = True # To bypass connection check mocker.patch.object(client, 'set') await client.set_current_values({ 'Vehicle.Speed': Datapoint(42.0, @@ -454,6 +459,7 @@ async def test_set_current_values(self, mocker, unused_tcp_port): async def test_set_target_values(self, mocker, unused_tcp_port): client = VSSClient('127.0.0.1', unused_tcp_port) + client.connected = True # To bypass connection check mocker.patch.object(client, 'set') await client.set_target_values({ 'Vehicle.ADAS.ABS.IsActive': Datapoint(True, datetime.datetime(2022, 11, 7, tzinfo=datetime.timezone.utc)), @@ -470,6 +476,7 @@ async def test_set_target_values(self, mocker, unused_tcp_port): async def test_set_metadata(self, mocker, unused_tcp_port): client = VSSClient('127.0.0.1', unused_tcp_port) + client.connected = True # To bypass connection check mocker.patch.object(client, 'set') await client.set_metadata({ 'Vehicle.Speed': Metadata(entry_type=EntryType.SENSOR), @@ -490,6 +497,7 @@ async def test_set_metadata(self, mocker, unused_tcp_port): async def test_subscribe_current_values(self, mocker, unused_tcp_port): client = VSSClient('127.0.0.1', unused_tcp_port) + client.connected = True # To bypass connection check async def subscribe_response_stream(**kwargs): yield [ @@ -529,6 +537,7 @@ async def subscribe_response_stream(**kwargs): async def test_subscribe_target_values(self, mocker, unused_tcp_port): client = VSSClient('127.0.0.1', unused_tcp_port) + client.connected = True # To bypass connection check async def subscribe_response_stream(**kwargs): yield [ @@ -561,6 +570,7 @@ async def subscribe_response_stream(**kwargs): async def test_subscribe_metadata(self, mocker, unused_tcp_port): client = VSSClient('127.0.0.1', unused_tcp_port) + client.connected = True # To bypass connection check async def subscribe_response_stream(**kwargs): yield [