diff --git a/kuksa_databroker/databroker/src/broker.rs b/kuksa_databroker/databroker/src/broker.rs index cb3071c3..d62a7ace 100644 --- a/kuksa_databroker/databroker/src/broker.rs +++ b/kuksa_databroker/databroker/src/broker.rs @@ -23,6 +23,7 @@ use tokio_stream::Stream; use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; +use std::mem::take; use std::sync::atomic::{AtomicI32, Ordering}; use std::sync::Arc; use std::time::SystemTime; @@ -154,7 +155,7 @@ pub struct QuerySubscription { pub struct ChangeSubscription { entries: HashMap>, - sender: mpsc::Sender, + sender: Option>, permissions: Permissions, } @@ -606,7 +607,7 @@ impl Subscriptions { } pub async fn notify( - &self, + &mut self, changed: Option<&HashMap>>, db: &Database, ) -> Result>, NotificationError> { @@ -626,10 +627,14 @@ impl Subscriptions { } } - for sub in &self.change_subscriptions { + for sub in &mut self.change_subscriptions { match sub.notify(changed, db).await { Ok(_) => {} - Err(err) => error = Some(err), + Err(err) => { + error = Some(err); + let taken_sender = take(&mut sub.sender); + drop(taken_sender); + } } } @@ -660,11 +665,16 @@ impl Subscriptions { } }); self.change_subscriptions.retain(|sub| { - if sub.sender.is_closed() { + if let Some(sender) = &sub.sender { + if sender.is_closed() { + info!("Subscriber gone: removing subscription"); + false + } else { + true + } + } else { info!("Subscriber gone: removing subscription"); false - } else { - true } }); } @@ -692,7 +702,7 @@ impl ChangeSubscription { // notify let notifications = { let mut notifications = EntryUpdates::default(); - + let mut error = None; for (id, changed_fields) in changed { if let Some(fields) = self.entries.get(id) { if !fields.is_disjoint(changed_fields) { @@ -722,22 +732,32 @@ impl ChangeSubscription { fields: notify_fields, }); } + Err(ReadError::PermissionExpired) => { + debug!("notify: token expired, closing subscription channel"); + error = Some(NotificationError {}); + break; + } Err(_) => { - debug!("notify: could not find entry with id {}", id) + debug!("notify: could not find entry with id {}", id); } } } } } + if let Some(err) = error { + return Err(err); + } notifications }; if notifications.updates.is_empty() { Ok(()) - } else { - match self.sender.send(notifications).await { + } else if let Some(sender) = &self.sender { + match sender.send(notifications).await { Ok(()) => Ok(()), Err(_) => Err(NotificationError {}), } + } else { + Err(NotificationError {}) } } else { Ok(()) @@ -774,9 +794,13 @@ impl ChangeSubscription { } notifications }; - match self.sender.send(notifications).await { - Ok(()) => Ok(()), - Err(_) => Err(NotificationError {}), + if let Some(sender) = &self.sender { + match sender.send(notifications).await { + Ok(()) => Ok(()), + Err(_) => Err(NotificationError {}), + } + } else { + Err(NotificationError {}) } } } @@ -1408,7 +1432,7 @@ impl<'a, 'b> AuthorizedAccess<'a, 'b> { match self .broker .subscriptions - .read() + .write() .await .notify(Some(&changed), &db) .await @@ -1454,7 +1478,7 @@ impl<'a, 'b> AuthorizedAccess<'a, 'b> { let (sender, receiver) = mpsc::channel(10); let subscription = ChangeSubscription { entries: valid_entries, - sender, + sender: Some(sender), permissions: self.permissions.clone(), };