Skip to content

Commit

Permalink
✨ hosts-file supports glob matching
Browse files Browse the repository at this point in the history
  • Loading branch information
mokeyish committed Jul 19, 2024
1 parent 2c72b19 commit c867110
Show file tree
Hide file tree
Showing 20 changed files with 205 additions and 34 deletions.
9 changes: 5 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ tracing-subscriber = { version = "0.3", features = [
# tracing-appender = "0.2"

# hickory dns
hickory-proto = { git = "https://github.com/mokeyish/hickory-dns.git", rev = "0.25.0-smartdns.1", version = "0.25.0-alpha.1", features = ["serde-config"]}
hickory-resolver = { git = "https://github.com/mokeyish/hickory-dns.git", rev = "0.25.0-smartdns.1", version = "0.25.0-alpha.1", features = [
hickory-proto = { git = "https://github.com/mokeyish/hickory-dns.git", rev = "0.25.0-smartdns.2", version = "0.25.0-alpha.1", features = ["serde-config"]}
hickory-resolver = { git = "https://github.com/mokeyish/hickory-dns.git", rev = "0.25.0-smartdns.2", version = "0.25.0-alpha.1", features = [
"serde-config",
"system-config",
] }
hickory-server = { git = "https://github.com/mokeyish/hickory-dns.git", rev = "0.25.0-smartdns.1", version = "0.25.0-alpha.1", features = ["resolver"], optional = true }
hickory-server = { git = "https://github.com/mokeyish/hickory-dns.git", rev = "0.25.0-smartdns.2", version = "0.25.0-alpha.1", features = ["resolver"], optional = true }

# ssl
webpki-roots = "0.25.2"
Expand All @@ -155,6 +155,7 @@ hostname = "0.3"
byte-unit = { version = "5.0.3", features = ["serde"]}
ipnet = "2.7"
which = { version = "6.0.1", optional = true }
glob = "0.3.1"

# process
sysinfo = "0.29"
Expand Down
2 changes: 1 addition & 1 deletion src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ pub struct Config {
/// whether resolv local hostname to ip address
pub resolv_hostname: Option<bool>,

pub hosts_file: Option<PathBuf>,
pub hosts_file: Option<glob::Pattern>,

pub expand_ptr_from_address: Option<bool>,

Expand Down
58 changes: 58 additions & 0 deletions src/config/parser/glob_pattern.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use glob::Pattern;

use super::*;

impl NomParser for Pattern {
fn parse(input: &str) -> IResult<&str, Self> {
let delimited_path = delimited(char('"'), is_not("\""), char('"'));
let unix_path = recognize(tuple((
opt(char('/')),
separated_list1(char('/'), escaped(is_not("\n \t\\"), '\\', one_of(r#" \"#))),
opt(char('/')),
)));
let windows_path = recognize(tuple((
opt(pair(alpha1, tag(":\\"))),
separated_list1(char('\\'), is_not("\\")),
opt(char('\\')),
)));
map_res(
alt((delimited_path, unix_path, windows_path)),
FromStr::from_str,
)(input)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_parse() {
assert_eq!(Pattern::parse("a*"), Ok(("", "a*".parse().unwrap())));
assert_eq!(Pattern::parse("/"), Ok(("", "/".parse().unwrap())));
assert_eq!(
Pattern::parse("a/b😁/c"),
Ok(("", "a/b😁/c".parse().unwrap()))
);
assert_eq!(
Pattern::parse("a/ b/c"),
Ok((" b/c", "a/".parse().unwrap()))
);
assert_eq!(
Pattern::parse("/a/b/c"),
Ok(("", "/a/b/c".parse().unwrap()))
);
assert_eq!(
Pattern::parse("/a/b/c/"),
Ok(("", "/a/b/c/".parse().unwrap()))
);
assert_eq!(
Pattern::parse("a/b/c*/"),
Ok(("", "a/b/c*/".parse().unwrap()))
);
assert_eq!(
Pattern::parse("**/*.rs"),
Ok(("", "**/*.rs".parse().unwrap()))
);
}
}
13 changes: 10 additions & 3 deletions src/config/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mod domain_rule;
mod domain_set;
mod file_mode;
mod forward_rule;
mod glob_pattern;
mod ipnet;
mod listener;
mod log_level;
Expand Down Expand Up @@ -96,10 +97,11 @@ pub enum OneConfig {
DualstackIpSelection(bool),
DualstackIpSelectionThreshold(u16),
EdnsClientSubnet(IpNet),
ExpandPtrFromAddress(bool),
ForceAAAASOA(bool),
ForceQtypeSoa(RecordType),
ForwardRule(ForwardRule),
HostsFile(PathBuf),
HostsFile(glob::Pattern),
IgnoreIp(IpNet),
Listener(ListenerConfig),
LocalTtl(u64),
Expand Down Expand Up @@ -182,6 +184,10 @@ pub fn parse_config(input: &str) -> IResult<&str, OneConfig> {
parse_item("edns-client-subnet"),
OneConfig::EdnsClientSubnet,
),
map(
parse_item("expand-ptr-from-address"),
OneConfig::ExpandPtrFromAddress,
),
map(parse_item("force-AAAA-SOA"), OneConfig::ForceAAAASOA),
map(parse_item("force-qtype-soa"), OneConfig::ForceQtypeSoa),
map(parse_item("response"), OneConfig::ResponseMode),
Expand All @@ -199,17 +205,18 @@ pub fn parse_config(input: &str) -> IResult<&str, OneConfig> {
map(parse_item("log-num"), OneConfig::LogNum),
map(parse_item("log-size"), OneConfig::LogSize),
map(parse_item("max-reply-ip-num"), OneConfig::MaxReplyIpNum),
map(parse_item("mdns-lookup"), OneConfig::MdnsLookup),
map(parse_item("nameserver"), OneConfig::ForwardRule),
));

let group3 = alt((
map(parse_item("mdns-lookup"), OneConfig::MdnsLookup),
map(parse_item("nameserver"), OneConfig::ForwardRule),
map(parse_item("proxy-server"), OneConfig::ProxyConfig),
map(parse_item("rr-ttl-reply-max"), OneConfig::RrTtlReplyMax),
map(parse_item("rr-ttl-min"), OneConfig::RrTtlMin),
map(parse_item("rr-ttl-max"), OneConfig::RrTtlMax),
map(parse_item("rr-ttl"), OneConfig::RrTtl),
map(parse_item("resolv-file"), OneConfig::ResolvFile),
map(parse_item("resolv-hostanme"), OneConfig::ResolvHostname),
map(parse_item("response-mode"), OneConfig::ResponseMode),
map(parse_item("server-name"), OneConfig::ServerName),
map(parse_item("speed-check-mode"), OneConfig::SpeedMode),
Expand Down
11 changes: 6 additions & 5 deletions src/dns_conf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ impl RuntimeConfig {

/// hosts file path
#[inline]
pub fn hosts_file(&self) -> Option<&Path> {
self.hosts_file.as_deref()
pub fn hosts_file(&self) -> Option<&glob::Pattern> {
self.hosts_file.as_ref()
}

/// Whether to expand the address record corresponding to PTR record
Expand Down Expand Up @@ -781,6 +781,7 @@ impl RuntimeConfigBuilder {
CacheFile(v) => self.cache.file = Some(v),
CachePersist(v) => self.cache.persist = Some(v),
CName(v) => self.cnames.push(v),
ExpandPtrFromAddress(v) => self.expand_ptr_from_address = Some(v),
NftSet(config) => self.nftsets.push(config),
Server(server) => self.nameservers.push(server),
ResponseMode(mode) => self.response_mode = Some(mode),
Expand Down Expand Up @@ -1122,7 +1123,7 @@ mod tests {
#[test]
fn test_config_domain_rules_without_args() {
let mut cfg = RuntimeConfig::builder();
cfg.config("domain-set -name domain-forwarding-list -file tests/test_confs/block-list.txt");
cfg.config("domain-set -name domain-forwarding-list -file tests/test_data/block-list.txt");
cfg.config("domain-rules /domain-set:domain-forwarding-list/");
assert!(cfg.address_rules.last().is_none());
}
Expand Down Expand Up @@ -1441,7 +1442,7 @@ mod tests {

#[test]
fn test_parse_load_config_file_b() {
let cfg = RuntimeConfig::load_from_file("tests/test_confs/b_main.conf");
let cfg = RuntimeConfig::load_from_file("tests/test_data/b_main.conf");

assert_eq!(cfg.server_name, "SmartDNS123".parse().ok());
assert_eq!(
Expand All @@ -1465,7 +1466,7 @@ mod tests {
#[test]
#[cfg(failed_tests)]
fn test_domain_set() {
let cfg = RuntimeConfig::load_from_file("tests/test_confs/b_main.conf");
let cfg = RuntimeConfig::load_from_file("tests/test_data/b_main.conf");

assert!(!cfg.domain_sets.is_empty());

Expand Down
125 changes: 111 additions & 14 deletions src/dns_mw_hosts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use std::time::{Duration, Instant};
use crate::libdns::proto::op::Query;
use tokio::sync::RwLock;

use crate::dns::*;
use crate::libdns::resolver::Hosts;
use crate::middleware::*;
use crate::{dns::*, log};

pub struct DnsHostsMiddleware(RwLock<Option<(Instant, Arc<Hosts>)>>);

Expand Down Expand Up @@ -41,19 +41,8 @@ impl Middleware<DnsContext, DnsRequest, DnsResponse, DnsError> for DnsHostsMiddl
Some(v) => v,
None => {
let hosts = match ctx.cfg().hosts_file() {
Some(file) => {
if file.exists() {
std::fs::OpenOptions::new()
.read(true)
.open(file)
.map(|f| Hosts::default().read_hosts_conf(f))
.unwrap_or_else(Err)
.unwrap_or_default()
} else {
Hosts::default()
}
}
None => Hosts::new(),
Some(pattern) => read_hosts(pattern.as_str()),
None => Hosts::new(), // read from system hosts file
};
let hosts = Arc::new(hosts);
*self.0.write().await = Some((Instant::now(), hosts.clone()));
Expand All @@ -77,3 +66,111 @@ impl Middleware<DnsContext, DnsRequest, DnsResponse, DnsError> for DnsHostsMiddl
next.run(ctx, req).await
}
}

fn read_hosts(pattern: &str) -> Hosts {
let mut hosts = Hosts::default();
match glob::glob(pattern) {
Ok(paths) => {
for entry in paths {
let path = match entry {
Ok(path) => {
if !path.is_file() {
continue;
}
path
}
Err(err) => {
log::error!("{}", err);
continue;
}
};

let file = match std::fs::OpenOptions::new().read(true).open(path) {
Ok(file) => file,
Err(err) => {
log::error!("{}", err);
continue;
}
};

if let Err(err) = hosts.read_hosts_conf(file) {
log::error!("{}", err);
}
}
}
Err(err) => {
log::error!("{}", err);
}
}
hosts
}

#[cfg(test)]
mod tests {
use std::{net::IpAddr, str::FromStr};

use crate::libdns::proto::rr::rdata::PTR;

use super::*;

use crate::{dns_conf::RuntimeConfig, dns_mw::*};

#[tokio::test()]
async fn test_query_ip() -> anyhow::Result<()> {
let cfg = RuntimeConfig::builder()
.with("hosts-file ./tests/test_data/hosts/a*.hosts")
.build();

let mock = DnsMockMiddleware::mock(DnsHostsMiddleware::new()).build(cfg);

let lookup = mock.lookup("hi.a1", RecordType::A).await?;
let ip_addrs = lookup
.records()
.iter()
.flat_map(|r| r.data().ip_addr())
.collect::<Vec<_>>();
assert_eq!(ip_addrs, vec![IpAddr::from_str("1.1.1.1").unwrap()]);

let lookup = mock.lookup("hi.a2", RecordType::A).await?;
let ip_addrs = lookup
.records()
.iter()
.flat_map(|r| r.data().ip_addr())
.collect::<Vec<_>>();
assert_eq!(ip_addrs, vec![IpAddr::from_str("2.2.2.2").unwrap()]);

Ok(())
}

#[tokio::test()]
async fn test_query_ptr() -> anyhow::Result<()> {
let cfg = RuntimeConfig::builder()
.with("hosts-file ./tests/test_data/hosts/a*.hosts")
.with("expand-ptr-from-address yes")
.build();

let mock = DnsMockMiddleware::mock(DnsHostsMiddleware::new()).build(cfg);

let lookup = mock
.lookup("1.1.1.1.in-addr.arpa.", RecordType::PTR)
.await?;
let hostnames = lookup
.records()
.iter()
.flat_map(|r| r.data().as_ptr())
.collect::<Vec<_>>();
assert_eq!(hostnames, vec![&PTR("hi.a1.".parse().unwrap())]);

let lookup = mock
.lookup("2.2.2.2.in-addr.arpa.", RecordType::PTR)
.await?;
let hostnames = lookup
.records()
.iter()
.flat_map(|r| r.data().as_ptr())
.collect::<Vec<_>>();
assert_eq!(hostnames, vec![&PTR("hi.a2.".parse().unwrap())]);

Ok(())
}
}
Loading

0 comments on commit c867110

Please sign in to comment.