diff --git a/Cargo.toml b/Cargo.toml index aff3838..4546131 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ idna = "0.3" lazy_static = "1.4.0" parking_lot = "0.12" regex = "1.7" -serde = { version = "1", features = ["derive"] } +serde = { version = "1", features = ["derive","rc"] } tokio = { version = "1", features = ["macros","signal"] } tera = "1" toml = "0.7" diff --git a/src/config/dns.rs b/src/config/dns.rs index 6db438b..94872e6 100644 --- a/src/config/dns.rs +++ b/src/config/dns.rs @@ -1,6 +1,7 @@ use serde::{Deserialize,Serialize}; use trust_dns_resolver::config::Protocol; +use std::sync::Arc; use std::collections::HashMap; use std::net::SocketAddr; @@ -10,12 +11,12 @@ pub struct DnsConfig { pub allow_forward_lookup: bool, pub allow_reverse_lookup: bool, pub hidden_suffixes: Vec, - pub resolver: HashMap, + pub resolver: HashMap,DnsResolverConfig>, pub enable_system_resolver: bool, - pub system_resolver_name: String, + pub system_resolver_name: Arc, pub system_resolver_weight: i32, - pub system_resolver_id: String, + pub system_resolver_id: Arc, } #[derive(Deserialize, Serialize, Clone)] @@ -30,16 +31,16 @@ pub enum DnsProtocol { #[derive(Deserialize, Serialize, Clone)] pub struct DnsResolverConfig { - pub display_name: String, + pub display_name: Arc, #[serde(default)] - pub info_url: Option, + pub info_url: Option>, #[serde(default)] - pub aliases: Vec, + pub aliases: Vec>, #[serde(default="zero")] pub weight: i32, pub servers: Vec, pub protocol: DnsProtocol, - pub tls_dns_name: Option, + pub tls_dns_name: Option>, #[serde(skip_serializing)] //Don't leak our bind address to the outside pub bind_address: Option, #[serde(default="default_true")] @@ -63,9 +64,9 @@ impl Default for DnsConfig { resolver: Default::default(), enable_system_resolver: true, - system_resolver_name: "System".to_string(), + system_resolver_name: "System".into(), system_resolver_weight: 1000, - system_resolver_id: "system".to_string(), + system_resolver_id: "system".into(), } } } @@ -91,7 +92,7 @@ impl DnsResolverConfig { resolver.add_name_server(trust_dns_resolver::config::NameServerConfig{ socket_addr: *server, protocol: self.protocol.clone().into(), - tls_dns_name: self.tls_dns_name.clone(), + tls_dns_name: self.tls_dns_name.clone().map(|s| s.to_string()), trust_nx_responses: self.trust_nx_responses, tls_config: None, bind_addr: self.bind_address, diff --git a/src/main.rs b/src/main.rs index 5163fb4..37e557c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -76,26 +76,26 @@ pub struct IpResult { asn: Option, location: Option, ip_info: AddressInfo, - used_dns_resolver: Option, + used_dns_resolver: Option>, } // We need this one to hide the partial lookup field when irelevant pub fn not(b: &bool) -> bool { !b } -#[derive(Serialize, Default, Clone)] +#[derive(Serialize, Clone)] pub struct DigResult { records: simple_dns::DnsLookupResult, #[serde(skip_serializing_if = "IdnaName::was_ascii")] idn: IdnaName, #[serde(skip_serializing_if = "not")] partial_lookup: bool, - used_dns_resolver: String, + used_dns_resolver: Arc, } struct ServiceSharedState { templating_engine: templating_engine::Engine, - dns_resolvers: HashMap, - dns_resolver_aliases: HashMap, + dns_resolvers: HashMap,TokioAsyncResolver>, + dns_resolver_aliases: HashMap,Arc>, asn_db: geoip::MMDBCarrier, location_db: geoip::MMDBCarrier, config: config::EchoIpServiceConfig, @@ -105,7 +105,7 @@ struct ServiceSharedState { #[derive(Clone)] struct DerivedConfiguration { dns_resolver_selectables: Vec, - default_resolver: String, + default_resolver: Arc, } #[derive(Parser)] @@ -234,8 +234,8 @@ async fn main() { println!("Initalizing dns resolvers ..."); let mut dns_resolver_selectables = Vec::::new(); - let mut dns_resolver_map: HashMap = HashMap::new(); - let mut dns_resolver_aliases: HashMap = HashMap::new(); + let mut dns_resolver_map: HashMap,TokioAsyncResolver> = HashMap::new(); + let mut dns_resolver_aliases: HashMap,Arc> = HashMap::new(); if config.dns.enable_system_resolver { println!("Initalizing System resolver ..."); @@ -290,7 +290,7 @@ async fn main() { dns_resolver_selectables.sort_by(|a,b| b.weight.cmp(&a.weight)); let default_resolver = dns_resolver_selectables.get(0) .map(|s| s.id.clone() ) - .unwrap_or("none".to_string()); + .unwrap_or("none".into()); let derived_config = DerivedConfiguration { dns_resolver_selectables: dns_resolver_selectables, default_resolver: default_resolver, @@ -364,10 +364,10 @@ async fn settings_query_middleware( let mut dns_resolver_id = derived_config.default_resolver; if let Some(resolver_id) = query.dns { - dns_resolver_id = resolver_id; + dns_resolver_id = resolver_id.into(); } else if let Some(cookie_header) = cookie_header { if let Some(resolver_id) = cookie_header.0.get("dns_resolver") { - dns_resolver_id = resolver_id.to_string(); + dns_resolver_id = resolver_id.into(); } } @@ -448,7 +448,7 @@ async fn handle_default_route( } } - let result = get_ip_result(&address, &settings.lang, &"default".to_string(), &state).await; + let result = get_ip_result(&address, &settings.lang, &"default".into(), &state).await; let user_agent: Option = match user_agent_header { Some(TypedHeader(user_agent)) => Some(user_agent.to_string()), @@ -493,11 +493,11 @@ async fn handle_search_request( } if let Some(via_cap) = VIA_REGEX.captures(&search_query) { - if let Some(via) = via_cap.get(1).map(|c| c.as_str().to_string()) { + if let Some(via) = via_cap.get(1) { let state = Arc::clone(&arc_state); - if state.dns_resolvers.contains_key(&via) { - settings.dns_resolver_id = via; - } else if let Some(alias) = state.dns_resolver_aliases.get(&via) { + if state.dns_resolvers.contains_key(via.as_str()) { + settings.dns_resolver_id = via.as_str().into(); + } else if let Some(alias) = state.dns_resolver_aliases.get(via.as_str()) { settings.dns_resolver_id = alias.clone(); } } @@ -537,7 +537,7 @@ async fn handle_dns_resolver_route_with_path( extract::Path(query): extract::Path, ) -> Response { let state = Arc::clone(&arc_state); - if let Some(resolver) = state.config.dns.resolver.get(&query) { + if let Some(resolver) = state.config.dns.resolver.get(query.as_str()) { state.templating_engine.render_view( &settings, &View::DnsResolver{ config: resolver.clone() }, @@ -584,7 +584,7 @@ async fn handle_ip_request( async fn get_ip_result( address: &IpAddr, lang: &String, - dns_resolver_name: &String, + dns_resolver_name: &Arc, state: &ServiceSharedState, ) -> IpResult { @@ -605,7 +605,7 @@ async fn get_ip_result( // do reverse lookup let mut hostname: Option = None; - let mut used_dns_resolver: Option = None; + let mut used_dns_resolver: Option> = None; if state.config.dns.allow_reverse_lookup { if let Some(dns_resolver) = &state.dns_resolvers.get(dns_resolver_name) { hostname = simple_dns::reverse_lookup(&dns_resolver, &address).await; @@ -673,13 +673,13 @@ async fn handle_dig_request( async fn get_dig_result( dig_query: &String, - dns_resolver_name: &String, + dns_resolver_name: &Arc, state: &ServiceSharedState, do_full_lookup: bool, ) -> DigResult { let name = &dig_query.trim().trim_end_matches(".").to_string(); let idna_name = IdnaName::from_string(&name); - if let Some(dns_resolver) = &state.dns_resolvers.get(dns_resolver_name) { + if let Some(dns_resolver) = state.dns_resolvers.get(dns_resolver_name) { if let Ok(domain_name) = Name::from_str_relaxed(name.to_owned()+".") { if match_domain_hidden_list(&name, &state.config.dns.hidden_suffixes) { // Try to hide the fact that we didn't do dns resolution at all @@ -707,21 +707,27 @@ async fn get_dig_result( } } } else { + // Invalid domain name return DigResult { records: DnsLookupResult{ invalid_name: true, .. Default::default() }, - .. Default::default() + idn: idna_name, + partial_lookup: !do_full_lookup, + used_dns_resolver: dns_resolver_name.clone(), } } } else { + // Unknown resolver name return DigResult { records: DnsLookupResult{ unkown_resolver: true, .. Default::default() }, - .. Default::default() + idn: idna_name, + partial_lookup: !do_full_lookup, + used_dns_resolver: "unkown_resolver".into(), } } } diff --git a/src/settings.rs b/src/settings.rs index 39ee983..b5ebbfb 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,5 +1,7 @@ use serde::{Deserialize,Serialize}; +use std::sync::Arc; + /* Response format */ #[derive(Deserialize, Serialize, Clone, Copy)] @@ -39,13 +41,13 @@ pub struct QuerySettings { pub format: ResponseFormat, pub lang: String, pub available_dns_resolvers: Vec, - pub dns_resolver_id: String, + pub dns_resolver_id: Arc, } #[derive(Deserialize, Serialize, Clone)] pub struct Selectable { - pub id: String, - pub name: String, + pub id: Arc, + pub name: Arc, pub weight: i32, } diff --git a/src/templating_engine.rs b/src/templating_engine.rs index faf9a82..4b97924 100644 --- a/src/templating_engine.rs +++ b/src/templating_engine.rs @@ -120,7 +120,7 @@ impl Engine { View::NotFound => *response.status_mut() = StatusCode::NOT_FOUND, _ => {}, } - let cookie = Cookie::build("dns_resolver",settings.dns_resolver_id.clone()) + let cookie = Cookie::build("dns_resolver",settings.dns_resolver_id.to_string()) .path("/") .same_site(cookie::SameSite::Strict) .finish();