Added some ratelimiting middleware

This commit is contained in:
Slatian
2023-02-25 12:14:50 +01:00
parent 9f3b6d0c17
commit a48050b234
8 changed files with 307 additions and 18 deletions

View File

@ -1,5 +1,6 @@
use axum_client_ip::SecureClientIpSource;
use std::net::SocketAddr;
use std::num::NonZeroU32;
#[derive(serde::Deserialize, Default, Clone)]
pub struct EchoIpServiceConfig {
@ -7,6 +8,7 @@ pub struct EchoIpServiceConfig {
pub dns: DnsConfig,
pub geoip: GeoIpConfig,
pub template: TemplateConfig,
pub ratelimit: RatelimitConfig,
}
#[derive(serde::Deserialize, Clone)]
@ -38,6 +40,12 @@ pub struct TemplateConfig {
pub text_user_agents: Vec<String>,
}
#[derive(serde::Deserialize, Clone)]
pub struct RatelimitConfig {
pub per_minute: NonZeroU32,
pub burst: NonZeroU32,
}
impl Default for ServerConfig {
fn default() -> Self {
ServerConfig {
@ -76,3 +84,12 @@ impl Default for TemplateConfig {
}
}
}
impl Default for RatelimitConfig {
fn default() -> Self {
RatelimitConfig {
per_minute: NonZeroU32::new(20).unwrap(),
burst: NonZeroU32::new(15).unwrap(),
}
}
}

View File

@ -33,6 +33,7 @@ use std::path::Path;
mod config;
mod geoip;
mod ipinfo;
mod ratelimit;
mod simple_dns;
mod templating_engine;
mod idna;
@ -85,6 +86,7 @@ pub struct DigResult {
partial_lookup: bool,
}
struct ServiceSharedState {
templating_engine: templating_engine::Engine,
dns_resolver: TokioAsyncResolver,
@ -110,7 +112,6 @@ fn match_domain_hidden_list(domain: &String, hidden_list: &Vec<String>) -> bool
let name = domain.trim_end_matches(".");
for suffix in hidden_list {
if name.ends_with(suffix) {
println!("Blocked {name}");
return true;
}
}
@ -185,6 +186,9 @@ async fn main() {
template_config: template_extra_config,
};
// Initalize Rate Limiter
// Initalize GeoIP Database
let mut asn_db = geoip::MMDBCarrier {
@ -243,9 +247,12 @@ async fn main() {
.with_state(shared_state)
.layer(
ServiceBuilder::new()
.layer(ip_header.into_extension())
.layer(Extension(config))
.layer(middleware::from_fn(format_and_language_middleware))
.layer(ip_header.into_extension())
.layer(ratelimit::build_rate_limiting_state(
config.ratelimit.per_minute, config.ratelimit.burst))
.layer(middleware::from_fn(ratelimit::rate_limit_middleware))
.layer(Extension(config))
.layer(middleware::from_fn(format_and_language_middleware))
)
;
@ -257,6 +264,7 @@ async fn main() {
.unwrap();
}
async fn format_and_language_middleware<B>(
Query(query): Query<SettingsQuery>,
Extension(config): Extension<config::EchoIpServiceConfig>,

71
src/ratelimit.rs Normal file
View File

@ -0,0 +1,71 @@
use axum_client_ip::SecureClientIp;
use axum::{
extract::Extension,
http::{
Request,
StatusCode,
},
middleware::Next,
response::{
IntoResponse,
Response,
},
};
use governor::{
clock::DefaultClock,
Quota,
RateLimiter,
state::keyed::DefaultKeyedStateStore,
};
use std::net::IpAddr;
use std::num::NonZeroU32;
use std::sync::Arc;
pub type SimpleRateLimiter<Key> =
RateLimiter<Key, DefaultKeyedStateStore<Key>, DefaultClock>;
pub fn build_rate_limiting_state(
requests_per_minute: NonZeroU32,
request_burst_capacity: NonZeroU32,
) -> Extension<Arc<SimpleRateLimiter<IpAddr>>> {
let quota = Quota::per_minute(requests_per_minute)
.allow_burst(request_burst_capacity);
let arc_limiter : Arc<SimpleRateLimiter<IpAddr>> = Arc::new(
RateLimiter::keyed(quota)
);
Extension(arc_limiter)
}
pub async fn rate_limit_middleware<B>(
SecureClientIp(address): SecureClientIp,
Extension(arc_limiter): Extension<Arc<SimpleRateLimiter<IpAddr>>>,
req: Request<B>,
next: Next<B>
) -> Response {
let limiter = Arc::clone(&arc_limiter);
match limiter.check_key(&address) {
Ok(_) => {
//Little hack to prevent too many cleanups in cases of very high load
if limiter.check_key(&IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)).is_ok() {
let oldlen = limiter.len();
if oldlen > 100 {
println!("Doing limiter cleanup ...");
limiter.retain_recent();
limiter.shrink_to_fit();
println!("Old limiter store size: {oldlen} New limiter store size: {}", limiter.len());
}
}
next.run(req).await
},
Err(_) => (
StatusCode::TOO_MANY_REQUESTS,
"You make too many requests! Please slow down a bit."
).into_response(),
}
}