Antireplay Improvements + DC Ping
- Fix: LruCache::get type ambiguity in stats/mod.rs - Changed `self.cache.get(&key.into())` to `self.cache.get(key)` (key is already &[u8], resolved via Box<[u8]>: Borrow<[u8]>) - Changed `self.cache.peek(&key)` / `.pop(&key)` to `.peek(key.as_ref())` / `.pop(key.as_ref())` (explicit &[u8] instead of &Box<[u8]>) - Startup DC ping with RTT display and improved health-check (all DCs, RTT tracking, EMA latency, 30s interval): - Implemented `LatencyEma` – exponential moving average (α=0.3) for RTT - `connect()` – measures RTT of each real connection and updates EMA - `ping_all_dcs()` – pings all 5 DCs via each upstream, returns `Vec<StartupPingResult>` with RTT or error - `run_health_checks(prefer_ipv6)` – accepts IPv6 preference parameter, rotates DC between cycles (DC1→DC2→...→DC5→DC1...), interval reduced to 30s from 60s, failed checks now mark upstream as unhealthy after 3 consecutive fails - `DcPingResult` / `StartupPingResult` – public structures for display - DC Ping at startup: calls `upstream_manager.ping_all_dcs()` before accept loop, outputs table via `println!` (always visible) - Health checks with `prefer_ipv6`: `run_health_checks(prefer_ipv6)` receives the parameter - Exported `StartupPingResult` and `DcPingResult` - Summary: Startup DC ping with RTT, rotational health-check with EMA latency tracking, 30-second interval, correct unhealthy marking after 3 fails. Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
333
src/stats/mod.rs
333
src/stats/mod.rs
@@ -1,32 +1,28 @@
|
||||
//! Statistics
|
||||
//! Statistics and replay protection
|
||||
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Instant, Duration};
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::{RwLock, Mutex};
|
||||
use parking_lot::Mutex;
|
||||
use lru::LruCache;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::collections::VecDeque;
|
||||
use tracing::debug;
|
||||
|
||||
// ============= Stats =============
|
||||
|
||||
/// Thread-safe statistics
|
||||
#[derive(Default)]
|
||||
pub struct Stats {
|
||||
// Global counters
|
||||
connects_all: AtomicU64,
|
||||
connects_bad: AtomicU64,
|
||||
handshake_timeouts: AtomicU64,
|
||||
|
||||
// Per-user stats
|
||||
user_stats: DashMap<String, UserStats>,
|
||||
|
||||
// Start time
|
||||
start_time: RwLock<Option<Instant>>,
|
||||
start_time: parking_lot::RwLock<Option<Instant>>,
|
||||
}
|
||||
|
||||
/// Per-user statistics
|
||||
#[derive(Default)]
|
||||
pub struct UserStats {
|
||||
pub connects: AtomicU64,
|
||||
@@ -44,42 +40,20 @@ impl Stats {
|
||||
stats
|
||||
}
|
||||
|
||||
// Global stats
|
||||
pub fn increment_connects_all(&self) {
|
||||
self.connects_all.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
pub fn increment_connects_all(&self) { self.connects_all.fetch_add(1, Ordering::Relaxed); }
|
||||
pub fn increment_connects_bad(&self) { self.connects_bad.fetch_add(1, Ordering::Relaxed); }
|
||||
pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); }
|
||||
pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) }
|
||||
pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) }
|
||||
|
||||
pub fn increment_connects_bad(&self) {
|
||||
self.connects_bad.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_handshake_timeouts(&self) {
|
||||
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn get_connects_all(&self) -> u64 {
|
||||
self.connects_all.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn get_connects_bad(&self) -> u64 {
|
||||
self.connects_bad.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
// User stats
|
||||
pub fn increment_user_connects(&self, user: &str) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.connects
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
self.user_stats.entry(user.to_string()).or_default()
|
||||
.connects.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_user_curr_connects(&self, user: &str) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.curr_connects
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
self.user_stats.entry(user.to_string()).or_default()
|
||||
.curr_connects.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn decrement_user_curr_connects(&self, user: &str) {
|
||||
@@ -89,47 +63,33 @@ impl Stats {
|
||||
}
|
||||
|
||||
pub fn get_user_curr_connects(&self, user: &str) -> u64 {
|
||||
self.user_stats
|
||||
.get(user)
|
||||
self.user_stats.get(user)
|
||||
.map(|s| s.curr_connects.load(Ordering::Relaxed))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn add_user_octets_from(&self, user: &str, bytes: u64) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.octets_from_client
|
||||
.fetch_add(bytes, Ordering::Relaxed);
|
||||
self.user_stats.entry(user.to_string()).or_default()
|
||||
.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.octets_to_client
|
||||
.fetch_add(bytes, Ordering::Relaxed);
|
||||
self.user_stats.entry(user.to_string()).or_default()
|
||||
.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_user_msgs_from(&self, user: &str) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.msgs_from_client
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
self.user_stats.entry(user.to_string()).or_default()
|
||||
.msgs_from_client.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_user_msgs_to(&self, user: &str) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.msgs_to_client
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
self.user_stats.entry(user.to_string()).or_default()
|
||||
.msgs_to_client.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn get_user_total_octets(&self, user: &str) -> u64 {
|
||||
self.user_stats
|
||||
.get(user)
|
||||
self.user_stats.get(user)
|
||||
.map(|s| {
|
||||
s.octets_from_client.load(Ordering::Relaxed) +
|
||||
s.octets_to_client.load(Ordering::Relaxed)
|
||||
@@ -144,21 +104,27 @@ impl Stats {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sharded Replay attack checker using LRU cache + sliding window
|
||||
/// Uses multiple independent LRU caches to reduce lock contention
|
||||
// ============= Replay Checker =============
|
||||
|
||||
pub struct ReplayChecker {
|
||||
shards: Vec<Mutex<ReplayShard>>,
|
||||
shard_mask: usize,
|
||||
window: Duration,
|
||||
checks: AtomicU64,
|
||||
hits: AtomicU64,
|
||||
additions: AtomicU64,
|
||||
cleanups: AtomicU64,
|
||||
}
|
||||
|
||||
struct ReplayEntry {
|
||||
seen_at: Instant,
|
||||
seq: u64,
|
||||
}
|
||||
|
||||
struct ReplayShard {
|
||||
cache: LruCache<Vec<u8>, ReplayEntry>,
|
||||
queue: VecDeque<(Instant, Vec<u8>)>,
|
||||
cache: LruCache<Box<[u8]>, ReplayEntry>,
|
||||
queue: VecDeque<(Instant, Box<[u8]>, u64)>,
|
||||
seq_counter: u64,
|
||||
}
|
||||
|
||||
impl ReplayShard {
|
||||
@@ -166,33 +132,60 @@ impl ReplayShard {
|
||||
Self {
|
||||
cache: LruCache::new(cap),
|
||||
queue: VecDeque::with_capacity(cap.get()),
|
||||
seq_counter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn next_seq(&mut self) -> u64 {
|
||||
self.seq_counter += 1;
|
||||
self.seq_counter
|
||||
}
|
||||
|
||||
fn cleanup(&mut self, now: Instant, window: Duration) {
|
||||
if window.is_zero() {
|
||||
return;
|
||||
}
|
||||
let cutoff = now - window;
|
||||
while let Some((ts, _)) = self.queue.front() {
|
||||
let cutoff = now.checked_sub(window).unwrap_or(now);
|
||||
|
||||
while let Some((ts, _, _)) = self.queue.front() {
|
||||
if *ts >= cutoff {
|
||||
break;
|
||||
}
|
||||
let (ts_old, key_old) = self.queue.pop_front().unwrap();
|
||||
if let Some(entry) = self.cache.get(&key_old) {
|
||||
if entry.seen_at <= ts_old {
|
||||
self.cache.pop(&key_old);
|
||||
let (_, key, queue_seq) = self.queue.pop_front().unwrap();
|
||||
|
||||
// Use key.as_ref() to get &[u8] — avoids Borrow<Q> ambiguity
|
||||
// between Borrow<[u8]> and Borrow<Box<[u8]>>
|
||||
if let Some(entry) = self.cache.peek(key.as_ref()) {
|
||||
if entry.seq == queue_seq {
|
||||
self.cache.pop(key.as_ref());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool {
|
||||
self.cleanup(now, window);
|
||||
// key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]>
|
||||
self.cache.get(key).is_some()
|
||||
}
|
||||
|
||||
fn add(&mut self, key: &[u8], now: Instant, window: Duration) {
|
||||
self.cleanup(now, window);
|
||||
|
||||
let seq = self.next_seq();
|
||||
let boxed_key: Box<[u8]> = key.into();
|
||||
|
||||
self.cache.put(boxed_key.clone(), ReplayEntry { seen_at: now, seq });
|
||||
self.queue.push_back((now, boxed_key, seq));
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.cache.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReplayChecker {
|
||||
/// Create new replay checker with specified capacity per shard
|
||||
/// Total capacity = capacity * num_shards
|
||||
pub fn new(total_capacity: usize, window: Duration) -> Self {
|
||||
// Use 64 shards for good concurrency
|
||||
let num_shards = 64;
|
||||
let shard_capacity = (total_capacity / num_shards).max(1);
|
||||
let cap = NonZeroUsize::new(shard_capacity).unwrap();
|
||||
@@ -206,50 +199,114 @@ impl ReplayChecker {
|
||||
shards,
|
||||
shard_mask: num_shards - 1,
|
||||
window,
|
||||
checks: AtomicU64::new(0),
|
||||
hits: AtomicU64::new(0),
|
||||
additions: AtomicU64::new(0),
|
||||
cleanups: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_shard(&self, key: &[u8]) -> usize {
|
||||
fn get_shard_idx(&self, key: &[u8]) -> usize {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
key.hash(&mut hasher);
|
||||
(hasher.finish() as usize) & self.shard_mask
|
||||
}
|
||||
|
||||
fn check(&self, data: &[u8]) -> bool {
|
||||
let shard_idx = self.get_shard(data);
|
||||
let mut shard = self.shards[shard_idx].lock();
|
||||
let now = Instant::now();
|
||||
shard.cleanup(now, self.window);
|
||||
|
||||
let key = data.to_vec();
|
||||
shard.cache.get(&key).is_some()
|
||||
self.checks.fetch_add(1, Ordering::Relaxed);
|
||||
let idx = self.get_shard_idx(data);
|
||||
let mut shard = self.shards[idx].lock();
|
||||
let found = shard.check(data, Instant::now(), self.window);
|
||||
if found {
|
||||
self.hits.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
found
|
||||
}
|
||||
|
||||
fn add(&self, data: &[u8]) {
|
||||
let shard_idx = self.get_shard(data);
|
||||
let mut shard = self.shards[shard_idx].lock();
|
||||
let now = Instant::now();
|
||||
shard.cleanup(now, self.window);
|
||||
|
||||
let key = data.to_vec();
|
||||
shard.cache.put(key.clone(), ReplayEntry { seen_at: now });
|
||||
shard.queue.push_back((now, key));
|
||||
self.additions.fetch_add(1, Ordering::Relaxed);
|
||||
let idx = self.get_shard_idx(data);
|
||||
let mut shard = self.shards[idx].lock();
|
||||
shard.add(data, Instant::now(), self.window);
|
||||
}
|
||||
|
||||
pub fn check_handshake(&self, data: &[u8]) -> bool {
|
||||
self.check(data)
|
||||
pub fn check_handshake(&self, data: &[u8]) -> bool { self.check(data) }
|
||||
pub fn add_handshake(&self, data: &[u8]) { self.add(data) }
|
||||
pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check(data) }
|
||||
pub fn add_tls_digest(&self, data: &[u8]) { self.add(data) }
|
||||
|
||||
pub fn stats(&self) -> ReplayStats {
|
||||
let mut total_entries = 0;
|
||||
let mut total_queue_len = 0;
|
||||
for shard in &self.shards {
|
||||
let s = shard.lock();
|
||||
total_entries += s.cache.len();
|
||||
total_queue_len += s.queue.len();
|
||||
}
|
||||
|
||||
ReplayStats {
|
||||
total_entries,
|
||||
total_queue_len,
|
||||
total_checks: self.checks.load(Ordering::Relaxed),
|
||||
total_hits: self.hits.load(Ordering::Relaxed),
|
||||
total_additions: self.additions.load(Ordering::Relaxed),
|
||||
total_cleanups: self.cleanups.load(Ordering::Relaxed),
|
||||
num_shards: self.shards.len(),
|
||||
window_secs: self.window.as_secs(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_handshake(&self, data: &[u8]) {
|
||||
self.add(data)
|
||||
|
||||
pub async fn run_periodic_cleanup(&self) {
|
||||
let interval = if self.window.as_secs() > 60 {
|
||||
Duration::from_secs(30)
|
||||
} else {
|
||||
Duration::from_secs(self.window.as_secs().max(1) / 2)
|
||||
};
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
|
||||
let now = Instant::now();
|
||||
let mut cleaned = 0usize;
|
||||
|
||||
for shard_mutex in &self.shards {
|
||||
let mut shard = shard_mutex.lock();
|
||||
let before = shard.len();
|
||||
shard.cleanup(now, self.window);
|
||||
let after = shard.len();
|
||||
cleaned += before.saturating_sub(after);
|
||||
}
|
||||
|
||||
self.cleanups.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
if cleaned > 0 {
|
||||
debug!(cleaned = cleaned, "Replay checker: periodic cleanup");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
|
||||
self.check(data)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReplayStats {
|
||||
pub total_entries: usize,
|
||||
pub total_queue_len: usize,
|
||||
pub total_checks: u64,
|
||||
pub total_hits: u64,
|
||||
pub total_additions: u64,
|
||||
pub total_cleanups: u64,
|
||||
pub num_shards: usize,
|
||||
pub window_secs: u64,
|
||||
}
|
||||
|
||||
impl ReplayStats {
|
||||
pub fn hit_rate(&self) -> f64 {
|
||||
if self.total_checks == 0 { 0.0 }
|
||||
else { (self.total_hits as f64 / self.total_checks as f64) * 100.0 }
|
||||
}
|
||||
|
||||
pub fn add_tls_digest(&self, data: &[u8]) {
|
||||
self.add(data)
|
||||
|
||||
pub fn ghost_ratio(&self) -> f64 {
|
||||
if self.total_entries == 0 { 0.0 }
|
||||
else { self.total_queue_len as f64 / self.total_entries as f64 }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -260,28 +317,60 @@ mod tests {
|
||||
#[test]
|
||||
fn test_stats_shared_counters() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let stats1 = Arc::clone(&stats);
|
||||
let stats2 = Arc::clone(&stats);
|
||||
|
||||
stats1.increment_connects_all();
|
||||
stats2.increment_connects_all();
|
||||
stats1.increment_connects_all();
|
||||
|
||||
stats.increment_connects_all();
|
||||
stats.increment_connects_all();
|
||||
stats.increment_connects_all();
|
||||
assert_eq!(stats.get_connects_all(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_checker_sharding() {
|
||||
fn test_replay_checker_basic() {
|
||||
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||
let data1 = b"test1";
|
||||
let data2 = b"test2";
|
||||
|
||||
checker.add_handshake(data1);
|
||||
assert!(checker.check_handshake(data1));
|
||||
assert!(!checker.check_handshake(data2));
|
||||
|
||||
checker.add_handshake(data2);
|
||||
assert!(checker.check_handshake(data2));
|
||||
assert!(!checker.check_handshake(b"test1"));
|
||||
checker.add_handshake(b"test1");
|
||||
assert!(checker.check_handshake(b"test1"));
|
||||
assert!(!checker.check_handshake(b"test2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_checker_duplicate_add() {
|
||||
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||
checker.add_handshake(b"dup");
|
||||
checker.add_handshake(b"dup");
|
||||
assert!(checker.check_handshake(b"dup"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_checker_expiration() {
|
||||
let checker = ReplayChecker::new(100, Duration::from_millis(50));
|
||||
checker.add_handshake(b"expire");
|
||||
assert!(checker.check_handshake(b"expire"));
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
assert!(!checker.check_handshake(b"expire"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_checker_stats() {
|
||||
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||
checker.add_handshake(b"k1");
|
||||
checker.add_handshake(b"k2");
|
||||
checker.check_handshake(b"k1");
|
||||
checker.check_handshake(b"k3");
|
||||
let stats = checker.stats();
|
||||
assert_eq!(stats.total_additions, 2);
|
||||
assert_eq!(stats.total_checks, 2);
|
||||
assert_eq!(stats.total_hits, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_checker_many_keys() {
|
||||
let checker = ReplayChecker::new(1000, Duration::from_secs(60));
|
||||
for i in 0..500u32 {
|
||||
checker.add(&i.to_le_bytes());
|
||||
}
|
||||
for i in 0..500u32 {
|
||||
assert!(checker.check(&i.to_le_bytes()));
|
||||
}
|
||||
assert_eq!(checker.stats().total_entries, 500);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user