Configurable multiple dns resolvers

This commit is contained in:
Slatian
2023-08-04 00:00:21 +02:00
parent cd8c0455dc
commit 104a072fd6
7 changed files with 441 additions and 100 deletions

99
src/config/dns.rs Normal file
View File

@ -0,0 +1,99 @@
use serde::{Deserialize,Serialize};
use trust_dns_resolver::config::Protocol;
use trust_dns_resolver::Name;
use std::collections::HashMap;
use std::net::SocketAddr;
#[derive(Deserialize, Clone)]
pub struct DnsConfig {
pub allow_forward_lookup: bool,
pub allow_reverse_lookup: bool,
pub hidden_suffixes: Vec<String>,
#[serde(default="default_dns_resolver_name")]
pub default_resolver: String,
pub resolver: HashMap<String,DnsResolverConfig>,
}
#[derive(Deserialize, Serialize, Clone)]
pub enum DnsProtocol {
Udp,
Tcp,
Tls,
Https,
Quic,
}
pub fn default_dns_resolver_name() -> String {
"default".to_string()
}
#[derive(Deserialize, Serialize, Clone)]
pub struct DnsResolverConfig {
pub display_name: String,
#[serde(default="zero")]
pub weight: i32,
pub servers: Vec<SocketAddr>,
#[serde(default)]
pub search: Vec<String>,
pub protocol: DnsProtocol,
pub tls_dns_name: Option<String>,
pub bind_address: Option<SocketAddr>,
#[serde(default="default_true")]
pub trust_nx_responses: bool,
}
fn zero() -> i32 {
return 0;
}
fn default_true() -> bool {
return true;
}
impl Default for DnsConfig {
fn default() -> Self {
DnsConfig {
allow_forward_lookup: true,
allow_reverse_lookup: false,
hidden_suffixes: Vec::new(),
default_resolver: "default".to_string(),
resolver: Default::default(),
}
}
}
impl Into<Protocol> for DnsProtocol {
fn into(self) -> Protocol {
match self {
Self::Udp => Protocol::Udp,
Self::Tcp => Protocol::Tcp,
Self::Tls => Protocol::Tls,
Self::Https => Protocol::Https,
Self::Quic => Protocol::Quic,
}
}
}
impl DnsResolverConfig {
pub fn to_trust_resolver_config(&self) -> trust_dns_resolver::config::ResolverConfig {
let mut resolver = trust_dns_resolver::config::ResolverConfig::new();
for server in &self.servers {
resolver.add_name_server(trust_dns_resolver::config::NameServerConfig{
socket_addr: *server,
protocol: self.protocol.clone().into(),
tls_dns_name: self.tls_dns_name.clone(),
trust_nx_responses: self.trust_nx_responses,
tls_config: None,
bind_addr: self.bind_address,
});
}
for search in &self.search {
if let Ok(name) = Name::from_str_relaxed(search) {
resolver.add_search(name);
}
}
return resolver;
}
}

View File

