Antireplay on sliding window + SecureRandom
This commit is contained in:
@@ -2,13 +2,14 @@
|
||||
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use std::time::{Instant, Duration};
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::{RwLock, Mutex};
|
||||
use lru::LruCache;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::collections::VecDeque;
|
||||
|
||||
/// Thread-safe statistics
|
||||
#[derive(Default)]
|
||||
@@ -143,57 +144,112 @@ impl Stats {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sharded Replay attack checker using LRU cache
|
||||
/// Sharded Replay attack checker using LRU cache + sliding window
|
||||
/// Uses multiple independent LRU caches to reduce lock contention
|
||||
pub struct ReplayChecker {
|
||||
shards: Vec<Mutex<LruCache<Vec<u8>, ()>>>,
|
||||
shards: Vec<Mutex<ReplayShard>>,
|
||||
shard_mask: usize,
|
||||
window: Duration,
|
||||
}
|
||||
|
||||
struct ReplayEntry {
|
||||
seen_at: Instant,
|
||||
}
|
||||
|
||||
struct ReplayShard {
|
||||
cache: LruCache<Vec<u8>, ReplayEntry>,
|
||||
queue: VecDeque<(Instant, Vec<u8>)>,
|
||||
}
|
||||
|
||||
impl ReplayShard {
|
||||
fn new(cap: NonZeroUsize) -> Self {
|
||||
Self {
|
||||
cache: LruCache::new(cap),
|
||||
queue: VecDeque::with_capacity(cap.get()),
|
||||
}
|
||||
}
|
||||
|
||||
fn cleanup(&mut self, now: Instant, window: Duration) {
|
||||
if window.is_zero() {
|
||||
return;
|
||||
}
|
||||
let cutoff = now - window;
|
||||
while let Some((ts, _)) = self.queue.front() {
|
||||
if *ts >= cutoff {
|
||||
break;
|
||||
}
|
||||
let (ts_old, key_old) = self.queue.pop_front().unwrap();
|
||||
if let Some(entry) = self.cache.get(&key_old) {
|
||||
if entry.seen_at <= ts_old {
|
||||
self.cache.pop(&key_old);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReplayChecker {
|
||||
/// Create new replay checker with specified capacity per shard
|
||||
/// Total capacity = capacity * num_shards
|
||||
pub fn new(total_capacity: usize) -> Self {
|
||||
pub fn new(total_capacity: usize, window: Duration) -> Self {
|
||||
// Use 64 shards for good concurrency
|
||||
let num_shards = 64;
|
||||
let shard_capacity = (total_capacity / num_shards).max(1);
|
||||
let cap = NonZeroUsize::new(shard_capacity).unwrap();
|
||||
|
||||
|
||||
let mut shards = Vec::with_capacity(num_shards);
|
||||
for _ in 0..num_shards {
|
||||
shards.push(Mutex::new(LruCache::new(cap)));
|
||||
shards.push(Mutex::new(ReplayShard::new(cap)));
|
||||
}
|
||||
|
||||
|
||||
Self {
|
||||
shards,
|
||||
shard_mask: num_shards - 1,
|
||||
window,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn get_shard(&self, key: &[u8]) -> usize {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
key.hash(&mut hasher);
|
||||
(hasher.finish() as usize) & self.shard_mask
|
||||
}
|
||||
|
||||
|
||||
fn check(&self, data: &[u8]) -> bool {
|
||||
let shard_idx = self.get_shard(data);
|
||||
let mut shard = self.shards[shard_idx].lock();
|
||||
let now = Instant::now();
|
||||
shard.cleanup(now, self.window);
|
||||
|
||||
let key = data.to_vec();
|
||||
shard.cache.get(&key).is_some()
|
||||
}
|
||||
|
||||
fn add(&self, data: &[u8]) {
|
||||
let shard_idx = self.get_shard(data);
|
||||
let mut shard = self.shards[shard_idx].lock();
|
||||
let now = Instant::now();
|
||||
shard.cleanup(now, self.window);
|
||||
|
||||
let key = data.to_vec();
|
||||
shard.cache.put(key.clone(), ReplayEntry { seen_at: now });
|
||||
shard.queue.push_back((now, key));
|
||||
}
|
||||
|
||||
pub fn check_handshake(&self, data: &[u8]) -> bool {
|
||||
let shard_idx = self.get_shard(data);
|
||||
self.shards[shard_idx].lock().contains(&data.to_vec())
|
||||
self.check(data)
|
||||
}
|
||||
|
||||
|
||||
pub fn add_handshake(&self, data: &[u8]) {
|
||||
let shard_idx = self.get_shard(data);
|
||||
self.shards[shard_idx].lock().put(data.to_vec(), ());
|
||||
self.add(data)
|
||||
}
|
||||
|
||||
|
||||
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
|
||||
let shard_idx = self.get_shard(data);
|
||||
self.shards[shard_idx].lock().contains(&data.to_vec())
|
||||
self.check(data)
|
||||
}
|
||||
|
||||
|
||||
pub fn add_tls_digest(&self, data: &[u8]) {
|
||||
let shard_idx = self.get_shard(data);
|
||||
self.shards[shard_idx].lock().put(data.to_vec(), ());
|
||||
self.add(data)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,7 +273,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_replay_checker_sharding() {
|
||||
let checker = ReplayChecker::new(100);
|
||||
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||
let data1 = b"test1";
|
||||
let data2 = b"test2";
|
||||
|
||||
|
||||
Reference in New Issue
Block a user