Skip to content

Commit

Permalink
Merge pull request #20 from kolapapa/0.6
Browse files Browse the repository at this point in the history
Remove uuid && optimization
  • Loading branch information
kolapapa authored May 18, 2022
2 parents 45dddc6 + 1914752 commit d4eb57b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 66 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ rand = "0.8.5"
socket2 = { version = "0.4.4", features = ["all"] }
thiserror = "1.0.30"
tokio = { version = "1.17.0", features = ["time", "macros"] }
tracing = "0.1.32"
uuid = { version = "1.0.0", features = ["v4"] }
tracing = "0.1.34"


[dev-dependencies]
structopt = "0.3.26"
Expand Down
70 changes: 29 additions & 41 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,22 @@ use std::os::windows::io::{FromRawSocket, IntoRawSocket};

use std::{
collections::HashMap,
convert::TryInto,
io,
net::{IpAddr, SocketAddr},
sync::Arc,
time::Instant,
};

use pnet_packet::{icmp, icmpv6, ipv4, Packet};
use rand::random;
use socket2::{Domain, Protocol, Socket, Type};
use tokio::{
net::UdpSocket,
sync::{broadcast, mpsc, Mutex},
task,
sync::{mpsc, Mutex},
task::{self, JoinHandle},
};
use tracing::warn;
use uuid::Uuid;

use crate::{config::Config, Pinger, ICMP};

Expand Down Expand Up @@ -81,21 +82,24 @@ impl AsyncSocket {
}
}

pub(crate) type UniqueId = [u8; 16];
pub(crate) type ClientMapping = Arc<Mutex<HashMap<UniqueId, mpsc::Sender<Message>>>>;
///
/// If you want to pass the `Client` in the task, please wrap it with `Arc`: `Arc<Client>`.
/// and can realize the simultaneous ping of multiple addresses when only one `socket` is created.
///
#[derive(Clone)]
pub struct Client {
socket: AsyncSocket,
mapping: Arc<Mutex<HashMap<Uuid, mpsc::Sender<Message>>>>,
shutdown_tx: broadcast::Sender<()>,
mapping: ClientMapping,
recv: Arc<JoinHandle<()>>,
}

