From b9428d9780c8efccef6e7aac5ed68e570f9170be Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sat, 7 Feb 2026 18:26:44 +0300 Subject: [PATCH] Antireplay on sliding window + SecureRandom --- config.toml | 1 + src/config/mod.rs | 5 ++ src/crypto/mod.rs | 2 +- src/crypto/random.rs | 5 -- src/main.rs | 12 ++++- src/protocol/tls.rs | 24 ++++++---- src/proxy/client.rs | 16 +++++-- src/proxy/handshake.rs | 14 ++++-- src/stats/mod.rs | 98 ++++++++++++++++++++++++++++++-------- src/stream/frame.rs | 6 ++- src/stream/frame_codec.rs | 41 +++++++++------- src/stream/frame_stream.rs | 23 +++++---- 12 files changed, 171 insertions(+), 76 deletions(-) diff --git a/config.toml b/config.toml index 7b1dc8a..45f370f 100644 --- a/config.toml +++ b/config.toml @@ -49,6 +49,7 @@ fake_cert_len = 2048 # username "hello" is used for example [access] replay_check_len = 65536 +replay_window_secs = 1800 ignore_time_skew = false [access.users] diff --git a/src/config/mod.rs b/src/config/mod.rs index 425aeef..b9eaf16 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -14,6 +14,7 @@ fn default_port() -> u16 { 443 } fn default_tls_domain() -> String { "www.google.com".to_string() } fn default_mask_port() -> u16 { 443 } fn default_replay_check_len() -> usize { 65536 } +fn default_replay_window_secs() -> u64 { 1800 } fn default_handshake_timeout() -> u64 { 15 } fn default_connect_timeout() -> u64 { 10 } fn default_keepalive() -> u64 { 60 } @@ -187,6 +188,9 @@ pub struct AccessConfig { #[serde(default = "default_replay_check_len")] pub replay_check_len: usize, + #[serde(default = "default_replay_window_secs")] + pub replay_window_secs: u64, + #[serde(default)] pub ignore_time_skew: bool, } @@ -201,6 +205,7 @@ impl Default for AccessConfig { user_expirations: HashMap::new(), user_data_quota: HashMap::new(), replay_check_len: default_replay_check_len(), + replay_window_secs: default_replay_window_secs(), ignore_time_skew: false, } } diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 6339927..dfc2be6 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -6,4 +6,4 @@ pub mod random; pub use aes::{AesCtr, AesCbc}; pub use hash::{sha256, sha256_hmac, sha1, md5, crc32}; -pub use random::{SecureRandom, SECURE_RANDOM}; \ No newline at end of file +pub use random::SecureRandom; \ No newline at end of file diff --git a/src/crypto/random.rs b/src/crypto/random.rs index c179f25..19d8788 100644 --- a/src/crypto/random.rs +++ b/src/crypto/random.rs @@ -4,11 +4,6 @@ use rand::{Rng, RngCore, SeedableRng}; use rand::rngs::StdRng; use parking_lot::Mutex; use crate::crypto::AesCtr; -use once_cell::sync::Lazy; - -/// Global secure random instance -pub static SECURE_RANDOM: Lazy = Lazy::new(SecureRandom::new); - /// Cryptographically secure PRNG with AES-CTR pub struct SecureRandom { inner: Mutex, diff --git a/src/main.rs b/src/main.rs index a672ed0..f87bec5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,6 +21,7 @@ mod util; use crate::config::ProxyConfig; use crate::proxy::ClientHandler; use crate::stats::{Stats, ReplayChecker}; +use crate::crypto::SecureRandom; use crate::transport::{create_listener, ListenOptions, UpstreamManager}; use crate::util::ip::detect_ip; use crate::stream::BufferPool; @@ -68,10 +69,14 @@ async fn main() -> Result<(), Box> { let config = Arc::new(config); let stats = Arc::new(Stats::new()); + let rng = Arc::new(SecureRandom::new()); // Initialize global ReplayChecker // Using sharded implementation for better concurrency - let replay_checker = Arc::new(ReplayChecker::new(config.access.replay_check_len)); + let replay_checker = Arc::new(ReplayChecker::new( + config.access.replay_check_len, + Duration::from_secs(config.access.replay_window_secs), + )); // Initialize Upstream Manager let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); @@ -166,6 +171,7 @@ async fn main() -> Result<(), Box> { let upstream_manager = upstream_manager.clone(); let replay_checker = replay_checker.clone(); let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); tokio::spawn(async move { loop { @@ -176,6 +182,7 @@ async fn main() -> Result<(), Box> { let upstream_manager = upstream_manager.clone(); let replay_checker = replay_checker.clone(); let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); tokio::spawn(async move { if let Err(e) = ClientHandler::new( @@ -185,7 +192,8 @@ async fn main() -> Result<(), Box> { stats, upstream_manager, replay_checker, - buffer_pool + buffer_pool, + rng ).run().await { // Log only relevant errors } diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index 354ee9a..68cd3dc 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -4,7 +4,7 @@ //! for domain fronting. The handshake looks like valid TLS 1.3 but //! actually carries MTProto authentication data. -use crate::crypto::{sha256_hmac, random::SECURE_RANDOM}; +use crate::crypto::{sha256_hmac, SecureRandom}; use crate::error::{ProxyError, Result}; use super::constants::*; use std::time::{SystemTime, UNIX_EPOCH}; @@ -315,8 +315,8 @@ pub fn validate_tls_handshake( /// /// This generates random bytes that look like a valid X25519 public key. /// Since we're not doing real TLS, the actual cryptographic properties don't matter. -pub fn gen_fake_x25519_key() -> [u8; 32] { - let bytes = SECURE_RANDOM.bytes(32); +pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] { + let bytes = rng.bytes(32); bytes.try_into().unwrap() } @@ -333,8 +333,9 @@ pub fn build_server_hello( client_digest: &[u8; TLS_DIGEST_LEN], session_id: &[u8], fake_cert_len: usize, + rng: &SecureRandom, ) -> Vec { - let x25519_key = gen_fake_x25519_key(); + let x25519_key = gen_fake_x25519_key(rng); // Build ServerHello let server_hello = ServerHelloBuilder::new(session_id.to_vec()) @@ -351,7 +352,7 @@ pub fn build_server_hello( ]; // Build fake certificate (Application Data record) - let fake_cert = SECURE_RANDOM.bytes(fake_cert_len); + let fake_cert = rng.bytes(fake_cert_len); let mut app_data_record = Vec::with_capacity(5 + fake_cert_len); app_data_record.push(TLS_RECORD_APPLICATION); app_data_record.extend_from_slice(&TLS_VERSION); @@ -489,8 +490,9 @@ mod tests { #[test] fn test_gen_fake_x25519_key() { - let key1 = gen_fake_x25519_key(); - let key2 = gen_fake_x25519_key(); + let rng = SecureRandom::new(); + let key1 = gen_fake_x25519_key(&rng); + let key2 = gen_fake_x25519_key(&rng); assert_eq!(key1.len(), 32); assert_eq!(key2.len(), 32); @@ -545,7 +547,8 @@ mod tests { let client_digest = [0x42u8; 32]; let session_id = vec![0xAA; 32]; - let response = build_server_hello(secret, &client_digest, &session_id, 2048); + let rng = SecureRandom::new(); + let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng); // Should have at least 3 records assert!(response.len() > 100); @@ -577,8 +580,9 @@ mod tests { let client_digest = [0x42u8; 32]; let session_id = vec![0xAA; 32]; - let response1 = build_server_hello(secret, &client_digest, &session_id, 1024); - let response2 = build_server_hello(secret, &client_digest, &session_id, 1024); + let rng = SecureRandom::new(); + let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng); + let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng); // Digest position should have non-zero data let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 29ef0cd..2d85a9f 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -15,7 +15,7 @@ use crate::protocol::tls; use crate::stats::{Stats, ReplayChecker}; use crate::transport::{configure_client_socket, UpstreamManager}; use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool}; -use crate::crypto::AesCtr; +use crate::crypto::{AesCtr, SecureRandom}; // Use absolute paths to avoid confusion use crate::proxy::handshake::{ @@ -37,6 +37,7 @@ pub struct RunningClientHandler { replay_checker: Arc, upstream_manager: Arc, buffer_pool: Arc, + rng: Arc, } impl ClientHandler { @@ -49,6 +50,7 @@ impl ClientHandler { upstream_manager: Arc, replay_checker: Arc, buffer_pool: Arc, + rng: Arc, ) -> RunningClientHandler { RunningClientHandler { stream, @@ -58,6 +60,7 @@ impl ClientHandler { replay_checker, upstream_manager, buffer_pool, + rng, } } } @@ -168,6 +171,7 @@ impl RunningClientHandler { peer, &config, &replay_checker, + &self.rng, ).await { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { @@ -211,7 +215,8 @@ impl RunningClientHandler { self.upstream_manager, self.stats, self.config, - buffer_pool + buffer_pool, + self.rng ).await } @@ -272,7 +277,8 @@ impl RunningClientHandler { self.upstream_manager, self.stats, self.config, - buffer_pool + buffer_pool, + self.rng ).await } @@ -285,6 +291,7 @@ impl RunningClientHandler { stats: Arc, config: Arc, buffer_pool: Arc, + rng: Arc, ) -> Result<()> where R: AsyncRead + Unpin + Send + 'static, @@ -321,6 +328,7 @@ impl RunningClientHandler { tg_stream, &success, &config, + rng.as_ref(), ).await?; debug!(peer = %success.peer, "Telegram handshake complete, starting relay"); @@ -401,12 +409,14 @@ impl RunningClientHandler { mut stream: TcpStream, success: &HandshakeSuccess, config: &ProxyConfig, + rng: &SecureRandom, ) -> Result<(CryptoReader, CryptoWriter)> { // Generate nonce with keys for TG let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce( success.proto_tag, &success.dec_key, // Client's dec key success.dec_iv, + rng, config.general.fast_mode, ); diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index bb0ad1a..7e65716 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -4,8 +4,7 @@ use std::net::SocketAddr; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn, trace, info}; -use crate::crypto::{sha256, AesCtr}; -use crate::crypto::random::SECURE_RANDOM; +use crate::crypto::{sha256, AesCtr, SecureRandom}; use crate::protocol::constants::*; use crate::protocol::tls; use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter}; @@ -42,6 +41,7 @@ pub async fn handle_tls_handshake( peer: SocketAddr, config: &ProxyConfig, replay_checker: &ReplayChecker, + rng: &SecureRandom, ) -> HandshakeResult<(FakeTlsReader, FakeTlsWriter, String), R, W> where R: AsyncRead + Unpin, @@ -101,6 +101,7 @@ where &validation.digest, &validation.session_id, config.censorship.fake_cert_len, + rng, ); debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); @@ -264,10 +265,11 @@ pub fn generate_tg_nonce( proto_tag: ProtoTag, client_dec_key: &[u8; 32], client_dec_iv: u128, + rng: &SecureRandom, fast_mode: bool, ) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) { loop { - let bytes = SECURE_RANDOM.bytes(HANDSHAKE_LEN); + let bytes = rng.bytes(HANDSHAKE_LEN); let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap(); if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; } @@ -323,8 +325,9 @@ mod tests { let client_dec_key = [0x42u8; 32]; let client_dec_iv = 12345u128; + let rng = SecureRandom::new(); let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = - generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false); + generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false); // Check length assert_eq!(nonce.len(), HANDSHAKE_LEN); @@ -339,8 +342,9 @@ mod tests { let client_dec_key = [0x42u8; 32]; let client_dec_iv = 12345u128; + let rng = SecureRandom::new(); let (nonce, _, _, _, _) = - generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false); + generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false); let encrypted = encrypt_tg_nonce(&nonce); diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 9fa495d..39002e8 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -2,13 +2,14 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; -use std::time::Instant; +use std::time::{Instant, Duration}; use dashmap::DashMap; use parking_lot::{RwLock, Mutex}; use lru::LruCache; use std::num::NonZeroUsize; use std::hash::{Hash, Hasher}; use std::collections::hash_map::DefaultHasher; +use std::collections::VecDeque; /// Thread-safe statistics #[derive(Default)] @@ -143,57 +144,112 @@ impl Stats { } } -/// Sharded Replay attack checker using LRU cache +/// Sharded Replay attack checker using LRU cache + sliding window /// Uses multiple independent LRU caches to reduce lock contention pub struct ReplayChecker { - shards: Vec, ()>>>, + shards: Vec>, shard_mask: usize, + window: Duration, +} + +struct ReplayEntry { + seen_at: Instant, +} + +struct ReplayShard { + cache: LruCache, ReplayEntry>, + queue: VecDeque<(Instant, Vec)>, +} + +impl ReplayShard { + fn new(cap: NonZeroUsize) -> Self { + Self { + cache: LruCache::new(cap), + queue: VecDeque::with_capacity(cap.get()), + } + } + + fn cleanup(&mut self, now: Instant, window: Duration) { + if window.is_zero() { + return; + } + let cutoff = now - window; + while let Some((ts, _)) = self.queue.front() { + if *ts >= cutoff { + break; + } + let (ts_old, key_old) = self.queue.pop_front().unwrap(); + if let Some(entry) = self.cache.get(&key_old) { + if entry.seen_at <= ts_old { + self.cache.pop(&key_old); + } + } + } + } } impl ReplayChecker { /// Create new replay checker with specified capacity per shard /// Total capacity = capacity * num_shards - pub fn new(total_capacity: usize) -> Self { + pub fn new(total_capacity: usize, window: Duration) -> Self { // Use 64 shards for good concurrency let num_shards = 64; let shard_capacity = (total_capacity / num_shards).max(1); let cap = NonZeroUsize::new(shard_capacity).unwrap(); - + let mut shards = Vec::with_capacity(num_shards); for _ in 0..num_shards { - shards.push(Mutex::new(LruCache::new(cap))); + shards.push(Mutex::new(ReplayShard::new(cap))); } - + Self { shards, shard_mask: num_shards - 1, + window, } } - + fn get_shard(&self, key: &[u8]) -> usize { let mut hasher = DefaultHasher::new(); key.hash(&mut hasher); (hasher.finish() as usize) & self.shard_mask } - + + fn check(&self, data: &[u8]) -> bool { + let shard_idx = self.get_shard(data); + let mut shard = self.shards[shard_idx].lock(); + let now = Instant::now(); + shard.cleanup(now, self.window); + + let key = data.to_vec(); + shard.cache.get(&key).is_some() + } + + fn add(&self, data: &[u8]) { + let shard_idx = self.get_shard(data); + let mut shard = self.shards[shard_idx].lock(); + let now = Instant::now(); + shard.cleanup(now, self.window); + + let key = data.to_vec(); + shard.cache.put(key.clone(), ReplayEntry { seen_at: now }); + shard.queue.push_back((now, key)); + } + pub fn check_handshake(&self, data: &[u8]) -> bool { - let shard_idx = self.get_shard(data); - self.shards[shard_idx].lock().contains(&data.to_vec()) + self.check(data) } - + pub fn add_handshake(&self, data: &[u8]) { - let shard_idx = self.get_shard(data); - self.shards[shard_idx].lock().put(data.to_vec(), ()); + self.add(data) } - + pub fn check_tls_digest(&self, data: &[u8]) -> bool { - let shard_idx = self.get_shard(data); - self.shards[shard_idx].lock().contains(&data.to_vec()) + self.check(data) } - + pub fn add_tls_digest(&self, data: &[u8]) { - let shard_idx = self.get_shard(data); - self.shards[shard_idx].lock().put(data.to_vec(), ()); + self.add(data) } } @@ -217,7 +273,7 @@ mod tests { #[test] fn test_replay_checker_sharding() { - let checker = ReplayChecker::new(100); + let checker = ReplayChecker::new(100, Duration::from_secs(60)); let data1 = b"test1"; let data2 = b"test2"; diff --git a/src/stream/frame.rs b/src/stream/frame.rs index 42f6c74..b97d4cf 100644 --- a/src/stream/frame.rs +++ b/src/stream/frame.rs @@ -5,8 +5,10 @@ use bytes::{Bytes, BytesMut}; use std::io::Result; +use std::sync::Arc; use crate::protocol::constants::ProtoTag; +use crate::crypto::SecureRandom; // ============= Frame Types ============= @@ -147,11 +149,11 @@ pub trait FrameCodec: Send + Sync { // ============= Codec Factory ============= /// Create a frame codec for the given protocol tag -pub fn create_codec(proto_tag: ProtoTag) -> Box { +pub fn create_codec(proto_tag: ProtoTag, rng: Arc) -> Box { match proto_tag { ProtoTag::Abridged => Box::new(crate::stream::frame_codec::AbridgedCodec::new()), ProtoTag::Intermediate => Box::new(crate::stream::frame_codec::IntermediateCodec::new()), - ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new()), + ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new(rng)), } } diff --git a/src/stream/frame_codec.rs b/src/stream/frame_codec.rs index 75f6bde..30bcc95 100644 --- a/src/stream/frame_codec.rs +++ b/src/stream/frame_codec.rs @@ -5,9 +5,11 @@ use bytes::{Bytes, BytesMut, BufMut}; use std::io::{self, Error, ErrorKind}; +use std::sync::Arc; use tokio_util::codec::{Decoder, Encoder}; use crate::protocol::constants::ProtoTag; +use crate::crypto::SecureRandom; use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait}; // ============= Unified Codec ============= @@ -21,14 +23,17 @@ pub struct FrameCodec { proto_tag: ProtoTag, /// Maximum allowed frame size max_frame_size: usize, + /// RNG for secure padding + rng: Arc, } impl FrameCodec { /// Create a new codec for the given protocol - pub fn new(proto_tag: ProtoTag) -> Self { + pub fn new(proto_tag: ProtoTag, rng: Arc) -> Self { Self { proto_tag, max_frame_size: 16 * 1024 * 1024, // 16MB default + rng, } } @@ -64,7 +69,7 @@ impl Encoder for FrameCodec { match self.proto_tag { ProtoTag::Abridged => encode_abridged(&frame, dst), ProtoTag::Intermediate => encode_intermediate(&frame, dst), - ProtoTag::Secure => encode_secure(&frame, dst), + ProtoTag::Secure => encode_secure(&frame, dst, &self.rng), } } } @@ -288,9 +293,7 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result io::Result<()> { - use crate::crypto::random::SECURE_RANDOM; - +fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::Result<()> { let data = &frame.data; // Simple ACK: just send data @@ -303,10 +306,10 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { // Generate padding to make length not divisible by 4 let padding_len = if data.len() % 4 == 0 { // Add 1-3 bytes to make it non-aligned - (SECURE_RANDOM.range(3) + 1) as usize + (rng.range(3) + 1) as usize } else { // Already non-aligned, can add 0-3 - SECURE_RANDOM.range(4) as usize + rng.range(4) as usize }; let total_len = data.len() + padding_len; @@ -321,7 +324,7 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { dst.extend_from_slice(data); if padding_len > 0 { - let padding = SECURE_RANDOM.bytes(padding_len); + let padding = rng.bytes(padding_len); dst.extend_from_slice(&padding); } @@ -445,19 +448,21 @@ impl FrameCodecTrait for IntermediateCodec { /// Secure Intermediate protocol codec pub struct SecureCodec { max_frame_size: usize, + rng: Arc, } impl SecureCodec { - pub fn new() -> Self { + pub fn new(rng: Arc) -> Self { Self { max_frame_size: 16 * 1024 * 1024, + rng, } } } impl Default for SecureCodec { fn default() -> Self { - Self::new() + Self::new(Arc::new(SecureRandom::new())) } } @@ -474,7 +479,7 @@ impl Encoder for SecureCodec { type Error = io::Error; fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { - encode_secure(&frame, dst) + encode_secure(&frame, dst, &self.rng) } } @@ -485,7 +490,7 @@ impl FrameCodecTrait for SecureCodec { fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result { let before = dst.len(); - encode_secure(frame, dst)?; + encode_secure(frame, dst, &self.rng)?; Ok(dst.len() - before) } @@ -506,6 +511,8 @@ mod tests { use tokio_util::codec::{FramedRead, FramedWrite}; use tokio::io::duplex; use futures::{SinkExt, StreamExt}; + use crate::crypto::SecureRandom; + use std::sync::Arc; #[tokio::test] async fn test_framed_abridged() { @@ -541,8 +548,8 @@ mod tests { async fn test_framed_secure() { let (client, server) = duplex(4096); - let mut writer = FramedWrite::new(client, SecureCodec::new()); - let mut reader = FramedRead::new(server, SecureCodec::new()); + let mut writer = FramedWrite::new(client, SecureCodec::new(Arc::new(SecureRandom::new()))); + let mut reader = FramedRead::new(server, SecureCodec::new(Arc::new(SecureRandom::new()))); let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); let frame = Frame::new(original.clone()); @@ -557,8 +564,8 @@ mod tests { for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] { let (client, server) = duplex(4096); - let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag)); - let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag)); + let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new()))); + let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new()))); // Use 4-byte aligned data for abridged compatibility let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); @@ -607,7 +614,7 @@ mod tests { #[test] fn test_frame_too_large() { - let mut codec = FrameCodec::new(ProtoTag::Intermediate) + let mut codec = FrameCodec::new(ProtoTag::Intermediate, Arc::new(SecureRandom::new())) .with_max_frame_size(100); // Create a "frame" that claims to be very large diff --git a/src/stream/frame_stream.rs b/src/stream/frame_stream.rs index 9e62c8d..fd8c1b4 100644 --- a/src/stream/frame_stream.rs +++ b/src/stream/frame_stream.rs @@ -4,8 +4,8 @@ use bytes::{Bytes, BytesMut}; use std::io::{Error, ErrorKind, Result}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use crate::protocol::constants::*; -use crate::crypto::crc32; -use crate::crypto::random::SECURE_RANDOM; +use crate::crypto::{crc32, SecureRandom}; +use std::sync::Arc; use super::traits::{FrameMeta, LayeredStream}; // ============= Abridged (Compact) Frame ============= @@ -251,11 +251,12 @@ impl LayeredStream for SecureIntermediateFrameReader { /// Writer for secure intermediate MTProto framing pub struct SecureIntermediateFrameWriter { upstream: W, + rng: Arc, } impl SecureIntermediateFrameWriter { - pub fn new(upstream: W) -> Self { - Self { upstream } + pub fn new(upstream: W, rng: Arc) -> Self { + Self { upstream, rng } } } @@ -267,8 +268,8 @@ impl SecureIntermediateFrameWriter { } // Add random padding (0-3 bytes) - let padding_len = SECURE_RANDOM.range(4); - let padding = SECURE_RANDOM.bytes(padding_len); + let padding_len = self.rng.range(4); + let padding = self.rng.bytes(padding_len); let total_len = data.len() + padding_len; let len_bytes = (total_len as u32).to_le_bytes(); @@ -454,11 +455,11 @@ pub enum FrameWriterKind { } impl FrameWriterKind { - pub fn new(upstream: W, proto_tag: ProtoTag) -> Self { + pub fn new(upstream: W, proto_tag: ProtoTag, rng: Arc) -> Self { match proto_tag { ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)), ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)), - ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream)), + ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream, rng)), } } @@ -483,6 +484,8 @@ impl FrameWriterKind { mod tests { use super::*; use tokio::io::duplex; + use std::sync::Arc; + use crate::crypto::SecureRandom; #[tokio::test] async fn test_abridged_roundtrip() { @@ -539,7 +542,7 @@ mod tests { async fn test_secure_intermediate_padding() { let (client, server) = duplex(1024); - let mut writer = SecureIntermediateFrameWriter::new(client); + let mut writer = SecureIntermediateFrameWriter::new(client, Arc::new(SecureRandom::new())); let mut reader = SecureIntermediateFrameReader::new(server); let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8]; @@ -572,7 +575,7 @@ mod tests { async fn test_frame_reader_kind() { let (client, server) = duplex(1024); - let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate); + let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate, Arc::new(SecureRandom::new())); let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate); let data = vec![1u8, 2, 3, 4];