2023-02-21 00:06:49 +01:00

430 lines
10 KiB
Rust

use axum::{
extract::{
self,
Query,
State,
Extension,
},
http::Request,
middleware::{self, Next},
response::Response,
Router,
routing::get,
};
use axum_client_ip::SecureClientIp;
use clap::Parser;
use tera::Tera;
use tower::ServiceBuilder;
use trust_dns_resolver::{
TokioAsyncResolver,
// config::ResolverOpts,
// config::ResolverConfig,
};
use std::fs;
use std::net::IpAddr;
use std::sync::Arc;
use std::path::Path;
mod config;
mod geoip;
mod ipinfo;
mod simple_dns;
mod templating_engine;
use crate::geoip::QueryAsn;
use crate::geoip::QueryLocation;
use geoip::AsnResult;
use geoip::LocationResult;
use crate::templating_engine::View;
use crate::templating_engine::ResponseFormat;
use crate::ipinfo::{AddressCast,AddressInfo,AddressScope};
#[derive(serde::Deserialize, serde::Serialize, Clone)]
pub struct BaseQuery {
format: Option<ResponseFormat>,
lang: Option<String>,
}
#[derive(serde::Deserialize, serde::Serialize, Clone)]
pub struct QuerySettings {
format: ResponseFormat,
lang: String,
}
#[derive(serde::Deserialize, serde::Serialize, Clone)]
pub struct IpQuery {
ip: IpAddr,
}
#[derive(serde::Deserialize, serde::Serialize, Clone)]
pub struct DigQuery {
name: String,
}
#[derive(serde::Deserialize, serde::Serialize, Clone)]
pub struct IpResult {
hostname: Option<String>,
asn: Option<AsnResult>,
location: Option<LocationResult>,
ip_info: AddressInfo,
}
struct ServiceSharedState {
templating_engine: templating_engine::Engine,
dns_resolver: TokioAsyncResolver,
asn_db: geoip::MMDBCarrier,
location_db: geoip::MMDBCarrier,
config: config::EchoIpServiceConfig,
}
#[derive(Parser)]
#[command(author, version, long_about="A web service that tells you your ip-address and more …")]
struct CliArgs {
#[arg(short, long)]
config: Option<String>,
#[arg(short, long)]
listen_on: Option<String>,
#[arg(short, long)]
templates: Option<String>,
#[arg(short,long)]
extra_config: Option<String>,
}
fn match_domain_hidden_list(domain: &String, hidden_list: &Vec<String>) -> bool {
let name = domain.trim_end_matches(".");
for suffix in hidden_list {
if name.ends_with(suffix) {
println!("Blocked {name}");
return true;
}
}
return false;
}
fn read_toml_from_file<T: for<'de> serde::Deserialize<'de>>(path: &String) -> Option<T> {
let text = match fs::read_to_string(path) {
Ok(t) => t,
Err(e) => {
println!("Error while reading file '{path}': {e}");
return None;
}
};
match toml::from_str(&text) {
Ok(t) => Some(t),
Err(e) => {
println!("Unable to parse file '{path}':\n{e}");
return None;
}
}
}
#[tokio::main]
async fn main() {
// Parse Command line arguments
let cli_args = CliArgs::parse();
// Read configuration file
let config: config::EchoIpServiceConfig = match cli_args.config {
Some(config_path) => {
match read_toml_from_file::<config::EchoIpServiceConfig>(&config_path) {
Some(c) => c,
None => {
println!("Could not read confuration file, exiting.");
::std::process::exit(1);
}
}
},
None => Default::default(),
};
// Initalize Tera templates
let mut template_base_dir = (&config.template.template_location).to_owned();
if !template_base_dir.ends_with("/") {
template_base_dir = template_base_dir + "/";
}
let template_extra_config = match &cli_args.extra_config {
Some(path) => read_toml_from_file(path),
None => match &config.template.extra_config {
Some(path) => read_toml_from_file(path),
None => {
println!("Trying to read default template configuration ...");
println!("(If this fails that may be ok, depending on your template)");
read_toml_from_file(&(template_base_dir.clone()+"extra.toml"))
},
},
};
let template_glob = template_base_dir+"*.html";
println!("Parsing Templates from '{}' ...", &template_glob);
let res = Tera::new((template_glob).as_str());
let tera = match res {
Ok(t) => t,
Err(e) => {
println!("Template parsing error(s): {}", e);
::std::process::exit(1);
}
};
let templating_engine = templating_engine::Engine{
tera: tera,
template_config: template_extra_config,
};
// Initalize GeoIP Database
let mut asn_db = geoip::MMDBCarrier {
mmdb: None,
name: "GeoIP ASN Database".to_string(),
};
match &config.geoip.asn_database {
Some(path) => { asn_db.load_database_from_path(Path::new(&path)).ok(); },
None => {},
}
let mut location_db = geoip::MMDBCarrier {
mmdb: None,
name: "GeoIP Location Database".to_string(),
};
match &config.geoip.location_database {
Some(path) => { location_db.load_database_from_path(Path::new(&path)).ok(); },
None => {},
}
// Initalize DNS resolver with os defaults
println!("Initalizing dns resolver ...");
println!("Using System configuration ...");
let res = TokioAsyncResolver::tokio_from_system_conf();
//let res = TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
let dns_resolver = match res {
Ok(resolver) => resolver,
Err(e) => {
println!("Error while setting up dns resolver: {e}");
::std::process::exit(1);
}
};
let listen_on = config.server.listen_on;
let ip_header = config.server.ip_header.clone();
// Initialize shared state
let shared_state = Arc::new(
ServiceSharedState {
templating_engine: templating_engine,
dns_resolver: dns_resolver,
asn_db: asn_db,
location_db: location_db,
config: config.clone(),
});
// Initalize axum server
let app = Router::new()
.route("/", get(handle_default_route))
.route("/dig", get(handle_dig_route))
.route("/dig/:name", get(handle_dig_route_with_path))
.route("/ip", get(handle_ip_route))
.route("/ip/:address", get(handle_ip_route_with_path))
.route("/hi", get(hello_world_handler))
.with_state(shared_state)
.layer(
ServiceBuilder::new()
.layer(ip_header.into_extension())
.layer(Extension(config))
.layer(middleware::from_fn(format_and_language_middleware))
)
;
println!("Starting Server ...");
axum::Server::bind(&listen_on)
.serve(app.into_make_service_with_connect_info::<std::net::SocketAddr>())
.await
.unwrap();
}
async fn format_and_language_middleware<B>(
Query(query): Query<BaseQuery>,
Extension(config): Extension<config::EchoIpServiceConfig>,
mut req: Request<B>,
next: Next<B>
) -> Response {
let format = query.format.unwrap_or(ResponseFormat::TextHtml);
req.extensions_mut().insert(QuerySettings{
format: format,
lang: query.lang.unwrap_or("en".to_string()),
});
next.run(req).await
}
#[axum::debug_handler]
async fn hello_world_handler(
State(arc_state): State<Arc<ServiceSharedState>>,
Extension(settings): Extension<QuerySettings>,
) -> Response {
let state = Arc::clone(&arc_state);
state.templating_engine.render_view(
settings.format,
&View::Message("Hello! There, You, Awesome Creature!".to_string())
).await
}
async fn handle_default_route(
State(arc_state): State<Arc<ServiceSharedState>>,
Extension(settings): Extension<QuerySettings>,
SecureClientIp(address): SecureClientIp
) -> Response {
let ip_query = IpQuery {
ip: address,
};
let state = Arc::clone(&arc_state);
let result = get_ip_result(&ip_query, &settings.lang, &state).await;
state.templating_engine.render_view(
settings.format,
&View::Index{query: ip_query, result: result}
).await
}
async fn handle_ip_route(
Query(ip_query): Query<IpQuery>,
Extension(settings): Extension<QuerySettings>,
State(arc_state): State<Arc<ServiceSharedState>>,
) -> Response {
return handle_ip_request(ip_query, settings, arc_state).await
}
async fn handle_ip_route_with_path(
Extension(settings): Extension<QuerySettings>,
State(arc_state): State<Arc<ServiceSharedState>>,
extract::Path(address): extract::Path<IpAddr>,
) -> Response {
return handle_ip_request(IpQuery {
ip: address,
}, settings, arc_state).await
}
async fn handle_ip_request(
ip_query: IpQuery,
settings: QuerySettings,
arc_state: Arc<ServiceSharedState>,
) -> Response {
let state = Arc::clone(&arc_state);
let result = get_ip_result(&ip_query, &settings.lang, &state).await;
state.templating_engine.render_view(
settings.format,
&View::Ip{query: ip_query, result: result}
).await
}
async fn get_ip_result(
ip_query: &IpQuery,
lang: &String,
state: &ServiceSharedState,
) -> IpResult {
let address = ip_query.ip;
let ip_info = AddressInfo::new(&address);
if !(ip_info.scope == AddressScope::Global || ip_info.scope == AddressScope::Shared) || ip_info.cast != AddressCast::Unicast {
if !((ip_info.scope == AddressScope::Private || ip_info.scope == AddressScope::LinkLocal) && state.config.server.allow_private_ip_lookup) {
return IpResult {
hostname: None,
asn: None,
location: None,
ip_info: ip_info,
}
}
}
// do reverse lookup
let hostname = if state.config.dns.allow_reverse_lookup {
simple_dns::reverse_lookup(&state.dns_resolver, &address).await
} else {
None
};
// asn lookup
let asn_result = state.asn_db.query_asn_for_ip(address);
// location lookup
let location_result = state.location_db.query_location_for_ip(
address,
&vec![lang, &"en".to_string()]
);
// filter reverse lookup
let final_hostname = match hostname {
Some(name) => {
if match_domain_hidden_list(&name, &state.config.dns.hidden_suffixes) {
None
} else {
Some(name.to_owned())
}
},
None => None,
};
IpResult{
hostname: final_hostname,
asn: asn_result,
location: location_result,
ip_info: ip_info,
}
}
async fn handle_dig_route(
Query(dig_query): Query<DigQuery>,
Extension(settings): Extension<QuerySettings>,
State(arc_state): State<Arc<ServiceSharedState>>,
) -> Response {
return handle_dig_request(dig_query, settings, arc_state).await
}
async fn handle_dig_route_with_path(
Extension(settings): Extension<QuerySettings>,
State(arc_state): State<Arc<ServiceSharedState>>,
extract::Path(name): extract::Path<String>,
) -> Response {
return handle_dig_request(DigQuery {
name: name,
}, settings, arc_state).await
}
async fn handle_dig_request(
dig_query: DigQuery,
settings: QuerySettings,
arc_state: Arc<ServiceSharedState>,
) -> Response {
let state = Arc::clone(&arc_state);
let dig_result = get_dig_result(&dig_query, &state).await;
state.templating_engine.render_view(
settings.format,
&View::Dig{ query: dig_query, result: dig_result}
).await
}
async fn get_dig_result(
dig_query: &DigQuery,
state: &ServiceSharedState,
) -> simple_dns::DnsLookupResult {
let name = &dig_query.name.trim().trim_end_matches(".").to_string();
if match_domain_hidden_list(&name, &state.config.dns.hidden_suffixes) {
Default::default()
} else {
simple_dns::lookup(&state.dns_resolver, name, true).await
}
}