From afbe8db955a14728acfbd6fbf805937c7a0958c6 Mon Sep 17 00:00:00 2001 From: rafaeling Date: Fri, 8 Sep 2023 14:15:18 +0200 Subject: [PATCH] Use entries iterator and handles permissions correctly --- kuksa-client/kuksa_client/cli_backend/grpc.py | 10 +- kuksa-client/kuksa_client/grpc/__init__.py | 46 ++-- kuksa-client/kuksa_client/grpc/aio.py | 4 +- kuksa-client/tests/test_grpc.py | 42 +++- kuksa_databroker/databroker/src/broker.rs | 23 -- .../databroker/src/grpc/kuksa_val_v1/val.rs | 208 ++++++++++-------- 6 files changed, 179 insertions(+), 154 deletions(-) diff --git a/kuksa-client/kuksa_client/cli_backend/grpc.py b/kuksa-client/kuksa_client/cli_backend/grpc.py index 784f9dbec..c74a5ae46 100644 --- a/kuksa-client/kuksa_client/cli_backend/grpc.py +++ b/kuksa-client/kuksa_client/cli_backend/grpc.py @@ -28,7 +28,6 @@ from typing import Optional import uuid import os -import re from kuksa_client import cli_backend import kuksa_client.grpc @@ -36,14 +35,17 @@ from kuksa_client.grpc import EntryUpdate from kuksa.val.v1 import types_pb2 + def callback_wrapper(callback: Callable[[str], None]) -> Callable[[Iterable[EntryUpdate]], None]: def wrapper(updates: Iterable[EntryUpdate]) -> None: callback(json.dumps([update.to_dict() for update in updates])) return wrapper + class DatabrokerEncoder(json.JSONEncoder): def default(self, obj): - if isinstance(obj, (types_pb2.StringArray, types_pb2.BoolArray, types_pb2.Uint32Array, types_pb2.Uint64Array, types_pb2.FloatArray, types_pb2.Int32Array, types_pb2.Int64Array, types_pb2.DoubleArray)): + if isinstance(obj, (types_pb2.StringArray, types_pb2.BoolArray, types_pb2.Uint32Array, types_pb2.Uint64Array, + types_pb2.FloatArray, types_pb2.Int32Array, types_pb2.Int64Array, types_pb2.DoubleArray)): string_values = [] for value in obj.values: value = str(value) @@ -145,7 +147,7 @@ def setValues(self, updates: Dict[str, Any], attribute="value", timeout=5): return json.dumps({"error": "Invalid Attribute"}) # Function for authorization - def authorize(self, token_or_tokenfile:Optional[str] =None, timeout=5): + def authorize(self, token_or_tokenfile: Optional[str] = None, timeout=5): if token_or_tokenfile is None: token_or_tokenfile = self.token_or_tokenfile if os.path.isfile(token_or_tokenfile): @@ -251,7 +253,7 @@ async def _grpcHandler(self, vss_client: kuksa_client.grpc.aio.VSSClient): responseQueue.put((resp, None)) except kuksa_client.grpc.VSSClientError as exc: responseQueue.put((None, exc.to_dict())) - except ValueError as exc: + except ValueError: responseQueue.put( (None, {"error": "ValueError in casting the value."})) diff --git a/kuksa-client/kuksa_client/grpc/__init__.py b/kuksa-client/kuksa_client/grpc/__init__.py index 45c536ac9..fa1d24ede 100644 --- a/kuksa-client/kuksa_client/grpc/__init__.py +++ b/kuksa-client/kuksa_client/grpc/__init__.py @@ -315,7 +315,6 @@ def from_message(cls, message: types_pb2.Datapoint): ) if message.HasField('timestamp') else None, ) - def cast_array_values(cast, array): """ Parses array input and cast individual values to wanted type. @@ -339,7 +338,7 @@ def cast_array_values(cast, array): # My Way # ... without quotes if item.strip() == '': - #skip + # skip pass else: yield cast(item) @@ -365,7 +364,7 @@ def cast_str(value) -> str: new_val = new_val.replace('\\\"', '\"') new_val = new_val.replace("\\\'", "\'") return new_val - + def to_message(self, value_type: DataType) -> types_pb2.Datapoint: message = types_pb2.Datapoint() @@ -374,7 +373,6 @@ def set_array_attr(obj, attr, values): array.Clear() array.values.extend(values) - field, set_field, cast_field = { DataType.INT8: ('int32', setattr, int), DataType.INT16: ('int32', setattr, int), @@ -388,29 +386,29 @@ def set_array_attr(obj, attr, values): DataType.DOUBLE: ('double', setattr, float), DataType.BOOLEAN: ('bool', setattr, Datapoint.cast_bool), DataType.STRING: ('string', setattr, Datapoint.cast_str), - DataType.INT8_ARRAY: ('int32_array', set_array_attr, + DataType.INT8_ARRAY: ('int32_array', set_array_attr, lambda array: Datapoint.cast_array_values(int, array)), - DataType.INT16_ARRAY: ('int32_array', set_array_attr, + DataType.INT16_ARRAY: ('int32_array', set_array_attr, lambda array: Datapoint.cast_array_values(int, array)), - DataType.INT32_ARRAY: ('int32_array', set_array_attr, + DataType.INT32_ARRAY: ('int32_array', set_array_attr, lambda array: Datapoint.cast_array_values(int, array)), - DataType.UINT8_ARRAY: ('uint32_array', set_array_attr, + DataType.UINT8_ARRAY: ('uint32_array', set_array_attr, lambda array: Datapoint.cast_array_values(int, array)), - DataType.UINT16_ARRAY: ('uint32_array', set_array_attr, + DataType.UINT16_ARRAY: ('uint32_array', set_array_attr, lambda array: Datapoint.cast_array_values(int, array)), - DataType.UINT32_ARRAY: ('uint32_array', set_array_attr, + DataType.UINT32_ARRAY: ('uint32_array', set_array_attr, lambda array: Datapoint.cast_array_values(int, array)), - DataType.UINT64_ARRAY: ('uint64_array', set_array_attr, + DataType.UINT64_ARRAY: ('uint64_array', set_array_attr, lambda array: Datapoint.cast_array_values(int, array)), - DataType.INT64_ARRAY: ('int64_array', set_array_attr, + DataType.INT64_ARRAY: ('int64_array', set_array_attr, lambda array: Datapoint.cast_array_values(int, array)), - DataType.FLOAT_ARRAY: ('float_array', set_array_attr, + DataType.FLOAT_ARRAY: ('float_array', set_array_attr, lambda array: Datapoint.cast_array_values(float, array)), - DataType.DOUBLE_ARRAY: ('double_array', set_array_attr, + DataType.DOUBLE_ARRAY: ('double_array', set_array_attr, lambda array: Datapoint.cast_array_values(float, array)), - DataType.BOOLEAN_ARRAY: ('bool_array', set_array_attr, + DataType.BOOLEAN_ARRAY: ('bool_array', set_array_attr, lambda array: Datapoint.cast_array_values(Datapoint.cast_bool, array)), - DataType.STRING_ARRAY: ('string_array', set_array_attr, + DataType.STRING_ARRAY: ('string_array', set_array_attr, lambda array: Datapoint.cast_array_values(Datapoint.cast_str, array)), }.get(value_type, (None, None, None)) if self.value is not None: @@ -523,6 +521,7 @@ class ServerInfo: def from_message(cls, message: val_pb2.GetServerInfoResponse): return cls(name=message.name, version=message.version) + class BaseVSSClient: def __init__( self, @@ -536,7 +535,6 @@ def __init__( connected: bool = False, tls_server_name: Optional[str] = None ): - self.authorization_header = self.get_authorization_header(token) self.target_host = f'{host}:{port}' @@ -559,11 +557,10 @@ def _load_creds(self) -> Optional[grpc.ChannelCredentials]: logger.info("Using client private key and certificates, mutual TLS supported if supported by server") return grpc.ssl_channel_credentials(root_certificates, private_key, certificate_chain) else: - logger.info(f"No client certificates provided, mutual TLS not supported!") + logger.info("No client certificates provided, mutual TLS not supported!") return grpc.ssl_channel_credentials(root_certificates) - logger.info(f"No Root CA present, it will not be posible to use a secure connection!") + logger.info("No Root CA present, it will not be posible to use a secure connection!") return None - def _prepare_get_request(self, entries: Iterable[EntryRequest]) -> val_pb2.GetRequest: req = val_pb2.GetRequest(entries=[]) @@ -649,7 +646,7 @@ def get_authorization_header(self, token: str): return "Bearer " + token def generate_metadata_header(self, metadata: list, header=None) -> list: - if header == None: + if header is None: header = self.authorization_header if metadata: metadata = dict(metadata) @@ -686,7 +683,7 @@ def connect(self, target_host=None): creds = self._load_creds() if target_host is None: target_host = self.target_host - + if creds is not None: logger.info("Establishing secure channel") if self.tls_server_name: @@ -694,12 +691,12 @@ def connect(self, target_host=None): options = [('grpc.ssl_target_name_override', self.tls_server_name)] channel = grpc.secure_channel(target_host, creds, options) else: - logger.debug(f"Not providing explicit TLS server name") + logger.debug("Not providing explicit TLS server name") channel = grpc.secure_channel(target_host, creds) else: logger.info("Establishing insecure channel") channel = grpc.insecure_channel(target_host) - + self.channel = self.exit_stack.enter_context(channel) self.client_stub = val_pb2_grpc.VALStub(self.channel) self.connected = True @@ -973,7 +970,6 @@ def get_server_info(self, **rpc_kwargs) -> Optional[ServerInfo]: else: raise VSSClientError.from_grpc_error(exc) from exc return None - def get_value_types(self, paths: Collection[str], **rpc_kwargs) -> Dict[str, DataType]: """ diff --git a/kuksa-client/kuksa_client/grpc/aio.py b/kuksa-client/kuksa_client/grpc/aio.py index d8a90c89c..f22dba8a2 100644 --- a/kuksa-client/kuksa_client/grpc/aio.py +++ b/kuksa-client/kuksa_client/grpc/aio.py @@ -75,7 +75,7 @@ async def connect(self, target_host=None): options = [('grpc.ssl_target_name_override', self.tls_server_name)] channel = grpc.aio.secure_channel(target_host, creds, options) else: - logger.debug(f"Not providing explicit TLS server name") + logger.debug("Not providing explicit TLS server name") channel = grpc.aio.secure_channel(target_host, creds) else: logger.info("Establishing insecure channel") @@ -363,7 +363,7 @@ async def get_server_info(self, **rpc_kwargs) -> Optional[ServerInfo]: except AioRpcError as exc: if exc.code() == grpc.StatusCode.UNAUTHENTICATED: logger.info("Unauthenticated channel started") - else: + else: raise VSSClientError.from_grpc_error(exc) from exc return None diff --git a/kuksa-client/tests/test_grpc.py b/kuksa-client/tests/test_grpc.py index 823e4fea2..650d2943d 100644 --- a/kuksa-client/tests/test_grpc.py +++ b/kuksa-client/tests/test_grpc.py @@ -435,7 +435,9 @@ async def test_set_current_values(self, mocker, unused_tcp_port): client = VSSClient('127.0.0.1', unused_tcp_port) mocker.patch.object(client, 'set') await client.set_current_values({ - 'Vehicle.Speed': Datapoint(42.0, datetime.datetime(2022, 11, 7, 16, 18, 35, 247307, tzinfo=datetime.timezone.utc)), + 'Vehicle.Speed': Datapoint(42.0, + datetime.datetime(2022, 11, 7, 16, 18, 35, 247307, + tzinfo=datetime.timezone.utc)), 'Vehicle.ADAS.ABS.IsActive': Datapoint(True), 'Vehicle.Chassis.Height': Datapoint(666), }) @@ -518,7 +520,9 @@ async def subscribe_response_stream(**kwargs): View.CURRENT_VALUE, (Field.VALUE,)), ] assert received_updates == { - 'Vehicle.Speed': Datapoint(42.0, datetime.datetime(2022, 11, 7, 16, 18, 35, 247307, tzinfo=datetime.timezone.utc)), + 'Vehicle.Speed': Datapoint(42.0, + datetime.datetime(2022, 11, 7, 16, 18, 35, 247307, + tzinfo=datetime.timezone.utc)), 'Vehicle.ADAS.ABS.IsActive': Datapoint(True), 'Vehicle.Chassis.Height': Datapoint(666), } @@ -652,7 +656,8 @@ async def test_get_some_entries(self, unused_tcp_port, val_servicer): ), ]) async with VSSClient('127.0.0.1', unused_tcp_port, ensure_startup_connection=False) as client: - entries = await client.get(entries=(entry for entry in ( # generator is intentional as get accepts Iterable + entries = await client.get(entries=(entry for entry in ( + # generator is intentional as get accepts Iterable EntryRequest('Vehicle.Speed', View.CURRENT_VALUE, (Field.VALUE,)), EntryRequest('Vehicle.ADAS.ABS.IsActive', @@ -795,7 +800,6 @@ async def test_get_unset_entries(self, unused_tcp_port, val_servicer): EntryRequest('Vehicle.ADAS.ABS.IsActive', View.TARGET_VALUE, (Field.ACTUATOR_TARGET,)), )) - assert entries == [DataEntry('Vehicle.Speed'), DataEntry( 'Vehicle.ADAS.ABS.IsActive')] @@ -991,8 +995,32 @@ async def test_authorize_successful(self, unused_tcp_port, val_servicer): name='test_server', version='1.2.3') async with VSSClient('127.0.0.1', unused_tcp_port, ensure_startup_connection=False) as client: # token from kuksa.val directory under jwt/provide-vehicle-speed.token - success = await client.authorize(token='eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJsb2NhbCBkZXYiLCJpc3MiOiJjcmVhdGVUb2tlbi5weSIsImF1ZCI6WyJrdWtzYS52YWwiXSwiaWF0IjoxNTE2MjM5MDIyLCJleHAiOjE3NjcyMjU1OTksInNjb3BlIjoicmVhZDpWZWhpY2xlLldpZHRoIHByb3ZpZGU6VmVoaWNsZS5TcGVlZCJ9.w2c8xrYwBVgMav3f0Se6E8H8E36Nd03rJiSS2A8s-CL3GtlwB7wVanjXHhppNsCdWym3tK4JwgslQdMQF-UL4hd7vzdtt-Mx6VjH_jO9mDxz4Z0Uzw7aJtbtQSpi2h6kwceTVTllkbLRF7WRHWIpwzXFF9yZolX6lH-BE9xf1AB62d6icd9SKxFnVvYs3MVK5D1xNmDNOmm-Fr0d2K604MmIIXGW5kPZJYIvBKO4NYRLklhJe47It_lGo3gnh1ppmzTOIo1kB4sDe55hplUCbTCJVricpyQSgTYsf7aFRPK51XMRwwwJ8kShWeaTggMLKpv1W-9dhVWDk4isC8BxsOjaVloArausMmjLmTz6KwAsfARgfXtaCrMsESUBNXi5KIdAyHVXZpmERvc9yeYPcaWlknVFrFsHbV6bw4nwqBX-0Ubuga0NGNQDFKmyTKQrbuZmQ3L9iipxY8_BOSCkdiYtWbE3lpplxpS_PaZl10KAaMmUfbcF9aYZunDEzEtoJgJe2EeGu3XDBtbyXVUKruImdSEdjaImfUGQIWl5bMbVH4N4zK5jE45wT5FJiRUcA5pMN5wNmDYJJzgbxWNpYW40KZYPFc_7XUH8EZ2Cs69wDHam3ArkOs1qMgMIoEPWVzHakjlVJfrPR9zQKxfirBtNNENIoHsBjJ_P4FEJCN4') - assert client.authorization_header == 'Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJsb2NhbCBkZXYiLCJpc3MiOiJjcmVhdGVUb2tlbi5weSIsImF1ZCI6WyJrdWtzYS52YWwiXSwiaWF0IjoxNTE2MjM5MDIyLCJleHAiOjE3NjcyMjU1OTksInNjb3BlIjoicmVhZDpWZWhpY2xlLldpZHRoIHByb3ZpZGU6VmVoaWNsZS5TcGVlZCJ9.w2c8xrYwBVgMav3f0Se6E8H8E36Nd03rJiSS2A8s-CL3GtlwB7wVanjXHhppNsCdWym3tK4JwgslQdMQF-UL4hd7vzdtt-Mx6VjH_jO9mDxz4Z0Uzw7aJtbtQSpi2h6kwceTVTllkbLRF7WRHWIpwzXFF9yZolX6lH-BE9xf1AB62d6icd9SKxFnVvYs3MVK5D1xNmDNOmm-Fr0d2K604MmIIXGW5kPZJYIvBKO4NYRLklhJe47It_lGo3gnh1ppmzTOIo1kB4sDe55hplUCbTCJVricpyQSgTYsf7aFRPK51XMRwwwJ8kShWeaTggMLKpv1W-9dhVWDk4isC8BxsOjaVloArausMmjLmTz6KwAsfARgfXtaCrMsESUBNXi5KIdAyHVXZpmERvc9yeYPcaWlknVFrFsHbV6bw4nwqBX-0Ubuga0NGNQDFKmyTKQrbuZmQ3L9iipxY8_BOSCkdiYtWbE3lpplxpS_PaZl10KAaMmUfbcF9aYZunDEzEtoJgJe2EeGu3XDBtbyXVUKruImdSEdjaImfUGQIWl5bMbVH4N4zK5jE45wT5FJiRUcA5pMN5wNmDYJJzgbxWNpYW40KZYPFc_7XUH8EZ2Cs69wDHam3ArkOs1qMgMIoEPWVzHakjlVJfrPR9zQKxfirBtNNENIoHsBjJ_P4FEJCN4' + token = ('eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJsb2NhbCBkZXYiLCJpc3MiOiJjcmVhdGVUb2' + 'tlbi5weSIsImF1ZCI6WyJrdWtzYS52YWwiXSwiaWF0IjoxNTE2MjM5MDIyLCJleHAiOjE3NjcyMjU1OTksIn' + 'Njb3BlIjoicmVhZDpWZWhpY2xlLldpZHRoIHByb3ZpZGU6VmVoaWNsZS5TcGVlZCJ9.w2c8xrYwBVgMav3f0S' + 'e6E8H8E36Nd03rJiSS2A8s-CL3GtlwB7wVanjXHhppNsCdWym3tK4JwgslQdMQF-UL4hd7vzdtt-Mx6VjH_jO9' + 'mDxz4Z0Uzw7aJtbtQSpi2h6kwceTVTllkbLRF7WRHWIpwzXFF9yZolX6lH-BE9xf1AB62d6icd9SKxFnVvYs3M' + 'VK5D1xNmDNOmm-Fr0d2K604MmIIXGW5kPZJYIvBKO4NYRLklhJe47It_lGo3gnh1ppmzTOIo1kB4sDe55hplUCb' + 'TCJVricpyQSgTYsf7aFRPK51XMRwwwJ8kShWeaTggMLKpv1W-9dhVWDk4isC8BxsOjaVloArausMmjLmTz6KwAsf' + 'ARgfXtaCrMsESUBNXi5KIdAyHVXZpmERvc9yeYPcaWlknVFrFsHbV6bw4nwqBX-0Ubuga0NGNQDFKmyTKQrbuZmQ' + '3L9iipxY8_BOSCkdiYtWbE3lpplxpS_PaZl10KAaMmUfbcF9aYZunDEzEtoJgJe2EeGu3XDBtbyXVUKruImdSEdja' + 'ImfUGQIWl5bMbVH4N4zK5jE45wT5FJiRUcA5pMN5wNmDYJJzgbxWNpYW40KZYPFc_7XUH8EZ2Cs69wDHam3ArkOs1' + 'qMgMIoEPWVzHakjlVJfrPR9zQKxfirBtNNENIoHsBjJ_P4FEJCN4' + ) + bearer = ('Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJsb2NhbCBkZXYiLCJpc3MiOiJjcmVhdG' + 'VUb2tlbi5weSIsImF1ZCI6WyJrdWtzYS52YWwiXSwiaWF0IjoxNTE2MjM5MDIyLCJleHAiOjE3NjcyMjU1OTks' + 'InNjb3BlIjoicmVhZDpWZWhpY2xlLldpZHRoIHByb3ZpZGU6VmVoaWNsZS5TcGVlZCJ9.w2c8xrYwBVgMav3f0' + 'Se6E8H8E36Nd03rJiSS2A8s-CL3GtlwB7wVanjXHhppNsCdWym3tK4JwgslQdMQF-UL4hd7vzdtt-Mx6VjH_jO' + '9mDxz4Z0Uzw7aJtbtQSpi2h6kwceTVTllkbLRF7WRHWIpwzXFF9yZolX6lH-BE9xf1AB62d6icd9SKxFnVvYs3M' + 'VK5D1xNmDNOmm-Fr0d2K604MmIIXGW5kPZJYIvBKO4NYRLklhJe47It_lGo3gnh1ppmzTOIo1kB4sDe55hplUCbT' + 'CJVricpyQSgTYsf7aFRPK51XMRwwwJ8kShWeaTggMLKpv1W-9dhVWDk4isC8BxsOjaVloArausMmjLmTz6KwAsf' + 'ARgfXtaCrMsESUBNXi5KIdAyHVXZpmERvc9yeYPcaWlknVFrFsHbV6bw4nwqBX-0Ubuga0NGNQDFKmyTKQrbuZmQ' + '3L9iipxY8_BOSCkdiYtWbE3lpplxpS_PaZl10KAaMmUfbcF9aYZunDEzEtoJgJe2EeGu3XDBtbyXVUKruImdSEdj' + 'aImfUGQIWl5bMbVH4N4zK5jE45wT5FJiRUcA5pMN5wNmDYJJzgbxWNpYW40KZYPFc_7XUH8EZ2Cs69wDHam3ArkO' + 's1qMgMIoEPWVzHakjlVJfrPR9zQKxfirBtNNENIoHsBjJ_P4FEJCN4' + ) + success = await client.authorize(token) + assert client.authorization_header == bearer assert success == "Authenticated" @pytest.mark.usefixtures('val_server') @@ -1002,7 +1030,7 @@ async def test_authorize_unsuccessful(self, unused_tcp_port, val_servicer): async with VSSClient('127.0.0.1', unused_tcp_port, ensure_startup_connection=False) as client: with pytest.raises(VSSClientError): await client.authorize(token='') - assert client.authorization_header == None + assert client.authorization_header is None @pytest.mark.usefixtures('val_server') async def test_subscribe_some_entries(self, mocker, unused_tcp_port, val_servicer): diff --git a/kuksa_databroker/databroker/src/broker.rs b/kuksa_databroker/databroker/src/broker.rs index 5e9a18e36..330e516a3 100644 --- a/kuksa_databroker/databroker/src/broker.rs +++ b/kuksa_databroker/databroker/src/broker.rs @@ -930,20 +930,6 @@ impl<'a, 'b> DatabaseReadAccess<'a, 'b> { } } - pub fn get_entries_by_regex(&self, regex: regex::Regex) -> Result, ReadError> { - let mut entries: Vec = Vec::new(); - for (_, value) in self.db.entries.iter() { - if regex.is_match(&value.metadata.path) { - entries.push(value.clone()); - } - } - if entries.is_empty() { - return Err(ReadError::NotFound); - } - - Ok(entries) - } - pub fn get_metadata_by_id(&self, id: i32) -> Option<&Metadata> { self.db.entries.get(&id).map(|entry| &entry.metadata) } @@ -1231,15 +1217,6 @@ impl<'a, 'b> AuthorizedAccess<'a, 'b> { .cloned() } - pub async fn get_entries_by_regex(&self, regex: regex::Regex) -> Result, ReadError> { - self.broker - .database - .read() - .await - .authorized_read_access(self.permissions) - .get_entries_by_regex(regex) - } - pub async fn get_entry_by_id(&self, id: i32) -> Result { self.broker .database diff --git a/kuksa_databroker/databroker/src/grpc/kuksa_val_v1/val.rs b/kuksa_databroker/databroker/src/grpc/kuksa_val_v1/val.rs index f0aaca44a..35fcf0dd0 100644 --- a/kuksa_databroker/databroker/src/grpc/kuksa_val_v1/val.rs +++ b/kuksa_databroker/databroker/src/grpc/kuksa_val_v1/val.rs @@ -23,6 +23,7 @@ use tokio_stream::StreamExt; use tracing::debug; use crate::broker; +use crate::broker::EntryReadAccess; use crate::broker::ReadError; use crate::broker::SubscriptionError; use crate::glob; @@ -53,7 +54,23 @@ impl proto::val_server::Val for broker::DataBroker { } else { let mut entries = Vec::new(); let mut errors = Vec::new(); - + /* + * valid_requests: A collection of valid requests, each represented as a tuple with five fields: + * - Regex: The regular expression created from the string path request. + * - Fields: A HashSet of proto::Field objects extracted from the request. + * - RequestPath: The original request path, used for error reporting when no entries match. + * - IsMatch: A boolean flag indicating whether the current request matches any entry. + * - Error: An optional ReadError representing a permission error that may occur when querying a valid path entry. + */ + let mut valid_requests: Vec<( + regex::Regex, + HashSet, + String, + bool, + Option, + )> = Vec::new(); + + // Fill valid_requests structure. for request in requested { if !glob::is_valid_pattern(&request.path) { errors.push(proto::DataEntryError { @@ -66,7 +83,6 @@ impl proto::val_server::Val for broker::DataBroker { }); continue; } - let view = proto::View::from_i32(request.view).ok_or_else(|| { tonic::Status::invalid_argument(format!("Invalid View (id: {}", request.view)) })?; @@ -76,64 +92,92 @@ impl proto::val_server::Val for broker::DataBroker { let view_fields = combine_view_and_fields(view, fields); debug!("Getting fields: {:?}", view_fields); - match broker.get_entry_by_path(&request.path).await { - Ok(entry) => { - let proto_entry = proto_entry_from_entry_and_fields(entry, view_fields); - debug!("Getting datapoint: {:?}", proto_entry); - entries.push(proto_entry); + let regex_exp = glob::to_regex(&request.path); + match regex_exp { + Ok(value) => { + valid_requests.push((value, view_fields, request.path, false, None)); + } + Err(_) => { + errors.push(proto::DataEntryError { + path: request.path, + error: Some(proto::Error { + code: 400, + reason: "bad regex".to_owned(), + message: "Regex can't be created for provided path".to_owned(), + }), + }); } - Err(ReadError::NotFound) => { - let regex = glob::to_regex(&request.path); - match regex { - Ok(value_regex) => { - match broker.get_entries_by_regex(value_regex).await { - Ok(entries_result) => { - for data_entry in entries_result { - let entry_type = data_entry.metadata.entry_type.clone(); - - let proto_entry = proto_entry_from_entry_and_fields( - data_entry, - view_fields.clone(), - ); - debug!("Getting datapoint: {:?}", proto_entry); - - if view == proto::View::TargetValue { - if entry_type == broker::EntryType::Actuator { - entries.push(proto_entry); - } - } else { - entries.push(proto_entry); - } + } + } + if !valid_requests.is_empty() { + broker + .for_each_entry(|entry| { + let mut result_fields: HashSet = HashSet::new(); + for (regex, view_fields, _, is_match, op_error) in &mut valid_requests { + let path = &entry.metadata().path; + if regex.is_match(path) { + // Update the `is_match` to indicate a valid and used request path. + *is_match = true; + if view_fields.contains(&proto::Field::Metadata) { + result_fields.extend(view_fields.clone()); + } else { + match entry.datapoint() { + Ok(_value) => { + // If the entry's path matches the regex and there is access permission, + // add the result fields to the current entry. + result_fields.extend(view_fields.clone()); } - if entries.is_empty() { - errors.push(proto::DataEntryError { - path: request.path, - error: Some(proto::Error { - code: 404, - reason: "not_found".to_owned(), - message: "No entries found for the provided call parameters" - .to_owned(), - }), - }); + Err(error) => { + //Propagate the error + *op_error = Some(error); } } - Err(read_error) => { - let data_entry_error = create_data_entry_error( - request.path.clone(), - read_error, - ); - errors.push(data_entry_error); - } } } - Err(_) => todo!(), } - } - Err(read_error) => { - let data_entry_error = - create_data_entry_error(request.path.clone(), read_error); - errors.push(data_entry_error); - } + + // If there are result fields, add them to the entries list. + if !result_fields.is_empty() { + let proto_entry = + proto_entry_from_entry_and_fields(entry, result_fields); + debug!("Getting datapoint: {:?}", proto_entry); + entries.push(proto_entry); + } + }) + .await; + } + + /* + * Handle Unmatched or Permission Errors + * + * After processing valid requests, this section iterates over the `valid_requests` vector + * to check if any requests didn't have matching entries or encountered permission errors. + * + * For each unmatched request, a "not_found" error message is added to the `errors` list. + * For requests with permission errors, a "forbidden" error message is added. + */ + for (_, _, path, is_match, error) in valid_requests { + if !is_match { + errors.push(proto::DataEntryError { + path: path.to_owned(), + error: Some(proto::Error { + code: 404, + reason: "not_found".to_owned(), + message: "No entries found for the provided path".to_owned(), + }), + }); + } else if let Some(_error) = error { + // clear the entries vector since we only want to return rerrors + // and not partial success + entries.clear(); + errors.push(proto::DataEntryError { + path: path.to_owned(), + error: Some(proto::Error { + code: 403, + reason: "forbidden".to_owned(), + message: "Permission denied for some entries".to_owned(), + }), + }); } } @@ -405,19 +449,25 @@ fn convert_to_proto_stream( } fn proto_entry_from_entry_and_fields( - entry: broker::Entry, + entry: EntryReadAccess, fields: HashSet, ) -> proto::DataEntry { - let path = entry.metadata.path; + let path = entry.metadata().path.to_string(); let value = if fields.contains(&proto::Field::Value) { - Option::::from(entry.datapoint) + match entry.datapoint() { + Ok(value) => Option::::from(value.clone()), + Err(_) => None, + } } else { None }; let actuator_target = if fields.contains(&proto::Field::ActuatorTarget) { - match entry.actuator_target { - Some(actuator_target) => Option::::from(actuator_target), - None => None, + match entry.actuator_target() { + Ok(value) => match value { + Some(value) => Option::::from(value.clone()), + None => None, + }, + Err(_) => None, } } else { None @@ -430,15 +480,15 @@ fn proto_entry_from_entry_and_fields( if all || fields.contains(&proto::Field::MetadataDataType) { metadata_is_set = true; - metadata.data_type = proto::DataType::from(entry.metadata.data_type) as i32; + metadata.data_type = proto::DataType::from(entry.metadata().data_type.clone()) as i32; } if all || fields.contains(&proto::Field::MetadataDescription) { metadata_is_set = true; - metadata.description = Some(entry.metadata.description); + metadata.description = Some(entry.metadata().description.clone()); } if all || fields.contains(&proto::Field::MetadataEntryType) { metadata_is_set = true; - metadata.entry_type = proto::EntryType::from(&entry.metadata.entry_type) as i32; + metadata.entry_type = proto::EntryType::from(&entry.metadata().entry_type) as i32; } if all || fields.contains(&proto::Field::MetadataComment) { metadata_is_set = true; @@ -462,7 +512,7 @@ fn proto_entry_from_entry_and_fields( if all || fields.contains(&proto::Field::MetadataActuator) { metadata_is_set = true; // TODO: Add to Metadata - metadata.entry_specific = match entry.metadata.entry_type { + metadata.entry_specific = match entry.metadata().entry_type { broker::EntryType::Actuator => { // Some(proto::metadata::EntrySpecific::Actuator( // proto::Actuator::default(), @@ -475,7 +525,7 @@ fn proto_entry_from_entry_and_fields( if all || fields.contains(&proto::Field::MetadataSensor) { metadata_is_set = true; // TODO: Add to Metadata - metadata.entry_specific = match entry.metadata.entry_type { + metadata.entry_specific = match entry.metadata().entry_type { broker::EntryType::Sensor => { // Some(proto::metadata::EntrySpecific::Sensor( // proto::Sensor::default(), @@ -488,7 +538,7 @@ fn proto_entry_from_entry_and_fields( if all || fields.contains(&proto::Field::MetadataAttribute) { metadata_is_set = true; // TODO: Add to Metadata - metadata.entry_specific = match entry.metadata.entry_type { + metadata.entry_specific = match entry.metadata().entry_type { broker::EntryType::Attribute => { // Some(proto::metadata::EntrySpecific::Attribute( // proto::Attribute::default(), @@ -513,34 +563,6 @@ fn proto_entry_from_entry_and_fields( } } -fn create_data_entry_error( - request_path: String, - read_error: broker::ReadError, -) -> proto::DataEntryError { - let error = match read_error { - broker::ReadError::NotFound => proto::Error { - code: 404, - reason: "not_found".to_owned(), - message: "No entries found for the provided path".to_owned(), - }, - broker::ReadError::PermissionExpired => proto::Error { - code: 401, - reason: "unauthorized".to_owned(), - message: "Authorization expired".to_owned(), - }, - broker::ReadError::PermissionDenied => proto::Error { - code: 403, - reason: "forbidden".to_owned(), - message: "Permission denied".to_owned(), - }, - }; - - proto::DataEntryError { - path: request_path, - error: Some(error), - } -} - fn combine_view_and_fields( view: proto::View, fields: impl IntoIterator,