Skip to content

Commit

Permalink
feat(api): implement some control APIs (#524)
Browse files Browse the repository at this point in the history
* mem usage

* DNS

* restart

* clippy

* clippy

* clippy

* clippy
  • Loading branch information
ibigbug authored Aug 2, 2024
1 parent 31c5c84 commit 8c0f8df
Show file tree
Hide file tree
Showing 9 changed files with 267 additions and 7 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion clash/tests/data/config/rules.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ socks-port: 8889
mixed-port: 8899

tun:
enable: true
enable: false
device-id: "dev://utun1989"

ipv6: true
Expand Down
2 changes: 2 additions & 0 deletions clash_lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ console-subscriber = { version = "0.4.0" }
tracing-timing = { version = "0.6.0" }
criterion = { version = "0.5", features = ["html_reports", "async_tokio"], optional = true }

memory-stats = "1.0.0"

[dev-dependencies]
tempfile = "3.10"
mockall = "0.13.0"
Expand Down
109 changes: 105 additions & 4 deletions clash_lib/src/app/api/handlers/dns.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
use std::sync::Arc;

use axum::{response::IntoResponse, routing::get, Router};
use axum::{
extract::{Query, State},
response::IntoResponse,
routing::get,
Json, Router,
};
use hickory_proto::{op::Message, rr::RecordType};
use http::StatusCode;
use serde::Deserialize;
use serde_json::{Map, Value};

use crate::app::{api::AppState, dns::ThreadSafeDNSResolver};

Expand All @@ -14,10 +22,103 @@ struct DNSState {
pub fn routes(resolver: ThreadSafeDNSResolver) -> Router<Arc<AppState>> {
let state = DNSState { resolver };
Router::new()
.route("/dns", get(query_dns))
.route("/query", get(query_dns))
.with_state(state)
}

async fn query_dns() -> impl IntoResponse {
StatusCode::NOT_IMPLEMENTED
#[derive(Deserialize)]
struct DnsQUery {
name: String,
#[serde(rename = "type")]
typ: String,
}

async fn query_dns(
State(state): State<DNSState>,
q: Query<DnsQUery>,
) -> impl IntoResponse {
if let crate::app::dns::ResolverKind::System = state.resolver.kind() {
return (StatusCode::BAD_REQUEST, "Clash resolver is not enabled.")
.into_response();
}
let typ: RecordType = q.typ.parse().unwrap_or(RecordType::A);
let mut m = Message::new();

let name = hickory_proto::rr::Name::from_str_relaxed(q.name.as_str());

if name.is_err() {
return (StatusCode::BAD_REQUEST, "Invalid name").into_response();
}

m.add_query(hickory_proto::op::Query::query(name.unwrap(), typ));

match state.resolver.exchange(m).await {
Ok(response) => {
let mut resp = Map::new();
resp.insert("Status".to_owned(), response.response_code().low().into());
resp.insert(
"Question".to_owned(),
response
.queries()
.iter()
.map(|x| {
let mut data = Map::new();
data.insert("name".to_owned(), x.name().to_string().into());
data.insert(
"qtype".to_owned(),
u16::from(x.query_type()).into(),
);
data.insert(
"qclass".to_owned(),
u16::from(x.query_class()).into(),
);
data.into()
})
.collect::<Vec<Value>>()
.into(),
);

resp.insert("TC".to_owned(), response.truncated().into());
resp.insert("RD".to_owned(), response.recursion_desired().into());
resp.insert("RA".to_owned(), response.recursion_available().into());
resp.insert("AD".to_owned(), response.authentic_data().into());
resp.insert("CD".to_owned(), response.checking_disabled().into());

let rr2json = |rr: &hickory_proto::rr::Record| -> Value {
let mut data = Map::new();
data.insert("name".to_owned(), rr.name().to_string().into());
data.insert("type".to_owned(), u16::from(rr.record_type()).into());
data.insert("ttl".to_owned(), rr.ttl().into());
data.insert(
"data".to_owned(),
rr.data().map(|x| x.to_string()).unwrap_or_default().into(),
);
data.into()
};

if response.answer_count() > 0 {
resp.insert(
"Answer".to_owned(),
response.answers().iter().map(rr2json).collect(),
);
}

if response.name_server_count() > 0 {
resp.insert(
"Authority".to_owned(),
response.name_servers().iter().map(rr2json).collect(),
);
}

if response.additional_count() > 0 {
resp.insert(
"Additional".to_owned(),
response.additionals().iter().map(rr2json).collect(),
);
}

Json(resp).into_response()
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
}
}
77 changes: 77 additions & 0 deletions clash_lib/src/app/api/handlers/memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use std::sync::Arc;

use axum::{
body::Body,
extract::{ws::Message, FromRequest, Query, Request, State, WebSocketUpgrade},
response::IntoResponse,
Json,
};
use http::HeaderMap;
use serde::{Deserialize, Serialize};
use tracing::{debug, warn};

use crate::app::api::AppState;

use super::utils::is_request_websocket;

#[derive(Deserialize)]
pub struct GetMemoryQuery {
interval: Option<u64>,
}

#[derive(Serialize)]
struct GetMemoryResponse {
inuse: usize,
oslimit: usize,
}
pub async fn handle(
headers: HeaderMap,
State(state): State<Arc<AppState>>,
q: Query<GetMemoryQuery>,
req: Request<Body>,
) -> impl IntoResponse {
if !is_request_websocket(headers) {
let mgr = state.statistics_manager.clone();
let snapshot = GetMemoryResponse {
inuse: mgr.memory_usage(),
oslimit: 0,
};
return Json(snapshot).into_response();
}

let ws = match WebSocketUpgrade::from_request(req, &state).await {
Ok(ws) => ws,
Err(e) => {
warn!("ws upgrade error: {}", e);
return e.into_response();
}
};

ws.on_failed_upgrade(|e| {
warn!("ws upgrade error: {}", e);
})
.on_upgrade(move |mut socket| async move {
let interval = q.interval;

let mgr = state.statistics_manager.clone();

loop {
let snapshot = GetMemoryResponse {
inuse: mgr.memory_usage(),
oslimit: 0,
};
let j = serde_json::to_vec(&snapshot).unwrap();
let body = String::from_utf8(j).unwrap();

if let Err(e) = socket.send(Message::Text(body)).await {
debug!("send memory snapshot failed: {}", e);
break;
}

tokio::time::sleep(tokio::time::Duration::from_secs(
interval.unwrap_or(1),
))
.await;
}
})
}
5 changes: 4 additions & 1 deletion clash_lib/src/app/api/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ pub mod connection;
pub mod dns;
pub mod hello;
pub mod log;
pub mod memory;
pub mod provider;
pub mod proxy;
pub mod restart;
pub mod rule;
pub mod traffic;
mod utils;
pub mod version;

mod utils;
53 changes: 53 additions & 0 deletions clash_lib/src/app/api/handlers/restart.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#[cfg(unix)]
use std::os::unix::process::CommandExt;

use axum::{response::IntoResponse, Json};
use serde_json::Map;

pub async fn handle() -> impl IntoResponse {
match std::env::current_exe() {
Ok(exec) => {
let mut map = Map::new();
map.insert("status".to_owned(), "ok".into());
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;

#[cfg(unix)]
{
use tracing::info;

let err = std::process::Command::new(exec)
.args(std::env::args().skip(1))
.envs(std::env::vars())
.exec();
info!("process restarted: {}", err);
}
#[cfg(windows)]
{
use tracing::error;

match std::process::Command::new(exec)
.args(std::env::args().skip(1))
.envs(std::env::vars())
.stdin(std::process::Stdio::inherit())
.stdout(std::process::Stdio::inherit())
.stderr(std::process::Stdio::inherit())
.spawn()
{
Ok(_) => {
// exit the current process
std::process::exit(0);
}
Err(e) => {
error!("Failed to restart: {}", e);
}
}
}
});
Json(map).into_response()
}
Err(e) => {
(http::StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
}
}
}
8 changes: 7 additions & 1 deletion clash_lib/src/app/api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::{net::SocketAddr, path::PathBuf, sync::Arc};

use axum::{response::Redirect, routing::get, Router};
use axum::{
response::Redirect,
routing::{get, post},
Router,
};

use http::{header, Method};
use tokio::sync::{broadcast::Sender, Mutex};
Expand Down Expand Up @@ -68,6 +72,8 @@ pub fn get_api_runner(
.route("/logs", get(handlers::log::handle))
.route("/traffic", get(handlers::traffic::handle))
.route("/version", get(handlers::version::handle))
.route("/memory", get(handlers::memory::handle))
.route("/restart", post(handlers::restart::handle))
.nest(
"/configs",
handlers::config::routes(
Expand Down
7 changes: 7 additions & 0 deletions clash_lib/src/app/dispatcher/statistics_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{
};

use chrono::Utc;
use memory_stats::memory_stats;
use serde::Serialize;
use tokio::sync::{oneshot::Sender, Mutex, RwLock};

Expand Down Expand Up @@ -55,6 +56,7 @@ pub struct Snapshot {
download_total: i64,
upload_total: i64,
connections: Vec<TrackerInfo>,
memory: usize,
}

type ConnectionMap = HashMap<uuid::Uuid, (Tracked, Sender<()>)>;
Expand Down Expand Up @@ -176,6 +178,7 @@ impl Manager {
.upload_total
.load(std::sync::atomic::Ordering::Relaxed),
connections,
memory: self.memory_usage(),
}
}

Expand All @@ -189,6 +192,10 @@ impl Manager {
self.download_total.store(0, Ordering::Relaxed);
}

pub fn memory_usage(&self) -> usize {
memory_stats().map(|x| x.physical_mem).unwrap_or(0)
}

async fn kick_off(&self) {
let mut ticker = tokio::time::interval(std::time::Duration::from_secs(1));
loop {
Expand Down

0 comments on commit 8c0f8df

Please sign in to comment.