Skip to content
This repository has been archived by the owner on Dec 18, 2024. It is now read-only.

Commit

Permalink
Implement "unsubscribe"
Browse files Browse the repository at this point in the history
  • Loading branch information
argerus committed Aug 29, 2023
1 parent ded3360 commit 84437c6
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 74 deletions.
45 changes: 27 additions & 18 deletions kuksa_databroker/databroker/src/viss/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ use tracing::{debug, error, info};
use futures::{channel::mpsc, Sink};
use futures::{stream::StreamExt, Stream};

use crate::broker;
use crate::viss::v2;
use crate::{broker, viss::v2::VissService};

pub async fn serve(
addr: impl Into<std::net::SocketAddr>,
Expand All @@ -35,7 +35,7 @@ pub async fn serve(
// signal: F
) -> Result<(), Box<dyn std::error::Error>> {
let app = Router::new()
.route("/", get(handle_upgrade::<broker::DataBroker>))
.route("/", get(handle_upgrade))
.with_state(broker);

let addr = addr.into();
Expand All @@ -52,20 +52,17 @@ pub async fn serve(
}

// Handle upgrade request
async fn handle_upgrade<T>(
async fn handle_upgrade(
ws: WebSocketUpgrade,
axum::extract::ConnectInfo(addr): axum::extract::ConnectInfo<SocketAddr>,
axum::extract::State(state): axum::extract::State<T>,
) -> impl IntoResponse
where
T: v2::VissService,
{
axum::extract::State(state): axum::extract::State<broker::DataBroker>,
) -> impl IntoResponse {
debug!("Received websocket upgrade request");
ws.on_upgrade(move |socket| handle_websocket(socket, addr, state))
}

// Handle websocket (one per connection)
async fn handle_websocket(socket: WebSocket, addr: SocketAddr, state: impl v2::VissService) {
async fn handle_websocket(socket: WebSocket, addr: SocketAddr, broker: broker::DataBroker) {
let valid_subprotocol = match socket.protocol() {
Some(subprotocol) => match subprotocol.to_str() {
Ok("VISSv2") => true,
Expand All @@ -92,14 +89,14 @@ async fn handle_websocket(socket: WebSocket, addr: SocketAddr, state: impl v2::V

let (write, read) = socket.split();

handle_viss_v2(write, read, addr, state).await;
handle_viss_v2(write, read, addr, broker).await;
}

async fn handle_viss_v2<W, R>(
write: W,
mut read: R,
client_addr: SocketAddr,
viss: impl v2::VissService,
broker: broker::DataBroker,
) where
W: Sink<Message> + Unpin + Send + 'static,
<W as Sink<Message>>::Error: Send,
Expand All @@ -109,6 +106,8 @@ async fn handle_viss_v2<W, R>(
// single consumer will write to the socket.
let (sender, receiver) = mpsc::channel::<Message>(10);

let server = v2::Server::new(broker);

let mut write_task = tokio::spawn(async move {
let _ = receiver.map(Ok).forward(write).await;
});
Expand All @@ -125,24 +124,25 @@ async fn handle_viss_v2<W, R>(
Ok(request) => match request {
v2::Request::Get(request) => {
debug!("Get request parsed successfully");
match viss.get(request).await {
match server.get(request).await {
Ok(response) => serde_json::to_string(&response),
Err(error_response) => serde_json::to_string(&error_response),
}
}
v2::Request::Set(request) => {
debug!("Set request parsed successfully");
match viss.set(request).await {
debug!("Set request successfully parsed");
match server.set(request).await {
Ok(response) => serde_json::to_string(&response),
Err(error_response) => serde_json::to_string(&error_response),
}
}
v2::Request::Subscribe(request) => {
debug!("Subscribe request parsed successfully");
match viss.subscribe(request).await {
debug!("Subscribe request successfully parsed");
match server.subscribe(request).await {
Ok((response, stream)) => {
// Setup background stream
let mut subscription_sender = sender.clone();

tokio::spawn(async move {
let mut stream = Box::pin(stream);
while let Some(item) = stream.next().await {
Expand All @@ -154,14 +154,15 @@ async fn handle_viss_v2<W, R>(
};

if let Ok(serialized) = serialized {
debug!("Sending notification: {}", serialized);
match subscription_sender
.try_send(Message::Text(serialized))
{
Ok(_) => {
debug!("Successfully sent response")
debug!("Successfully sent notification")
}
Err(err) => {
debug!("Failed to send response: {err}")
debug!("Failed to send notification: {err}")
}
};
}
Expand All @@ -174,12 +175,20 @@ async fn handle_viss_v2<W, R>(
Err(error_response) => serde_json::to_string(&error_response),
}
}
v2::Request::Unsubscribe(request) => {
debug!("Unsubscribe request successfully parsed");
match server.unsubscribe(request).await {
Ok(response) => serde_json::to_string(&response),
Err(error_response) => serde_json::to_string(&error_response),
}
}
},
Err(err) => serde_json::to_string(&err),
};

// Send it
if let Ok(serialized) = serialized {
debug!("Sending response: {}", serialized);
let mut sender = sender;
match sender.try_send(Message::Text(serialized)) {
Ok(_) => debug!("Successfully sent response"),
Expand Down
159 changes: 110 additions & 49 deletions kuksa_databroker/databroker/src/viss/v2/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@ use std::{
collections::{HashMap, HashSet},
convert::{TryFrom, TryInto},
pin::Pin,
sync::Arc,
time::SystemTime,
};

use futures::{Stream, StreamExt};
use futures::{
stream::{AbortHandle, Abortable},
Stream, StreamExt,
};
use tokio::sync::RwLock;

use crate::{broker, permissions};

use super::types::*;

pub fn parse_request(msg: &str) -> Result<Request, Error> {
let request: Request = serde_json::from_str(msg).map_err(|_err| Error::BadRequest)?;
Ok(request)
}

#[tonic::async_trait]
pub(crate) trait VissService: Send + Sync + 'static {
async fn get(&self, request: GetRequest) -> Result<GetSuccessResponse, GetErrorResponse>;
Expand All @@ -42,52 +42,37 @@ pub(crate) trait VissService: Send + Sync + 'static {
&self,
request: SubscribeRequest,
) -> Result<(SubscribeSuccessResponse, Self::SubscribeStream), SubscribeErrorResponse>;

async fn unsubscribe(
&self,
request: UnsubscribeRequest,
) -> Result<UnsubscribeSuccessResponse, UnsubscribeErrorResponse>;
}

fn convert_to_viss_stream(
subscription_id: SubscriptionId,
input: impl Stream<Item = broker::EntryUpdates>,
) -> impl Stream<Item = Result<SubscriptionEvent, SubscriptionErrorEvent>> {
input.map(move |item| {
let ts = SystemTime::now().into();
let subscription_id = subscription_id.clone();
match item.updates.get(0) {
Some(item) => match (&item.update.path, &item.update.datapoint) {
(Some(path), Some(datapoint)) => match datapoint.clone().try_into() {
Ok(dp) => Ok(SubscriptionEvent {
subscription_id,
data: Data::Object(DataObject {
path: path.clone().into(),
dp,
}),
ts,
}),
Err(error) => Err(SubscriptionErrorEvent {
subscription_id,
error,
ts,
}),
},
(_, _) => Err(SubscriptionErrorEvent {
subscription_id,
error: Error::InternalServerError,
ts,
}),
},
None => Err(SubscriptionErrorEvent {
subscription_id,
error: Error::InternalServerError,
ts,
}),
pub struct Server {
broker: broker::DataBroker,
subscriptions: Arc<RwLock<HashMap<SubscriptionId, AbortHandle>>>,
}

impl Server {
pub fn new(broker: broker::DataBroker) -> Self {
Self {
broker,
subscriptions: Arc::new(RwLock::new(HashMap::new())),
}
})
}
}

pub fn parse_request(msg: &str) -> Result<Request, Error> {
let request: Request = serde_json::from_str(msg).map_err(|_err| Error::BadRequest)?;
Ok(request)
}

#[tonic::async_trait]
impl VissService for broker::DataBroker {
impl VissService for Server {
async fn get(&self, request: GetRequest) -> Result<GetSuccessResponse, GetErrorResponse> {
let permissions = &permissions::ALLOW_ALL;
let broker = self.authorized_access(permissions);
let broker = self.broker.authorized_access(permissions);

match broker.get_datapoint_by_path(request.path.as_ref()).await {
Ok(datapoint) => match datapoint.try_into() {
Expand Down Expand Up @@ -118,7 +103,7 @@ impl VissService for broker::DataBroker {

async fn set(&self, request: SetRequest) -> Result<SetSuccessResponse, SetErrorResponse> {
let permissions = &permissions::ALLOW_ALL;
let broker = self.authorized_access(permissions);
let broker = self.broker.authorized_access(permissions);

match broker.get_metadata_by_path(request.path.as_ref()).await {
Some(metadata) => {
Expand Down Expand Up @@ -235,28 +220,65 @@ impl VissService for broker::DataBroker {
>,
>;

async fn unsubscribe(
&self,
request: UnsubscribeRequest,
) -> Result<UnsubscribeSuccessResponse, UnsubscribeErrorResponse> {
let subscription_id = request.subscription_id;
let request_id = request.request_id;
match self.subscriptions.read().await.get(&subscription_id) {
Some(abort_handle) => {
abort_handle.abort();
Ok(UnsubscribeSuccessResponse {
request_id,
subscription_id,
ts: SystemTime::now().into(),
})
}
None => Err(UnsubscribeErrorResponse {
request_id,
subscription_id,
error: Error::BadRequest,
ts: SystemTime::now().into(),
}),
}
}

async fn subscribe(
&self,
request: SubscribeRequest,
) -> Result<(SubscribeSuccessResponse, Self::SubscribeStream), SubscribeErrorResponse> {
let permissions = &permissions::ALLOW_ALL;
let broker = self.authorized_access(permissions);
let broker = self.broker.authorized_access(permissions);

let entries = HashMap::from([(
Into::<String>::into(request.path),
HashSet::from([broker::Field::Datapoint]),
)]);
match broker.subscribe(entries).await {
Ok(stream) => {
let subscription_id = SubscriptionId::from(request.request_id.as_ref());
let subscription_id = SubscriptionId::from(request.request_id.as_ref().to_owned());

let (abort_handle, abort_registration) = AbortHandle::new_pair();

// Make the stream abortable
let stream = Abortable::new(stream, abort_registration);

// Register abort handle
self.subscriptions
.write()
.await
.insert(subscription_id.clone(), abort_handle);

let stream = convert_to_viss_stream(subscription_id.clone(), stream);

Ok((
SubscribeSuccessResponse {
request_id: request.request_id,
subscription_id: subscription_id.clone(),
subscription_id,
ts: SystemTime::now().into(),
},
Box::pin(convert_to_viss_stream(subscription_id, stream)),
Box::pin(stream),
))
}
Err(err) => Err(SubscribeErrorResponse {
Expand All @@ -271,3 +293,42 @@ impl VissService for broker::DataBroker {
}
}
}

fn convert_to_viss_stream(
subscription_id: SubscriptionId,
stream: impl Stream<Item = broker::EntryUpdates>,
) -> impl Stream<Item = Result<SubscriptionEvent, SubscriptionErrorEvent>> {
stream.map(move |item| {
let ts = SystemTime::now().into();
let subscription_id = subscription_id.clone();
match item.updates.get(0) {
Some(item) => match (&item.update.path, &item.update.datapoint) {
(Some(path), Some(datapoint)) => match datapoint.clone().try_into() {
Ok(dp) => Ok(SubscriptionEvent {
subscription_id,
data: Data::Object(DataObject {
path: path.clone().into(),
dp,
}),
ts,
}),
Err(error) => Err(SubscriptionErrorEvent {
subscription_id,
error,
ts,
}),
},
(_, _) => Err(SubscriptionErrorEvent {
subscription_id,
error: Error::InternalServerError,
ts,
}),
},
None => Err(SubscriptionErrorEvent {
subscription_id,
error: Error::InternalServerError,
ts,
}),
}
})
}
Loading

0 comments on commit 84437c6

Please sign in to comment.