Middlewarre!

This commit is contained in:
Slatian 2023-02-21 00:06:49 +01:00
parent 2abc5844ad
commit 52ace5f61f
8 changed files with 106 additions and 62 deletions

14
Cargo.lock generated
View File

@ -45,6 +45,7 @@ checksum = "678c5130a507ae3a7c797f9a17393c14849300b8440eac47cdb90a5bdcb3a543"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
"axum-macros",
"bitflags", "bitflags",
"bytes", "bytes",
"futures-util", "futures-util",
@ -98,6 +99,18 @@ dependencies = [
"tower-service", "tower-service",
] ]
[[package]]
name = "axum-macros"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5fbf955307ff8addb48d2399393c9e2740dd491537ec562b66ab364fc4a38841"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
@ -1150,6 +1163,7 @@ dependencies = [
"tera", "tera",
"tokio", "tokio",
"toml", "toml",
"tower",
"trust-dns-resolver", "trust-dns-resolver",
] ]

View File

@ -7,12 +7,13 @@ authors = ["Slatian <baschdel@disroot.org>"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
axum = "0.6" axum = { version = "0.6", features = ["macros"] }
axum-client-ip = "0.4" axum-client-ip = "0.4"
clap = { version = "4", features = ["derive"] } clap = { version = "4", features = ["derive"] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
tera = "1" tera = "1"
toml = "0.7" toml = "0.7"
tower = "*"
trust-dns-resolver = "0.22" trust-dns-resolver = "0.22"
maxminddb = "0.17" maxminddb = "0.17"

View File

@ -1,7 +1,7 @@
use axum_client_ip::SecureClientIpSource; use axum_client_ip::SecureClientIpSource;
use std::net::SocketAddr; use std::net::SocketAddr;
#[derive(serde::Deserialize, Default)] #[derive(serde::Deserialize, Default, Clone)]
pub struct EchoIpServiceConfig { pub struct EchoIpServiceConfig {
pub server: ServerConfig, pub server: ServerConfig,
pub dns: DnsConfig, pub dns: DnsConfig,
@ -9,7 +9,7 @@ pub struct EchoIpServiceConfig {
pub template: TemplateConfig, pub template: TemplateConfig,
} }
#[derive(serde::Deserialize)] #[derive(serde::Deserialize, Clone)]
pub struct ServerConfig { pub struct ServerConfig {
pub listen_on: SocketAddr, pub listen_on: SocketAddr,
pub ip_header: SecureClientIpSource, pub ip_header: SecureClientIpSource,
@ -17,7 +17,7 @@ pub struct ServerConfig {
pub allow_private_ip_lookup: bool, pub allow_private_ip_lookup: bool,
} }
#[derive(serde::Deserialize)] #[derive(serde::Deserialize, Clone)]
pub struct DnsConfig { pub struct DnsConfig {
pub allow_forward_lookup: bool, pub allow_forward_lookup: bool,
pub allow_reverse_lookup: bool, pub allow_reverse_lookup: bool,
@ -25,13 +25,13 @@ pub struct DnsConfig {
//Future Idea: allow custom resolver //Future Idea: allow custom resolver
} }
#[derive(serde::Deserialize)] #[derive(serde::Deserialize, Clone)]
pub struct GeoIpConfig { pub struct GeoIpConfig {
pub asn_database: Option<String>, pub asn_database: Option<String>,
pub location_database: Option<String>, pub location_database: Option<String>,
} }
#[derive(serde::Deserialize)] #[derive(serde::Deserialize, Clone)]
pub struct TemplateConfig { pub struct TemplateConfig {
pub template_location: String, pub template_location: String,
pub extra_config: Option<String>, pub extra_config: Option<String>,

View File

@ -12,20 +12,20 @@ use std::path::Path;
/* Datatypes */ /* Datatypes */
#[derive(serde::Deserialize, serde::Serialize, Default)] #[derive(serde::Deserialize, serde::Serialize, Default, Clone)]
pub struct NamedLocation { pub struct NamedLocation {
iso_code: Option<String>, iso_code: Option<String>,
name: Option<String>, name: Option<String>,
geoname_id: Option<u32>, geoname_id: Option<u32>,
} }
#[derive(serde::Deserialize, serde::Serialize, Default)] #[derive(serde::Deserialize, serde::Serialize, Default, Copy, Clone)]
pub struct LocationCoordinates { pub struct LocationCoordinates {
latitude: f64, latitude: f64,
logtitude: f64, logtitude: f64,
} }
#[derive(serde::Deserialize, serde::Serialize, Default)] #[derive(serde::Deserialize, serde::Serialize, Default, Clone)]
pub struct LocationResult { pub struct LocationResult {
continent: Option<NamedLocation>, continent: Option<NamedLocation>,
country: Option<NamedLocation>, country: Option<NamedLocation>,
@ -38,7 +38,7 @@ pub struct LocationResult {
time_zone: Option<String>, time_zone: Option<String>,
} }
#[derive(serde::Deserialize, serde::Serialize, Default)] #[derive(serde::Deserialize, serde::Serialize, Default, Clone)]
pub struct AsnResult { pub struct AsnResult {
asn: Option<u32>, asn: Option<u32>,
name: Option<String>, name: Option<String>,

View File

@ -8,7 +8,7 @@
use std::net::{IpAddr, Ipv4Addr}; use std::net::{IpAddr, Ipv4Addr};
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
#[derive(Serialize, Deserialize, Default, PartialEq)] #[derive(Serialize, Deserialize, Default, PartialEq, Clone)]
#[serde(rename_all="lowercase")] #[serde(rename_all="lowercase")]
pub enum AddressCast { pub enum AddressCast {
Unspecified, Unspecified,
@ -18,7 +18,7 @@ pub enum AddressCast {
Broadcast, Broadcast,
} }
#[derive(Serialize, Deserialize, Default, PartialEq)] #[derive(Serialize, Deserialize, Default, PartialEq, Clone)]
#[serde(rename_all="lowercase")] #[serde(rename_all="lowercase")]
pub enum AddressScope { pub enum AddressScope {
Global, Global,
@ -32,7 +32,7 @@ pub enum AddressScope {
Unknown, Unknown,
} }
#[derive(Serialize, Deserialize, Default)] #[derive(Serialize, Deserialize, Default, Clone)]
pub struct AddressInfo { pub struct AddressInfo {
pub is_v6_address: bool, pub is_v6_address: bool,
pub cast: AddressCast, pub cast: AddressCast,

View File

@ -1,7 +1,12 @@
use axum::{ use axum::{
extract::Query, extract::{
extract::State, self,
extract, Query,
State,
Extension,
},
http::Request,
middleware::{self, Next},
response::Response, response::Response,
Router, Router,
routing::get, routing::get,
@ -9,6 +14,7 @@ use axum::{
use axum_client_ip::SecureClientIp; use axum_client_ip::SecureClientIp;
use clap::Parser; use clap::Parser;
use tera::Tera; use tera::Tera;
use tower::ServiceBuilder;
use trust_dns_resolver::{ use trust_dns_resolver::{
TokioAsyncResolver, TokioAsyncResolver,
// config::ResolverOpts, // config::ResolverOpts,
@ -36,26 +42,29 @@ use crate::templating_engine::ResponseFormat;
use crate::ipinfo::{AddressCast,AddressInfo,AddressScope}; use crate::ipinfo::{AddressCast,AddressInfo,AddressScope};
#[derive(serde::Deserialize, serde::Serialize)] #[derive(serde::Deserialize, serde::Serialize, Clone)]
pub struct BaseQuery { pub struct BaseQuery {
format: Option<ResponseFormat>, format: Option<ResponseFormat>,
lang: Option<String>, lang: Option<String>,
} }
#[derive(serde::Deserialize, serde::Serialize)] #[derive(serde::Deserialize, serde::Serialize, Clone)]
pub struct QuerySettings {
format: ResponseFormat,
lang: String,
}
#[derive(serde::Deserialize, serde::Serialize, Clone)]
pub struct IpQuery { pub struct IpQuery {
format: Option<ResponseFormat>,
lang: Option<String>,
ip: IpAddr, ip: IpAddr,
} }
#[derive(serde::Deserialize, serde::Serialize)] #[derive(serde::Deserialize, serde::Serialize, Clone)]
pub struct DigQuery { pub struct DigQuery {
format: Option<ResponseFormat>,
name: String, name: String,
} }
#[derive(serde::Deserialize, serde::Serialize)] #[derive(serde::Deserialize, serde::Serialize, Clone)]
pub struct IpResult { pub struct IpResult {
hostname: Option<String>, hostname: Option<String>,
asn: Option<AsnResult>, asn: Option<AsnResult>,
@ -158,6 +167,11 @@ async fn main() {
} }
}; };
let templating_engine = templating_engine::Engine{
tera: tera,
template_config: template_extra_config,
};
// Initalize GeoIP Database // Initalize GeoIP Database
let mut asn_db = geoip::MMDBCarrier { let mut asn_db = geoip::MMDBCarrier {
@ -198,14 +212,11 @@ async fn main() {
// Initialize shared state // Initialize shared state
let shared_state = Arc::new( let shared_state = Arc::new(
ServiceSharedState { ServiceSharedState {
templating_engine: templating_engine::Engine{ templating_engine: templating_engine,
tera: tera,
template_config: template_extra_config,
},
dns_resolver: dns_resolver, dns_resolver: dns_resolver,
asn_db: asn_db, asn_db: asn_db,
location_db: location_db, location_db: location_db,
config: config, config: config.clone(),
}); });
// Initalize axum server // Initalize axum server
@ -217,7 +228,12 @@ async fn main() {
.route("/ip/:address", get(handle_ip_route_with_path)) .route("/ip/:address", get(handle_ip_route_with_path))
.route("/hi", get(hello_world_handler)) .route("/hi", get(hello_world_handler))
.with_state(shared_state) .with_state(shared_state)
.layer(
ServiceBuilder::new()
.layer(ip_header.into_extension()) .layer(ip_header.into_extension())
.layer(Extension(config))
.layer(middleware::from_fn(format_and_language_middleware))
)
; ;
println!("Starting Server ..."); println!("Starting Server ...");
@ -228,77 +244,90 @@ async fn main() {
.unwrap(); .unwrap();
} }
async fn format_and_language_middleware<B>(
Query(query): Query<BaseQuery>,
Extension(config): Extension<config::EchoIpServiceConfig>,
mut req: Request<B>,
next: Next<B>
) -> Response {
let format = query.format.unwrap_or(ResponseFormat::TextHtml);
req.extensions_mut().insert(QuerySettings{
format: format,
lang: query.lang.unwrap_or("en".to_string()),
});
next.run(req).await
}
#[axum::debug_handler]
async fn hello_world_handler( async fn hello_world_handler(
State(arc_state): State<Arc<ServiceSharedState>>, State(arc_state): State<Arc<ServiceSharedState>>,
Extension(settings): Extension<QuerySettings>,
) -> Response { ) -> Response {
let state = Arc::clone(&arc_state); let state = Arc::clone(&arc_state);
state.templating_engine.render_view( state.templating_engine.render_view(
ResponseFormat::TextPlain, settings.format,
View::Message("Hello! There, You, Awesome Creature!".to_string()) &View::Message("Hello! There, You, Awesome Creature!".to_string())
).await ).await
} }
async fn handle_default_route( async fn handle_default_route(
Query(query): Query<BaseQuery>,
State(arc_state): State<Arc<ServiceSharedState>>, State(arc_state): State<Arc<ServiceSharedState>>,
Extension(settings): Extension<QuerySettings>,
SecureClientIp(address): SecureClientIp SecureClientIp(address): SecureClientIp
) -> Response { ) -> Response {
let format = query.format.unwrap_or(ResponseFormat::TextHtml);
let ip_query = IpQuery { let ip_query = IpQuery {
format: query.format,
lang: query.lang,
ip: address, ip: address,
}; };
let state = Arc::clone(&arc_state); let state = Arc::clone(&arc_state);
let result = get_ip_result(&ip_query, &state).await; let result = get_ip_result(&ip_query, &settings.lang, &state).await;
state.templating_engine.render_view( state.templating_engine.render_view(
format, settings.format,
View::Index{query: ip_query, result: result} &View::Index{query: ip_query, result: result}
).await ).await
} }
async fn handle_ip_route( async fn handle_ip_route(
Query(ip_query): Query<IpQuery>, Query(ip_query): Query<IpQuery>,
Extension(settings): Extension<QuerySettings>,
State(arc_state): State<Arc<ServiceSharedState>>, State(arc_state): State<Arc<ServiceSharedState>>,
) -> Response { ) -> Response {
return handle_ip_request(ip_query, arc_state).await return handle_ip_request(ip_query, settings, arc_state).await
} }
async fn handle_ip_route_with_path( async fn handle_ip_route_with_path(
Query(query): Query<BaseQuery>, Extension(settings): Extension<QuerySettings>,
State(arc_state): State<Arc<ServiceSharedState>>, State(arc_state): State<Arc<ServiceSharedState>>,
extract::Path(address): extract::Path<IpAddr>, extract::Path(address): extract::Path<IpAddr>,
) -> Response { ) -> Response {
return handle_ip_request(IpQuery { return handle_ip_request(IpQuery {
format: query.format,
lang: query.lang,
ip: address, ip: address,
}, arc_state).await }, settings, arc_state).await
} }
async fn handle_ip_request( async fn handle_ip_request(
ip_query: IpQuery, ip_query: IpQuery,
settings: QuerySettings,
arc_state: Arc<ServiceSharedState>, arc_state: Arc<ServiceSharedState>,
) -> Response { ) -> Response {
let state = Arc::clone(&arc_state); let state = Arc::clone(&arc_state);
let result = get_ip_result(&ip_query, &state).await; let result = get_ip_result(&ip_query, &settings.lang, &state).await;
let format = ip_query.format.unwrap_or(ResponseFormat::TextHtml);
state.templating_engine.render_view( state.templating_engine.render_view(
format, settings.format,
View::Ip{query: ip_query, result: result} &View::Ip{query: ip_query, result: result}
).await ).await
} }
async fn get_ip_result( async fn get_ip_result(
ip_query: &IpQuery, ip_query: &IpQuery,
lang: &String,
state: &ServiceSharedState, state: &ServiceSharedState,
) -> IpResult { ) -> IpResult {
let address = ip_query.ip; let address = ip_query.ip;
@ -329,7 +358,7 @@ async fn get_ip_result(
// location lookup // location lookup
let location_result = state.location_db.query_location_for_ip( let location_result = state.location_db.query_location_for_ip(
address, address,
&vec![&ip_query.lang.as_ref().unwrap_or(&"en".to_string()), &"en".to_string()] &vec![lang, &"en".to_string()]
); );
// filter reverse lookup // filter reverse lookup
@ -354,35 +383,35 @@ async fn get_ip_result(
async fn handle_dig_route( async fn handle_dig_route(
Query(dig_query): Query<DigQuery>, Query(dig_query): Query<DigQuery>,
Extension(settings): Extension<QuerySettings>,
State(arc_state): State<Arc<ServiceSharedState>>, State(arc_state): State<Arc<ServiceSharedState>>,
) -> Response { ) -> Response {
return handle_dig_request(dig_query, arc_state).await return handle_dig_request(dig_query, settings, arc_state).await
} }
async fn handle_dig_route_with_path( async fn handle_dig_route_with_path(
Query(query): Query<BaseQuery>, Extension(settings): Extension<QuerySettings>,
State(arc_state): State<Arc<ServiceSharedState>>, State(arc_state): State<Arc<ServiceSharedState>>,
extract::Path(name): extract::Path<String>, extract::Path(name): extract::Path<String>,
) -> Response { ) -> Response {
return handle_dig_request(DigQuery { return handle_dig_request(DigQuery {
format: query.format,
name: name, name: name,
}, arc_state).await }, settings, arc_state).await
} }
async fn handle_dig_request( async fn handle_dig_request(
dig_query: DigQuery, dig_query: DigQuery,
settings: QuerySettings,
arc_state: Arc<ServiceSharedState>, arc_state: Arc<ServiceSharedState>,
) -> Response { ) -> Response {
let state = Arc::clone(&arc_state); let state = Arc::clone(&arc_state);
let format = dig_query.format.unwrap_or(ResponseFormat::TextHtml);
let dig_result = get_dig_result(&dig_query, &state).await; let dig_result = get_dig_result(&dig_query, &state).await;
state.templating_engine.render_view( state.templating_engine.render_view(
format, settings.format,
View::Dig{ query: dig_query, result: dig_result} &View::Dig{ query: dig_query, result: dig_result}
).await ).await
} }

View File

@ -16,15 +16,14 @@ use std::net::IpAddr;
/* Data Structures */ /* Data Structures */
#[derive(serde::Deserialize, serde::Serialize)] #[derive(serde::Deserialize, serde::Serialize, Default, Clone)]
#[derive(Default)]
pub struct DnsLookupResult { pub struct DnsLookupResult {
a: Vec<IpAddr>, a: Vec<IpAddr>,
aaaa: Vec<IpAddr>, aaaa: Vec<IpAddr>,
mx: Vec<MxRecord>, mx: Vec<MxRecord>,
} }
#[derive(serde::Deserialize, serde::Serialize)] #[derive(serde::Deserialize, serde::Serialize, Clone)]
pub struct MxRecord { pub struct MxRecord {
preference: u16, preference: u16,
exchange: String, exchange: String,

View File

@ -42,7 +42,7 @@ impl ToString for ResponseFormat {
/* The echoip view */ /* The echoip view */
#[derive(serde::Deserialize, serde::Serialize)] #[derive(serde::Deserialize, serde::Serialize, Clone)]
#[serde(untagged)] #[serde(untagged)]
pub enum View { pub enum View {
Dig { query: DigQuery, result: simple_dns::DnsLookupResult }, Dig { query: DigQuery, result: simple_dns::DnsLookupResult },
@ -67,6 +67,7 @@ impl View {
/* The engine itself */ /* The engine itself */
#[derive(Clone)]
pub struct Engine { pub struct Engine {
pub tera: Tera, pub tera: Tera,
pub template_config: Option<Table>, pub template_config: Option<Table>,
@ -76,7 +77,7 @@ impl Engine {
pub async fn render_view( pub async fn render_view(
&self, &self,
format: ResponseFormat, format: ResponseFormat,
view: View, view: &View,
) -> Response { ) -> Response {
match format { match format {
ResponseFormat::TextHtml => { ResponseFormat::TextHtml => {