Antireplay on sliding window + SecureRandom

This commit is contained in:
Alexey
2026-02-07 18:26:44 +03:00
parent 5876f0c4d5
commit b9428d9780
12 changed files with 171 additions and 76 deletions

View File

@@ -2,13 +2,14 @@
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use std::time::{Instant, Duration};
use dashmap::DashMap;
use parking_lot::{RwLock, Mutex};
use lru::LruCache;
use std::num::NonZeroUsize;
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
use std::collections::VecDeque;
/// Thread-safe statistics
#[derive(Default)]
@@ -143,57 +144,112 @@ impl Stats {
}
}
/// Sharded Replay attack checker using LRU cache
/// Sharded Replay attack checker using LRU cache + sliding window
/// Uses multiple independent LRU caches to reduce lock contention
pub struct ReplayChecker {
shards: Vec<Mutex<LruCache<Vec<u8>, ()>>>,
shards: Vec<Mutex<ReplayShard>>,
shard_mask: usize,
window: Duration,
}
struct ReplayEntry {
seen_at: Instant,
}
struct ReplayShard {
cache: LruCache<Vec<u8>, ReplayEntry>,
queue: VecDeque<(Instant, Vec<u8>)>,
}
impl ReplayShard {
fn new(cap: NonZeroUsize) -> Self {
Self {
cache: LruCache::new(cap),
queue: VecDeque::with_capacity(cap.get()),
}
}
fn cleanup(&mut self, now: Instant, window: Duration) {
if window.is_zero() {
return;
}
let cutoff = now - window;
while let Some((ts, _)) = self.queue.front() {
if *ts >= cutoff {
break;
}
let (ts_old, key_old) = self.queue.pop_front().unwrap();
if let Some(entry) = self.cache.get(&key_old) {
if entry.seen_at <= ts_old {
self.cache.pop(&key_old);
}
}
}
}
}
impl ReplayChecker {
/// Create new replay checker with specified capacity per shard
/// Total capacity = capacity * num_shards
pub fn new(total_capacity: usize) -> Self {
pub fn new(total_capacity: usize, window: Duration) -> Self {
// Use 64 shards for good concurrency
let num_shards = 64;
let shard_capacity = (total_capacity / num_shards).max(1);
let cap = NonZeroUsize::new(shard_capacity).unwrap();
let mut shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
shards.push(Mutex::new(LruCache::new(cap)));
shards.push(Mutex::new(ReplayShard::new(cap)));
}
Self {
shards,
shard_mask: num_shards - 1,
window,
}
}
fn get_shard(&self, key: &[u8]) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
(hasher.finish() as usize) & self.shard_mask
}
fn check(&self, data: &[u8]) -> bool {
let shard_idx = self.get_shard(data);
let mut shard = self.shards[shard_idx].lock();
let now = Instant::now();
shard.cleanup(now, self.window);
let key = data.to_vec();
shard.cache.get(&key).is_some()
}
fn add(&self, data: &[u8]) {
let shard_idx = self.get_shard(data);
let mut shard = self.shards[shard_idx].lock();
let now = Instant::now();
shard.cleanup(now, self.window);
let key = data.to_vec();
shard.cache.put(key.clone(), ReplayEntry { seen_at: now });
shard.queue.push_back((now, key));
}
pub fn check_handshake(&self, data: &[u8]) -> bool {
let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().contains(&data.to_vec())
self.check(data)
}
pub fn add_handshake(&self, data: &[u8]) {
let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().put(data.to_vec(), ());
self.add(data)
}
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().contains(&data.to_vec())
self.check(data)
}
pub fn add_tls_digest(&self, data: &[u8]) {
let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().put(data.to_vec(), ());
self.add(data)
}
}
@@ -217,7 +273,7 @@ mod tests {
#[test]
fn test_replay_checker_sharding() {
let checker = ReplayChecker::new(100);
let checker = ReplayChecker::new(100, Duration::from_secs(60));
let data1 = b"test1";
let data2 = b"test2";