Switch a lot of Strings to Arc<str>

This commit is contained in:
Slatian 2023-10-08 09:12:06 +02:00
parent 223abdd804
commit 5c74de5685
5 changed files with 47 additions and 38 deletions

View File

@ -16,7 +16,7 @@ idna = "0.3"
lazy_static = "1.4.0"
parking_lot = "0.12"
regex = "1.7"
serde = { version = "1", features = ["derive"] }
serde = { version = "1", features = ["derive","rc"] }
tokio = { version = "1", features = ["macros","signal"] }
tera = "1"
toml = "0.7"

View File

@ -1,6 +1,7 @@
use serde::{Deserialize,Serialize};
use trust_dns_resolver::config::Protocol;
use std::sync::Arc;
use std::collections::HashMap;
use std::net::SocketAddr;
@ -10,12 +11,12 @@ pub struct DnsConfig {
pub allow_forward_lookup: bool,
pub allow_reverse_lookup: bool,
pub hidden_suffixes: Vec<String>,
pub resolver: HashMap<String,DnsResolverConfig>,
pub resolver: HashMap<Arc<str>,DnsResolverConfig>,
pub enable_system_resolver: bool,
pub system_resolver_name: String,
pub system_resolver_name: Arc<str>,
pub system_resolver_weight: i32,
pub system_resolver_id: String,
pub system_resolver_id: Arc<str>,
}
#[derive(Deserialize, Serialize, Clone)]
@ -30,16 +31,16 @@ pub enum DnsProtocol {
#[derive(Deserialize, Serialize, Clone)]
pub struct DnsResolverConfig {
pub display_name: String,
pub display_name: Arc<str>,
#[serde(default)]
pub info_url: Option<String>,
pub info_url: Option<Arc<str>>,
#[serde(default)]
pub aliases: Vec<String>,
pub aliases: Vec<Arc<str>>,
#[serde(default="zero")]
pub weight: i32,
pub servers: Vec<SocketAddr>,
pub protocol: DnsProtocol,
pub tls_dns_name: Option<String>,
pub tls_dns_name: Option<Arc<str>>,
#[serde(skip_serializing)] //Don't leak our bind address to the outside
pub bind_address: Option<SocketAddr>,
#[serde(default="default_true")]
@ -63,9 +64,9 @@ impl Default for DnsConfig {
resolver: Default::default(),
enable_system_resolver: true,
system_resolver_name: "System".to_string(),
system_resolver_name: "System".into(),
system_resolver_weight: 1000,
system_resolver_id: "system".to_string(),
system_resolver_id: "system".into(),
}
}
}
@ -91,7 +92,7 @@ impl DnsResolverConfig {
resolver.add_name_server(trust_dns_resolver::config::NameServerConfig{
socket_addr: *server,
protocol: self.protocol.clone().into(),
tls_dns_name: self.tls_dns_name.clone(),
tls_dns_name: self.tls_dns_name.clone().map(|s| s.to_string()),
trust_nx_responses: self.trust_nx_responses,
tls_config: None,
bind_addr: self.bind_address,

View File

@ -76,26 +76,26 @@ pub struct IpResult {
asn: Option<AsnResult>,
location: Option<LocationResult>,
ip_info: AddressInfo,
used_dns_resolver: Option<String>,
used_dns_resolver: Option<Arc<str>>,
}
// We need this one to hide the partial lookup field when irelevant
pub fn not(b: &bool) -> bool { !b }
#[derive(Serialize, Default, Clone)]
#[derive(Serialize, Clone)]
pub struct DigResult {
records: simple_dns::DnsLookupResult,
#[serde(skip_serializing_if = "IdnaName::was_ascii")]
idn: IdnaName,
#[serde(skip_serializing_if = "not")]
partial_lookup: bool,
used_dns_resolver: String,
used_dns_resolver: Arc<str>,
}
struct ServiceSharedState {
templating_engine: templating_engine::Engine,
dns_resolvers: HashMap<String,TokioAsyncResolver>,
dns_resolver_aliases: HashMap<String,String>,
dns_resolvers: HashMap<Arc<str>,TokioAsyncResolver>,
dns_resolver_aliases: HashMap<Arc<str>,Arc<str>>,
asn_db: geoip::MMDBCarrier,
location_db: geoip::MMDBCarrier,
config: config::EchoIpServiceConfig,
@ -105,7 +105,7 @@ struct ServiceSharedState {
#[derive(Clone)]
struct DerivedConfiguration {
dns_resolver_selectables: Vec<Selectable>,
default_resolver: String,
default_resolver: Arc<str>,
}
#[derive(Parser)]
@ -234,8 +234,8 @@ async fn main() {
println!("Initalizing dns resolvers ...");
let mut dns_resolver_selectables = Vec::<Selectable>::new();
let mut dns_resolver_map: HashMap<String,TokioAsyncResolver> = HashMap::new();
let mut dns_resolver_aliases: HashMap<String,String> = HashMap::new();
let mut dns_resolver_map: HashMap<Arc<str>,TokioAsyncResolver> = HashMap::new();
let mut dns_resolver_aliases: HashMap<Arc<str>,Arc<str>> = HashMap::new();
if config.dns.enable_system_resolver {
println!("Initalizing System resolver ...");
@ -290,7 +290,7 @@ async fn main() {
dns_resolver_selectables.sort_by(|a,b| b.weight.cmp(&a.weight));
let default_resolver = dns_resolver_selectables.get(0)
.map(|s| s.id.clone() )
.unwrap_or("none".to_string());
.unwrap_or("none".into());
let derived_config = DerivedConfiguration {
dns_resolver_selectables: dns_resolver_selectables,
default_resolver: default_resolver,
@ -364,10 +364,10 @@ async fn settings_query_middleware<B>(
let mut dns_resolver_id = derived_config.default_resolver;
if let Some(resolver_id) = query.dns {
dns_resolver_id = resolver_id;
dns_resolver_id = resolver_id.into();
} else if let Some(cookie_header) = cookie_header {
if let Some(resolver_id) = cookie_header.0.get("dns_resolver") {
dns_resolver_id = resolver_id.to_string();
dns_resolver_id = resolver_id.into();
}
}
@ -448,7 +448,7 @@ async fn handle_default_route(
}
}
let result = get_ip_result(&address, &settings.lang, &"default".to_string(), &state).await;
let result = get_ip_result(&address, &settings.lang, &"default".into(), &state).await;
let user_agent: Option<String> = match user_agent_header {
Some(TypedHeader(user_agent)) => Some(user_agent.to_string()),
@ -493,11 +493,11 @@ async fn handle_search_request(
}
if let Some(via_cap) = VIA_REGEX.captures(&search_query) {
if let Some(via) = via_cap.get(1).map(|c| c.as_str().to_string()) {
if let Some(via) = via_cap.get(1) {
let state = Arc::clone(&arc_state);
if state.dns_resolvers.contains_key(&via) {
settings.dns_resolver_id = via;
} else if let Some(alias) = state.dns_resolver_aliases.get(&via) {
if state.dns_resolvers.contains_key(via.as_str()) {
settings.dns_resolver_id = via.as_str().into();
} else if let Some(alias) = state.dns_resolver_aliases.get(via.as_str()) {
settings.dns_resolver_id = alias.clone();
}
}
@ -537,7 +537,7 @@ async fn handle_dns_resolver_route_with_path(
extract::Path(query): extract::Path<String>,
) -> Response {
let state = Arc::clone(&arc_state);
if let Some(resolver) = state.config.dns.resolver.get(&query) {
if let Some(resolver) = state.config.dns.resolver.get(query.as_str()) {
state.templating_engine.render_view(
&settings,
&View::DnsResolver{ config: resolver.clone() },
@ -584,7 +584,7 @@ async fn handle_ip_request(
async fn get_ip_result(
address: &IpAddr,
lang: &String,
dns_resolver_name: &String,
dns_resolver_name: &Arc<str>,
state: &ServiceSharedState,
) -> IpResult {
@ -605,7 +605,7 @@ async fn get_ip_result(
// do reverse lookup
let mut hostname: Option<String> = None;
let mut used_dns_resolver: Option<String> = None;
let mut used_dns_resolver: Option<Arc<str>> = None;
if state.config.dns.allow_reverse_lookup {
if let Some(dns_resolver) = &state.dns_resolvers.get(dns_resolver_name) {
hostname = simple_dns::reverse_lookup(&dns_resolver, &address).await;
@ -673,13 +673,13 @@ async fn handle_dig_request(
async fn get_dig_result(
dig_query: &String,
dns_resolver_name: &String,
dns_resolver_name: &Arc<str>,
state: &ServiceSharedState,
do_full_lookup: bool,
) -> DigResult {
let name = &dig_query.trim().trim_end_matches(".").to_string();
let idna_name = IdnaName::from_string(&name);
if let Some(dns_resolver) = &state.dns_resolvers.get(dns_resolver_name) {
if let Some(dns_resolver) = state.dns_resolvers.get(dns_resolver_name) {
if let Ok(domain_name) = Name::from_str_relaxed(name.to_owned()+".") {
if match_domain_hidden_list(&name, &state.config.dns.hidden_suffixes) {
// Try to hide the fact that we didn't do dns resolution at all
@ -707,21 +707,27 @@ async fn get_dig_result(
}
}
} else {
// Invalid domain name
return DigResult {
records: DnsLookupResult{
invalid_name: true,
.. Default::default()
},
.. Default::default()
idn: idna_name,
partial_lookup: !do_full_lookup,
used_dns_resolver: dns_resolver_name.clone(),
}
}
} else {
// Unknown resolver name
return DigResult {
records: DnsLookupResult{
unkown_resolver: true,
.. Default::default()
},
.. Default::default()
idn: idna_name,
partial_lookup: !do_full_lookup,
used_dns_resolver: "unkown_resolver".into(),
}
}
}

View File

@ -1,5 +1,7 @@
use serde::{Deserialize,Serialize};
use std::sync::Arc;
/* Response format */
#[derive(Deserialize, Serialize, Clone, Copy)]
@ -39,13 +41,13 @@ pub struct QuerySettings {
pub format: ResponseFormat,
pub lang: String,
pub available_dns_resolvers: Vec<Selectable>,
pub dns_resolver_id: String,
pub dns_resolver_id: Arc<str>,
}
#[derive(Deserialize, Serialize, Clone)]
pub struct Selectable {
pub id: String,
pub name: String,
pub id: Arc<str>,
pub name: Arc<str>,
pub weight: i32,
}

View File

@ -120,7 +120,7 @@ impl Engine {
View::NotFound => *response.status_mut() = StatusCode::NOT_FOUND,
_ => {},
}
let cookie = Cookie::build("dns_resolver",settings.dns_resolver_id.clone())
let cookie = Cookie::build("dns_resolver",settings.dns_resolver_id.to_string())
.path("/")
.same_site(cookie::SameSite::Strict)
.finish();