From caf47522e4cc56668a83e192779ac65689eba13b Mon Sep 17 00:00:00 2001 From: Slatian Date: Sun, 9 Feb 2025 16:10:35 +0100 Subject: [PATCH] Use a fallback for when the requested dns resolver isn't available --- src/main.rs | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/main.rs b/src/main.rs index 1aa0fc8..3fcf0bf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -310,7 +310,7 @@ async fn main() { ServeDir::new(static_file_directory) .fallback(not_found_handler.with_state(shared_state.clone())) ) - .with_state(shared_state) + .with_state(shared_state.clone()) .layer( ServiceBuilder::new() .layer(ip_header.into_extension()) @@ -319,7 +319,7 @@ async fn main() { .layer(middleware::from_fn(ratelimit::rate_limit_middleware)) .layer(Extension(config)) .layer(Extension(derived_config)) - .layer(middleware::from_fn(settings_query_middleware)) + .layer(middleware::from_fn_with_state(shared_state, settings_query_middleware)) ) ; @@ -331,26 +331,38 @@ async fn main() { .unwrap(); } - +#[allow(clippy::too_many_arguments)] async fn settings_query_middleware( Query(query): Query, + State(arc_state): State>, Extension(config): Extension, Extension(derived_config): Extension, cookie_header: Option>, user_agent_header: Option>, mut req: Request, - next: Next + next: Next, ) -> Response { + let state = Arc::clone(&arc_state); let mut format = query.format; - let mut dns_resolver_id = derived_config.default_resolver; + + let mut dns_resolver_id = derived_config.default_resolver.clone(); + let mut test_for_resolver = false; if let Some(resolver_id) = query.dns { dns_resolver_id = resolver_id.into(); + test_for_resolver = true; } else if let Some(cookie_header) = cookie_header { if let Some(resolver_id) = cookie_header.0.get("dns_resolver") { dns_resolver_id = resolver_id.into(); + test_for_resolver = true; } } + + // Falls back to the default resolver if an invalid resolver id ws requested. + // This may be the case for bookmarked links or old cookies of a resolver was removed. + if test_for_resolver && !state.dns_resolvers.contains_key(&dns_resolver_id) { + dns_resolver_id = derived_config.default_resolver; + } // Try to guess type from user agent if format.is_none() {