@ -1,8 +1,14 @@
use axum_client_ip::SecureClientIpSource;
use serde::Deserialize;
use std::net::SocketAddr;
use std::num::NonZeroU32;
#[derive(serde::Deserialize, Default, Clone)]
mod dns;
pub use crate::config::dns::{DnsConfig, DnsProtocol, DnsResolverConfig};
#[derive(Deserialize, Default, Clone)]
pub struct EchoIpServiceConfig {
pub server: ServerConfig,
pub dns: DnsConfig,
@ -11,7 +17,7 @@ pub struct EchoIpServiceConfig {
pub ratelimit: RatelimitConfig,
}
#[derive(serde::Deserialize, Clone)]
#[derive(Deserialize, Clone)]
pub struct ServerConfig {
pub listen_on: SocketAddr,
pub ip_header: SecureClientIpSource,
@ -20,33 +26,27 @@ pub struct ServerConfig {
pub static_location: Option<String>,
}
#[derive(serde::Deserialize, Clone)]
pub struct DnsConfig {
pub allow_forward_lookup: bool,
pub allow_reverse_lookup: bool,
pub hidden_suffixes: Vec<String>,
//Future Idea: allow custom resolver
}
#[derive(serde::Deserialize, Clone)]
#[derive(Deserialize, Clone)]
pub struct GeoIpConfig {
pub asn_database: Option<String>,
pub location_database: Option<String>,
}
#[derive(serde::Deserialize, Clone)]
#[derive(Deserialize, Clone)]
pub struct TemplateConfig {
pub template_location: String,
pub extra_config: Option<String>,
pub text_user_agents: Vec<String>,
}
#[derive(serde::Deserialize, Clone)]
#[derive(Deserialize, Clone)]
pub struct RatelimitConfig {
pub per_minute: NonZeroU32,
pub burst: NonZeroU32,
}
impl Default for ServerConfig {
fn default() -> Self {
ServerConfig {
@ -58,16 +58,6 @@ impl Default for ServerConfig {
}
}
impl Default for DnsConfig {
fn default() -> Self {
DnsConfig {
allow_forward_lookup: true,
allow_reverse_lookup: false,
hidden_suffixes: Vec::new(),
}
}
}
impl Default for GeoIpConfig {
fn default() -> Self {
GeoIpConfig {
@ -95,3 +85,4 @@ impl Default for RatelimitConfig {
}
}
}

View File

@ -18,14 +18,11 @@ use axum_client_ip::SecureClientIp;
use clap::Parser;
use lazy_static::lazy_static;
use regex::Regex;
use serde::{Deserialize,Serialize};
use tera::Tera;
use tower::ServiceBuilder;
use tower_http::services::ServeDir;
use trust_dns_resolver::{
TokioAsyncResolver,
// config::ResolverOpts,
// config::ResolverConfig,
};
use trust_dns_resolver::TokioAsyncResolver;
use tokio::signal::unix::{
signal,
@ -62,28 +59,26 @@ use crate::templating_engine::{
use crate::ipinfo::{AddressCast,AddressInfo,AddressScope};
#[derive(serde::Deserialize, serde::Serialize, Clone)]
#[derive(Deserialize, Serialize, Clone)]
pub struct SettingsQuery {
format: Option<ResponseFormat>,
lang: Option<String>,
dns: Option<String>,
}
#[derive(serde::Deserialize, serde::Serialize, Clone)]
#[derive(Clone, Serialize)]
pub struct QuerySettings {
#[serde(skip)]
template: TemplateSettings,
dns_resolver_id: String,
}
#[derive(Deserialize, Serialize, Clone)]
pub struct SearchQuery {
query: Option<String>,
}
pub fn default_dns_name() -> String {
"default".to_string()
}
#[derive(serde::Deserialize, serde::Serialize, Clone)]
pub struct ResolverQuery {
#[serde(default="default_dns_name")]
dns: String,
}
#[derive(serde::Deserialize, serde::Serialize, Clone)]
#[derive(Serialize, Clone)]
pub struct IpResult {
address: IpAddr,
hostname: Option<String>,
@ -95,7 +90,7 @@ pub struct IpResult {
// We need this one to hide the partial lookup field when irelevant
pub fn not(b: &bool) -> bool { !b }
#[derive(serde::Deserialize, serde::Serialize, Default, Clone)]
#[derive(Serialize, Default, Clone)]
pub struct DigResult {
records: simple_dns::DnsLookupResult,
#[serde(skip_serializing_if = "IdnaName::was_ascii")]
@ -239,7 +234,7 @@ async fn main() {
// Initalize DNS resolver with os defaults
println!("Initalizing dns resolver ...");
println!("Using System configuration ...");
println!("Initalizing System resolver ...");
let res = TokioAsyncResolver::tokio_from_system_conf();
//let res = TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
let dns_resolver = match res {
@ -269,6 +264,15 @@ async fn main() {
let mut dns_resolver_map: HashMap<String,TokioAsyncResolver> = HashMap::new();
for (key, resolver_config) in &config.dns.resolver {
println!("Initalizing {} resolver ...", key);
let resolver = TokioAsyncResolver::tokio(
resolver_config.to_trust_resolver_config(),
Default::default()
).unwrap();
dns_resolver_map.insert(key.clone(), resolver);
}
dns_resolver_map.insert("default".to_string(), dns_resolver);
dns_resolver_map.insert("quad9".to_string(), quad9_resolver);
dns_resolver_map.insert("google".to_string(), google_resolver);
@ -327,7 +331,7 @@ async fn main() {
config.ratelimit.per_minute, config.ratelimit.burst))
.layer(middleware::from_fn(ratelimit::rate_limit_middleware))
.layer(Extension(config))
.layer(middleware::from_fn(format_and_language_middleware))
.layer(middleware::from_fn(settings_query_middleware))
)
;
@ -340,7 +344,7 @@ async fn main() {
}
async fn format_and_language_middleware<B>(
async fn settings_query_middleware<B>(
Query(query): Query<SettingsQuery>,
Extension(config): Extension<config::EchoIpServiceConfig>,
user_agent_header: Option<TypedHeader<headers::UserAgent>>,
@ -361,9 +365,12 @@ async fn format_and_language_middleware<B>(
}
}
// Add the request settings extension
req.extensions_mut().insert(TemplateSettings{
format: format.unwrap_or(ResponseFormat::TextHtml),
lang: query.lang.unwrap_or("en".to_string()),
req.extensions_mut().insert(QuerySettings{
template: TemplateSettings{
format: format.unwrap_or(ResponseFormat::TextHtml),
lang: query.lang.unwrap_or("en".to_string()),
},
dns_resolver_id: query.dns.unwrap_or(config.dns.default_resolver),
});
next.run(req).await
}
@ -404,9 +411,8 @@ async fn user_agent_handler(
async fn handle_default_route(
Query(search_query): Query<SearchQuery>,
Query(resolver_settings): Query<ResolverQuery>,
State(arc_state): State<Arc<ServiceSharedState>>,
Extension(settings): Extension<TemplateSettings>,
Extension(settings): Extension<QuerySettings>,
user_agent_header: Option<TypedHeader<headers::UserAgent>>,
SecureClientIp(address): SecureClientIp
) -> Response {
@ -419,13 +425,12 @@ async fn handle_default_route(
search_query,
false,
settings,
resolver_settings,
state
).await;
}
}
let result = get_ip_result(&address, &settings.lang, &"default".to_string(), &state).await;
let result = get_ip_result(&address, &settings.template.lang, &"default".to_string(), &state).await;
let user_agent: Option<String> = match user_agent_header {
Some(TypedHeader(user_agent)) => Some(user_agent.to_string()),
@ -433,7 +438,7 @@ async fn handle_default_route(
};
state.templating_engine.render_view(
&settings,
&settings.template,
&View::Index{
result: result,
user_agent: user_agent,
@ -445,15 +450,16 @@ async fn handle_default_route(
async fn handle_search_request(
search_query: String,
this_should_have_been_an_ip: bool,
settings: TemplateSettings,
resolver_settings: ResolverQuery,
settings: QuerySettings,
arc_state: Arc<ServiceSharedState>,
) -> Response {
let search_query = search_query.trim();
let mut search_query = search_query.trim().to_string();
let mut settings = settings;
lazy_static!{
static ref ASN_REGEX: Regex = Regex::new(r"^[Aa][Ss][Nn]?\s*(\d{1,7})$").unwrap();
static ref VIA_REGEX: Regex = Regex::new(r"[Vv][Ii][Aa]\s+(\S+)").unwrap();
}
//If someone asked for an asn, give an asn answer
@ -462,22 +468,31 @@ async fn handle_search_request(
// Render a dummy template that can at least link to other pages
let state = Arc::clone(&arc_state);
return state.templating_engine.render_view(
&settings,
&settings.template,
&View::Asn{asn: asn},
).await
}
}
if let Some(via_cap) = VIA_REGEX.captures(&search_query) {
if let Some(via) = via_cap.get(1).map(|c| c.as_str().to_string()) {
let state = Arc::clone(&arc_state);
if state.dns_resolvers.contains_key(&via) {
settings.dns_resolver_id = via;
}
}
search_query = VIA_REGEX.replace(&search_query,"").trim().to_string();
}
// Try to interpret as an IP-Address
if let Ok(address) = search_query.parse() {
return handle_ip_request(address, settings, resolver_settings, arc_state).await;
return handle_ip_request(address, settings, arc_state).await;
}
// Fall back to treating it as a hostname
return handle_dig_request(
search_query.to_string(),
search_query,
settings,
resolver_settings,
arc_state,
!this_should_have_been_an_ip,
).await
@ -485,34 +500,32 @@ async fn handle_search_request(
}
async fn handle_ip_route_with_path(
Extension(settings): Extension<TemplateSettings>,
Extension(settings): Extension<QuerySettings>,
State(arc_state): State<Arc<ServiceSharedState>>,
Query(resolver_settings): Query<ResolverQuery>,
extract::Path(query): extract::Path<String>,
) -> Response {
if let Ok(address) = query.parse() {
return handle_ip_request(address, settings, resolver_settings, arc_state).await
return handle_ip_request(address, settings, arc_state).await
} else {
return handle_search_request(query, true, settings, resolver_settings, arc_state).await;
return handle_search_request(query, true, settings, arc_state).await;
}
}
async fn handle_ip_request(
address: IpAddr,
settings: TemplateSettings,
resolver_settings: ResolverQuery,
settings: QuerySettings,
arc_state: Arc<ServiceSharedState>,
) -> Response {
let state = Arc::clone(&arc_state);
let result = get_ip_result(
&address,
&settings.lang,
&resolver_settings.dns,
&settings.template.lang,
&settings.dns_resolver_id,
&state).await;
state.templating_engine.render_view(
&settings,
&settings.template,
&View::Ip{result: result}
).await
}
@ -580,18 +593,16 @@ async fn get_ip_result(
}
async fn handle_dig_route_with_path(
Query(resolver_settings): Query<ResolverQuery>,
Extension(settings): Extension<TemplateSettings>,
Extension(settings): Extension<QuerySettings>,
State(arc_state): State<Arc<ServiceSharedState>>,
extract::Path(name): extract::Path<String>,
) -> Response {
return handle_dig_request(name, settings, resolver_settings, arc_state, true).await
return handle_dig_request(name, settings, arc_state, true).await
}
async fn handle_dig_request(
dig_query: String,
settings: TemplateSettings,
resolver_settings: ResolverQuery,
settings: QuerySettings,
arc_state: Arc<ServiceSharedState>,
do_full_lookup: bool,
) -> Response {
@ -600,13 +611,13 @@ async fn handle_dig_request(
let dig_result = get_dig_result(
&dig_query,
&resolver_settings.dns,
&settings.dns_resolver_id,
&state,
do_full_lookup
).await;
state.templating_engine.render_view(
&settings,
&settings.template,
&View::Dig{ query: dig_query, result: dig_result}
).await

View File

@ -58,7 +58,7 @@ pub struct TemplateSettings {
/* The echoip view */
#[derive(serde::Deserialize, serde::Serialize, Clone)]
#[derive(serde::Serialize, Clone)]
#[serde(untagged)]
pub enum View {
Asn { asn: u32 },
@ -116,10 +116,10 @@ impl Engine {
_ => text.into_response(),
}
Err(e) => {
println!("There was an error while rendering template {template_name}: {e:?}");
println!("There was an error while rendering template {}: {e:?}", view.template_name());
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Template error in {template_name}, contact owner or see logs.\n")
format!("Template error in {}, contact owner or see logs.\n", view.template_name())
).into_response()
}
}