use axum_client_ip::SecureClientIp; use axum::{ body::Body, extract::Extension, http::{ Request, StatusCode, }, middleware::Next, response::{ IntoResponse, Response, }, }; use governor::{ clock::DefaultClock, Quota, RateLimiter, state::keyed::DefaultKeyedStateStore, }; use log::debug; use std::net::IpAddr; use std::num::NonZeroU32; use std::sync::Arc; pub type SimpleRateLimiter = RateLimiter, DefaultClock>; pub fn build_rate_limiting_state( requests_per_minute: NonZeroU32, request_burst_capacity: NonZeroU32, ) -> Extension>> { let quota = Quota::per_minute(requests_per_minute) .allow_burst(request_burst_capacity); let arc_limiter : Arc> = Arc::new( RateLimiter::keyed(quota) ); Extension(arc_limiter) } pub async fn rate_limit_middleware( SecureClientIp(address): SecureClientIp, Extension(arc_limiter): Extension>>, req: Request, next: Next ) -> Response { let limiter = Arc::clone(&arc_limiter); match limiter.check_key(&address) { Ok(_) => { //Little hack to prevent too many cleanups in cases of very high load if limiter.check_key(&IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)).is_ok() { let oldlen = limiter.len(); if oldlen > 100 { debug!("Doing limiter cleanup ..."); limiter.retain_recent(); limiter.shrink_to_fit(); debug!("Old limiter store size: {oldlen} New limiter store size: {}", limiter.len()); } } next.run(req).await }, Err(_) => ( StatusCode::TOO_MANY_REQUESTS, "You make too many requests! Please slow down a bit." ).into_response(), } }