Skip to content

Commit

Permalink
refactor: remove global variable TX
Browse files Browse the repository at this point in the history
  • Loading branch information
lomirus committed Nov 4, 2024
1 parent 0ae324a commit 27dfced
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 24 deletions.
16 changes: 6 additions & 10 deletions src/file_layer/watcher.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{path::PathBuf, time::Duration};
use std::{path::PathBuf, sync::Arc, time::Duration};

use notify::{Error, RecommendedWatcher, RecursiveMode, Watcher};
use notify_debouncer_full::{
Expand All @@ -7,16 +7,9 @@ use notify_debouncer_full::{
use tokio::{
fs,
runtime::Handle,
sync::mpsc::{channel, Receiver},
sync::{broadcast, mpsc::{channel, Receiver}},
};

use crate::TX;

async fn broadcast() {
let tx = TX.get().unwrap();
let _ = tx.send(());
}

pub(crate) async fn create_watcher(
root: PathBuf,
) -> Result<
Expand Down Expand Up @@ -67,6 +60,7 @@ pub async fn watch(
root_path: PathBuf,
mut debouncer: Debouncer<RecommendedWatcher, FileIdMap>,
mut rx: Receiver<Result<Vec<DebouncedEvent>, Vec<Error>>>,
tx: Arc<broadcast::Sender<()>>
) {
debouncer
.watcher()
Expand Down Expand Up @@ -127,7 +121,9 @@ pub async fn watch(
}
}
if files_changed {
broadcast().await;
if let Err(err) = tx.send(()) {
log::error!("{:?}", err);
}
}
}
}
Expand Down
24 changes: 16 additions & 8 deletions src/http_layer/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ use axum::{
};
use futures::{sink::SinkExt, stream::StreamExt};
use std::{fs, io::ErrorKind, sync::Arc};
use tokio::net::TcpListener;
use tokio::{net::TcpListener, sync::broadcast};

use crate::{ADDR, ROOT, TX};
use crate::{ADDR, ROOT};

/// JS script containing a function that takes in the address and connects to the websocket.
const WEBSOCKET_FUNCTION: &str = include_str!("../templates/websocket.js");
Expand All @@ -32,6 +32,14 @@ pub struct Options {
pub index_listing: bool,
}

pub(crate) struct AppState {
/// Always hard reload the page instead of hot-reload
pub(crate) hard_reload: bool,
/// Show page list of the current URL if `index.html` does not exist
pub(crate) index_listing: bool,
pub(crate) tx: Arc<broadcast::Sender<()>>,
}

impl Default for Options {
fn default() -> Self {
Self {
Expand All @@ -41,7 +49,8 @@ impl Default for Options {
}
}

pub(crate) fn create_server(options: Options) -> Router {
pub(crate) fn create_server(state: AppState) -> Router {
let tx = state.tx.clone();
Router::new()
.route("/", get(static_assets))
.route("/*path", get(static_assets))
Expand All @@ -51,15 +60,14 @@ pub(crate) fn create_server(options: Options) -> Router {
ws.on_failed_upgrade(|error| {
log::error!("Failed to upgrade websocket: {}", error);
})
.on_upgrade(on_websocket_upgrade)
.on_upgrade(|socket: WebSocket| on_websocket_upgrade(socket, tx))
}),
)
.with_state(Arc::new(options))
.with_state(Arc::new(state))
}

async fn on_websocket_upgrade(socket: WebSocket) {
async fn on_websocket_upgrade(socket: WebSocket, tx: Arc<broadcast::Sender<()>>) {
let (mut sender, mut receiver) = socket.split();
let tx = TX.get().unwrap();
let mut rx = tx.subscribe();
let mut send_task = tokio::spawn(async move {
while rx.recv().await.is_ok() {
Expand Down Expand Up @@ -103,7 +111,7 @@ fn get_index_listing(uri_path: &str) -> String {
}

async fn static_assets(
state: State<Arc<Options>>,
state: State<Arc<AppState>>,
req: Request<Body>,
) -> (StatusCode, HeaderMap, Body) {
let addr = ADDR.get().unwrap();
Expand Down
17 changes: 11 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,19 @@ pub use http_layer::server::Options;
use file_layer::watcher::{create_watcher, watch};
use http_layer::{
listener::create_listener,
server::{create_server, serve},
server::{create_server, serve, AppState},
};
use local_ip_address::local_ip;
use notify::RecommendedWatcher;
use notify_debouncer_full::{DebouncedEvent, Debouncer, FileIdMap};
use std::{error::Error, net::IpAddr, path::PathBuf};
use std::{error::Error, net::IpAddr, path::PathBuf, sync::Arc};
use tokio::{
net::TcpListener,
sync::{broadcast, mpsc::Receiver, OnceCell},
};

static ADDR: OnceCell<String> = OnceCell::const_new();
static ROOT: OnceCell<PathBuf> = OnceCell::const_new();
static TX: OnceCell<broadcast::Sender<()>> = OnceCell::const_new();

pub struct Listener {
tcp_listener: TcpListener,
Expand All @@ -57,10 +56,16 @@ impl Listener {
pub async fn start(self, options: Options) -> Result<(), Box<dyn Error>> {
ROOT.set(self.root_path.clone())?;
let (tx, _) = broadcast::channel(16);
TX.set(tx)?;

let watcher_future = tokio::spawn(watch(self.root_path, self.debouncer, self.rx));
let server_future = tokio::spawn(serve(self.tcp_listener, create_server(options)));
let arc_tx = Arc::new(tx);
let app_state = AppState {
hard_reload: options.hard_reload,
index_listing: options.index_listing,
tx: arc_tx.clone(),
};

let watcher_future = tokio::spawn(watch(self.root_path, self.debouncer, self.rx, arc_tx));
let server_future = tokio::spawn(serve(self.tcp_listener, create_server(app_state)));

tokio::try_join!(watcher_future, server_future)?;

Expand Down

0 comments on commit 27dfced

Please sign in to comment.