diff --git a/Cargo.lock b/Cargo.lock index 7916e28..ce8ccea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -372,6 +372,7 @@ dependencies = [ "idna 0.3.0", "lazy_static", "maxminddb", + "parking_lot", "regex", "serde", "tera", diff --git a/Cargo.toml b/Cargo.toml index 73d0296..ad49e85 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ clap = { version = "4", features = ["derive"] } governor = "0.5" idna = "0.3" lazy_static = "1.4.0" +parking_lot = "0.12" regex = "1.7" serde = { version = "1", features = ["derive"] } tokio = { version = "1", features = ["full"] } diff --git a/src/geoip.rs b/src/geoip.rs index e494828..e0c45e6 100644 --- a/src/geoip.rs +++ b/src/geoip.rs @@ -6,6 +6,8 @@ use maxminddb; use maxminddb::geoip2; +use parking_lot::RwLock; + use std::collections::BTreeMap; use std::net::IpAddr; use std::path::Path; @@ -47,8 +49,9 @@ pub struct AsnResult { } pub struct MMDBCarrier { - pub mmdb: Option>>, + pub mmdb: RwLock>>>, pub name: String, + pub path: Option, } pub trait QueryLocation { @@ -122,7 +125,8 @@ pub fn geoip2_subdivision_to_named_location(item: geoip2::city::Subdivision, lan impl QueryAsn for MMDBCarrier { fn query_asn_for_ip(&self, address: &IpAddr) -> Option { - match &self.mmdb { + let mmdb = self.mmdb.read(); + match &*mmdb { Some(mmdb) => { match mmdb.lookup::(*address) { Ok(res) => { @@ -144,7 +148,8 @@ impl QueryAsn for MMDBCarrier { impl QueryLocation for MMDBCarrier { fn query_location_for_ip(&self, address: &IpAddr, languages: &Vec<&String>) -> Option { - match &self.mmdb { + let mmdb = self.mmdb.read(); + match &*mmdb { Some(mmdb) => { match mmdb.lookup::(*address) { Ok(res) => { @@ -210,22 +215,38 @@ impl QueryLocation for MMDBCarrier { } impl MMDBCarrier { - pub fn load_database_from_path(&mut self, path: &Path) -> Result<(),maxminddb::MaxMindDBError> { + pub fn new(name: String, path: Option) -> MMDBCarrier { + MMDBCarrier { + mmdb: RwLock::new(None), + name: name, + path: path, + } + } + + pub fn reload_database(&self) -> Result<(),maxminddb::MaxMindDBError> { + match &self.path { + Some(path) => self.load_database_from_path(Path::new(&path)), + None => Ok(()), + } + } + + pub fn load_database_from_path(&self, path: &Path) -> Result<(),maxminddb::MaxMindDBError> { + let mut mmdb = self.mmdb.write(); println!("Loading {} from '{}' ...", &self.name, path.display()); match maxminddb::Reader::open_readfile(path) { Ok(reader) => { - let wording = if self.mmdb.is_some() { + let wording = if mmdb.is_some() { "Replaced old" } else { "Loaded new" }; - self.mmdb = Some(reader); + *mmdb = Some(reader); println!("{} {} with new one.", wording, &self.name); Ok(()) }, Err(e) => { println!("Error while reading {}: {}", &self.name, &e); - if self.mmdb.is_some() { + if mmdb.is_some() { println!("Not replacing old database."); } Err(e) diff --git a/src/main.rs b/src/main.rs index 77d79de..cd00c06 100644 --- a/src/main.rs +++ b/src/main.rs @@ -30,7 +30,6 @@ use trust_dns_resolver::{ use std::fs; use std::net::IpAddr; use std::sync::Arc; -use std::path::Path; mod config; mod geoip; @@ -205,23 +204,19 @@ async fn main() { // Initalize GeoIP Database - let mut asn_db = geoip::MMDBCarrier { - mmdb: None, - name: "GeoIP ASN Database".to_string(), - }; - match &config.geoip.asn_database { - Some(path) => { asn_db.load_database_from_path(Path::new(&path)).ok(); }, - None => {}, - } + let asn_db = geoip::MMDBCarrier::new( + "GeoIP ASN Database".to_string(), + config.geoip.asn_database.clone() + ); - let mut location_db = geoip::MMDBCarrier { - mmdb: None, - name: "GeoIP Location Database".to_string(), - }; - match &config.geoip.location_database { - Some(path) => { location_db.load_database_from_path(Path::new(&path)).ok(); }, - None => {}, - } + asn_db.reload_database().ok(); + + let location_db = geoip::MMDBCarrier::new( + "GeoIP Location Database".to_string(), + config.geoip.location_database.clone() + ); + + location_db.reload_database().ok(); // Initalize DNS resolver with os defaults println!("Initalizing dns resolver ...");