From 8c0f8df46d2a07cf17b914a312c63ba84b2d5f5d Mon Sep 17 00:00:00 2001 From: Yuwei Ba Date: Fri, 2 Aug 2024 16:26:05 +1000 Subject: [PATCH] feat(api): implement some control APIs (#524) * mem usage * DNS * restart * clippy * clippy * clippy * clippy --- Cargo.lock | 11 ++ clash/tests/data/config/rules.yaml | 2 +- clash_lib/Cargo.toml | 2 + clash_lib/src/app/api/handlers/dns.rs | 109 +++++++++++++++++- clash_lib/src/app/api/handlers/memory.rs | 77 +++++++++++++ clash_lib/src/app/api/handlers/mod.rs | 5 +- clash_lib/src/app/api/handlers/restart.rs | 53 +++++++++ clash_lib/src/app/api/mod.rs | 8 +- .../src/app/dispatcher/statistics_manager.rs | 7 ++ 9 files changed, 267 insertions(+), 7 deletions(-) create mode 100644 clash_lib/src/app/api/handlers/memory.rs create mode 100644 clash_lib/src/app/api/handlers/restart.rs diff --git a/Cargo.lock b/Cargo.lock index 6cc6575c..34b17bcd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1005,6 +1005,7 @@ dependencies = [ "lru_time_cache", "maxminddb", "md-5", + "memory-stats", "mockall", "murmur3", "netstack-lwip", @@ -3101,6 +3102,16 @@ dependencies = [ "libc", ] +[[package]] +name = "memory-stats" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c73f5c649995a115e1a0220b35e4df0a1294500477f97a91d0660fb5abeb574a" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "merlin" version = "3.0.0" diff --git a/clash/tests/data/config/rules.yaml b/clash/tests/data/config/rules.yaml index b7b85d87..c8932bc8 100644 --- a/clash/tests/data/config/rules.yaml +++ b/clash/tests/data/config/rules.yaml @@ -4,7 +4,7 @@ socks-port: 8889 mixed-port: 8899 tun: - enable: true + enable: false device-id: "dev://utun1989" ipv6: true diff --git a/clash_lib/Cargo.toml b/clash_lib/Cargo.toml index 1d3ea3a8..9787e89f 100644 --- a/clash_lib/Cargo.toml +++ b/clash_lib/Cargo.toml @@ -123,6 +123,8 @@ console-subscriber = { version = "0.4.0" } tracing-timing = { version = "0.6.0" } criterion = { version = "0.5", features = ["html_reports", "async_tokio"], optional = true } +memory-stats = "1.0.0" + [dev-dependencies] tempfile = "3.10" mockall = "0.13.0" diff --git a/clash_lib/src/app/api/handlers/dns.rs b/clash_lib/src/app/api/handlers/dns.rs index 3d8ff7e2..08e2d3c8 100644 --- a/clash_lib/src/app/api/handlers/dns.rs +++ b/clash_lib/src/app/api/handlers/dns.rs @@ -1,7 +1,15 @@ use std::sync::Arc; -use axum::{response::IntoResponse, routing::get, Router}; +use axum::{ + extract::{Query, State}, + response::IntoResponse, + routing::get, + Json, Router, +}; +use hickory_proto::{op::Message, rr::RecordType}; use http::StatusCode; +use serde::Deserialize; +use serde_json::{Map, Value}; use crate::app::{api::AppState, dns::ThreadSafeDNSResolver}; @@ -14,10 +22,103 @@ struct DNSState { pub fn routes(resolver: ThreadSafeDNSResolver) -> Router> { let state = DNSState { resolver }; Router::new() - .route("/dns", get(query_dns)) + .route("/query", get(query_dns)) .with_state(state) } -async fn query_dns() -> impl IntoResponse { - StatusCode::NOT_IMPLEMENTED +#[derive(Deserialize)] +struct DnsQUery { + name: String, + #[serde(rename = "type")] + typ: String, +} + +async fn query_dns( + State(state): State, + q: Query, +) -> impl IntoResponse { + if let crate::app::dns::ResolverKind::System = state.resolver.kind() { + return (StatusCode::BAD_REQUEST, "Clash resolver is not enabled.") + .into_response(); + } + let typ: RecordType = q.typ.parse().unwrap_or(RecordType::A); + let mut m = Message::new(); + + let name = hickory_proto::rr::Name::from_str_relaxed(q.name.as_str()); + + if name.is_err() { + return (StatusCode::BAD_REQUEST, "Invalid name").into_response(); + } + + m.add_query(hickory_proto::op::Query::query(name.unwrap(), typ)); + + match state.resolver.exchange(m).await { + Ok(response) => { + let mut resp = Map::new(); + resp.insert("Status".to_owned(), response.response_code().low().into()); + resp.insert( + "Question".to_owned(), + response + .queries() + .iter() + .map(|x| { + let mut data = Map::new(); + data.insert("name".to_owned(), x.name().to_string().into()); + data.insert( + "qtype".to_owned(), + u16::from(x.query_type()).into(), + ); + data.insert( + "qclass".to_owned(), + u16::from(x.query_class()).into(), + ); + data.into() + }) + .collect::>() + .into(), + ); + + resp.insert("TC".to_owned(), response.truncated().into()); + resp.insert("RD".to_owned(), response.recursion_desired().into()); + resp.insert("RA".to_owned(), response.recursion_available().into()); + resp.insert("AD".to_owned(), response.authentic_data().into()); + resp.insert("CD".to_owned(), response.checking_disabled().into()); + + let rr2json = |rr: &hickory_proto::rr::Record| -> Value { + let mut data = Map::new(); + data.insert("name".to_owned(), rr.name().to_string().into()); + data.insert("type".to_owned(), u16::from(rr.record_type()).into()); + data.insert("ttl".to_owned(), rr.ttl().into()); + data.insert( + "data".to_owned(), + rr.data().map(|x| x.to_string()).unwrap_or_default().into(), + ); + data.into() + }; + + if response.answer_count() > 0 { + resp.insert( + "Answer".to_owned(), + response.answers().iter().map(rr2json).collect(), + ); + } + + if response.name_server_count() > 0 { + resp.insert( + "Authority".to_owned(), + response.name_servers().iter().map(rr2json).collect(), + ); + } + + if response.additional_count() > 0 { + resp.insert( + "Additional".to_owned(), + response.additionals().iter().map(rr2json).collect(), + ); + } + + Json(resp).into_response() + } + Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), + } } diff --git a/clash_lib/src/app/api/handlers/memory.rs b/clash_lib/src/app/api/handlers/memory.rs new file mode 100644 index 00000000..7451075c --- /dev/null +++ b/clash_lib/src/app/api/handlers/memory.rs @@ -0,0 +1,77 @@ +use std::sync::Arc; + +use axum::{ + body::Body, + extract::{ws::Message, FromRequest, Query, Request, State, WebSocketUpgrade}, + response::IntoResponse, + Json, +}; +use http::HeaderMap; +use serde::{Deserialize, Serialize}; +use tracing::{debug, warn}; + +use crate::app::api::AppState; + +use super::utils::is_request_websocket; + +#[derive(Deserialize)] +pub struct GetMemoryQuery { + interval: Option, +} + +#[derive(Serialize)] +struct GetMemoryResponse { + inuse: usize, + oslimit: usize, +} +pub async fn handle( + headers: HeaderMap, + State(state): State>, + q: Query, + req: Request, +) -> impl IntoResponse { + if !is_request_websocket(headers) { + let mgr = state.statistics_manager.clone(); + let snapshot = GetMemoryResponse { + inuse: mgr.memory_usage(), + oslimit: 0, + }; + return Json(snapshot).into_response(); + } + + let ws = match WebSocketUpgrade::from_request(req, &state).await { + Ok(ws) => ws, + Err(e) => { + warn!("ws upgrade error: {}", e); + return e.into_response(); + } + }; + + ws.on_failed_upgrade(|e| { + warn!("ws upgrade error: {}", e); + }) + .on_upgrade(move |mut socket| async move { + let interval = q.interval; + + let mgr = state.statistics_manager.clone(); + + loop { + let snapshot = GetMemoryResponse { + inuse: mgr.memory_usage(), + oslimit: 0, + }; + let j = serde_json::to_vec(&snapshot).unwrap(); + let body = String::from_utf8(j).unwrap(); + + if let Err(e) = socket.send(Message::Text(body)).await { + debug!("send memory snapshot failed: {}", e); + break; + } + + tokio::time::sleep(tokio::time::Duration::from_secs( + interval.unwrap_or(1), + )) + .await; + } + }) +} diff --git a/clash_lib/src/app/api/handlers/mod.rs b/clash_lib/src/app/api/handlers/mod.rs index 5c60eb73..0ea4ccc9 100644 --- a/clash_lib/src/app/api/handlers/mod.rs +++ b/clash_lib/src/app/api/handlers/mod.rs @@ -3,9 +3,12 @@ pub mod connection; pub mod dns; pub mod hello; pub mod log; +pub mod memory; pub mod provider; pub mod proxy; +pub mod restart; pub mod rule; pub mod traffic; -mod utils; pub mod version; + +mod utils; diff --git a/clash_lib/src/app/api/handlers/restart.rs b/clash_lib/src/app/api/handlers/restart.rs new file mode 100644 index 00000000..213d7ab8 --- /dev/null +++ b/clash_lib/src/app/api/handlers/restart.rs @@ -0,0 +1,53 @@ +#[cfg(unix)] +use std::os::unix::process::CommandExt; + +use axum::{response::IntoResponse, Json}; +use serde_json::Map; + +pub async fn handle() -> impl IntoResponse { + match std::env::current_exe() { + Ok(exec) => { + let mut map = Map::new(); + map.insert("status".to_owned(), "ok".into()); + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + #[cfg(unix)] + { + use tracing::info; + + let err = std::process::Command::new(exec) + .args(std::env::args().skip(1)) + .envs(std::env::vars()) + .exec(); + info!("process restarted: {}", err); + } + #[cfg(windows)] + { + use tracing::error; + + match std::process::Command::new(exec) + .args(std::env::args().skip(1)) + .envs(std::env::vars()) + .stdin(std::process::Stdio::inherit()) + .stdout(std::process::Stdio::inherit()) + .stderr(std::process::Stdio::inherit()) + .spawn() + { + Ok(_) => { + // exit the current process + std::process::exit(0); + } + Err(e) => { + error!("Failed to restart: {}", e); + } + } + } + }); + Json(map).into_response() + } + Err(e) => { + (http::StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() + } + } +} diff --git a/clash_lib/src/app/api/mod.rs b/clash_lib/src/app/api/mod.rs index 8aa1db37..a2a0e0c2 100644 --- a/clash_lib/src/app/api/mod.rs +++ b/clash_lib/src/app/api/mod.rs @@ -1,6 +1,10 @@ use std::{net::SocketAddr, path::PathBuf, sync::Arc}; -use axum::{response::Redirect, routing::get, Router}; +use axum::{ + response::Redirect, + routing::{get, post}, + Router, +}; use http::{header, Method}; use tokio::sync::{broadcast::Sender, Mutex}; @@ -68,6 +72,8 @@ pub fn get_api_runner( .route("/logs", get(handlers::log::handle)) .route("/traffic", get(handlers::traffic::handle)) .route("/version", get(handlers::version::handle)) + .route("/memory", get(handlers::memory::handle)) + .route("/restart", post(handlers::restart::handle)) .nest( "/configs", handlers::config::routes( diff --git a/clash_lib/src/app/dispatcher/statistics_manager.rs b/clash_lib/src/app/dispatcher/statistics_manager.rs index fe027b3f..d98c8eef 100644 --- a/clash_lib/src/app/dispatcher/statistics_manager.rs +++ b/clash_lib/src/app/dispatcher/statistics_manager.rs @@ -7,6 +7,7 @@ use std::{ }; use chrono::Utc; +use memory_stats::memory_stats; use serde::Serialize; use tokio::sync::{oneshot::Sender, Mutex, RwLock}; @@ -55,6 +56,7 @@ pub struct Snapshot { download_total: i64, upload_total: i64, connections: Vec, + memory: usize, } type ConnectionMap = HashMap)>; @@ -176,6 +178,7 @@ impl Manager { .upload_total .load(std::sync::atomic::Ordering::Relaxed), connections, + memory: self.memory_usage(), } } @@ -189,6 +192,10 @@ impl Manager { self.download_total.store(0, Ordering::Relaxed); } + pub fn memory_usage(&self) -> usize { + memory_stats().map(|x| x.physical_mem).unwrap_or(0) + } + async fn kick_off(&self) { let mut ticker = tokio::time::interval(std::time::Duration::from_secs(1)); loop {