impl Drop for Client {
fn drop(&mut self) {
if self.shutdown_tx.send(()).is_err() {
warn!("Client shutdown error.");
// The client may pass through multiple tasks, so need to judge whether the number of references is 1.
if Arc::strong_count(&self.recv) <= 1 {
self.recv.abort();
}
}
}
Expand All @@ -106,63 +110,47 @@ impl Client {
pub fn new(config: &Config) -> io::Result<Self> {
let socket = AsyncSocket::new(config)?;
let mapping = Arc::new(Mutex::new(HashMap::new()));
let (shutdown_tx, _) = broadcast::channel(1);
task::spawn(recv_task(
socket.clone(),
mapping.clone(),
shutdown_tx.subscribe(),
));
let recv = task::spawn(recv_task(socket.clone(), mapping.clone()));

Ok(Self {
socket,
mapping,
shutdown_tx,
recv: Arc::new(recv),
})
}

/// Create a `Pinger` instance, you can make special configuration for this instance. Such as `timeout`, `size` etc.
pub async fn pinger(&self, host: IpAddr) -> Pinger {
let (tx, rx) = mpsc::channel(10);
let key = Uuid::new_v4();
let key: UniqueId = random();
{
self.mapping.lock().await.insert(key, tx);
}
Pinger::new(host, self.socket.clone(), rx, key, self.mapping.clone())
}
}

async fn recv_task(
socket: AsyncSocket,
mapping: Arc<Mutex<HashMap<Uuid, mpsc::Sender<Message>>>>,
mut shutdown_rx: broadcast::Receiver<()>,
) {
async fn recv_task(socket: AsyncSocket, mapping: ClientMapping) {
let mut buf = [0; 2048];

loop {
tokio::select! {
response = socket.recv_from(&mut buf) => {
if let Ok((sz, addr)) = response {
let datas = buf[0..sz].to_vec();
if let Some(uuid) = gen_uuid_with_payload(addr.ip(), datas.as_slice()) {
let instant = Instant::now();
let mut w = mapping.lock().await;
if let Some(tx) = (*w).get(&uuid) {
if tx.send(Message::new(instant, datas)).await.is_err() {
warn!("Pinger({}) already closed.", addr);
(*w).remove(&uuid);
}
}
if let Ok((sz, addr)) = socket.recv_from(&mut buf).await {
let datas = buf[0..sz].to_vec();
if let Some(uid) = gen_uid_with_payload(addr.ip(), datas.as_slice()) {
let instant = Instant::now();
let mut w = mapping.lock().await;
if let Some(tx) = (*w).get(&uid) {
if tx.send(Message::new(instant, datas)).await.is_err() {
warn!("Pinger({}) already closed.", addr);
(*w).remove(&uid);
}
}
}
_ = shutdown_rx.recv() => {
break;
}
}
}
}

fn gen_uuid_with_payload(addr: IpAddr, datas: &[u8]) -> Option<Uuid> {
fn gen_uid_with_payload(addr: IpAddr, datas: &[u8]) -> Option<UniqueId> {
match addr {
IpAddr::V4(_) => {
if let Some(ip_packet) = ipv4::Ipv4Packet::new(datas) {
Expand All @@ -173,8 +161,8 @@ fn gen_uuid_with_payload(addr: IpAddr, datas: &[u8]) -> Option<Uuid> {
return None;
}

let uuid = &payload[4..20];
return Uuid::from_slice(uuid).ok();
let uid = &payload[4..20];
return uid.try_into().ok();
}
}
}
Expand All @@ -186,8 +174,8 @@ fn gen_uuid_with_payload(addr: IpAddr, datas: &[u8]) -> Option<Uuid> {
return None;
}

let uuid = &payload[4..20];
return Uuid::from_slice(uuid).ok();
let uid = &payload[4..20];
return uid.try_into().ok();
}
}
}
Expand Down
35 changes: 12 additions & 23 deletions src/ping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@ use std::{
use parking_lot::Mutex;
use rand::random;
use tokio::{
sync::{broadcast, mpsc, Mutex as TokioMutex},
sync::{broadcast, mpsc},
task,
time::timeout,
};
use tracing::warn;
use uuid::Uuid;

use crate::client::{AsyncSocket, Message};
use crate::client::{AsyncSocket, ClientMapping, Message, UniqueId};
use crate::error::{Result, SurgeError};
use crate::icmp::{icmpv4, icmpv6, IcmpPacket};

Expand Down Expand Up @@ -51,7 +50,7 @@ pub struct Pinger {
socket: AsyncSocket,
rx: mpsc::Receiver<Message>,
cache: Cache,
key: Uuid,
key: UniqueId,
clear_tx: broadcast::Sender<()>,
}

Expand All @@ -68,8 +67,8 @@ impl Pinger {
host: IpAddr,
socket: AsyncSocket,
rx: mpsc::Receiver<Message>,
key: Uuid,
mapping: Arc<TokioMutex<HashMap<Uuid, mpsc::Sender<Message>>>>,
key: UniqueId,
mapping: ClientMapping,
) -> Pinger {
let (clear_tx, _) = broadcast::channel(1);
task::spawn(clear_mapping_key(key, mapping, clear_tx.subscribe()));
Expand Down Expand Up @@ -131,18 +130,12 @@ impl Pinger {
pub async fn ping(&mut self, seq_cnt: u16) -> Result<(IcmpPacket, Duration)> {
let sender = self.socket.clone();
let mut packet = match self.destination {
IpAddr::V4(_) => icmpv4::make_icmpv4_echo_packet(
self.ident,
seq_cnt,
self.size,
self.key.as_bytes(),
)?,
IpAddr::V6(_) => icmpv6::make_icmpv6_echo_packet(
self.ident,
seq_cnt,
self.size,
self.key.as_bytes(),
)?,
IpAddr::V4(_) => {
icmpv4::make_icmpv4_echo_packet(self.ident, seq_cnt, self.size, &self.key)?
}
IpAddr::V6(_) => {
icmpv6::make_icmpv6_echo_packet(self.ident, seq_cnt, self.size, &self.key)?
}
};
// let mut packet = EchoRequest::new(self.host, self.ident, seq_cnt, self.size).encode()?;
let sock_addr = SocketAddr::new(self.destination, 0);
Expand All @@ -164,11 +157,7 @@ impl Pinger {
}
}

async fn clear_mapping_key(
key: Uuid,
mapping: Arc<TokioMutex<HashMap<Uuid, mpsc::Sender<Message>>>>,
mut rx: broadcast::Receiver<()>,
) {
async fn clear_mapping_key(key: UniqueId, mapping: ClientMapping, mut rx: broadcast::Receiver<()>) {
if rx.recv().await.is_ok() {
mapping.lock().await.remove(&key);
}
Expand Down

0 comments on commit d4eb57b

Please sign in to comment.