Antireplay on sliding window + SecureRandom
This commit is contained in:
@@ -14,6 +14,7 @@ fn default_port() -> u16 { 443 }
|
||||
fn default_tls_domain() -> String { "www.google.com".to_string() }
|
||||
fn default_mask_port() -> u16 { 443 }
|
||||
fn default_replay_check_len() -> usize { 65536 }
|
||||
fn default_replay_window_secs() -> u64 { 1800 }
|
||||
fn default_handshake_timeout() -> u64 { 15 }
|
||||
fn default_connect_timeout() -> u64 { 10 }
|
||||
fn default_keepalive() -> u64 { 60 }
|
||||
@@ -187,6 +188,9 @@ pub struct AccessConfig {
|
||||
#[serde(default = "default_replay_check_len")]
|
||||
pub replay_check_len: usize,
|
||||
|
||||
#[serde(default = "default_replay_window_secs")]
|
||||
pub replay_window_secs: u64,
|
||||
|
||||
#[serde(default)]
|
||||
pub ignore_time_skew: bool,
|
||||
}
|
||||
@@ -201,6 +205,7 @@ impl Default for AccessConfig {
|
||||
user_expirations: HashMap::new(),
|
||||
user_data_quota: HashMap::new(),
|
||||
replay_check_len: default_replay_check_len(),
|
||||
replay_window_secs: default_replay_window_secs(),
|
||||
ignore_time_skew: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,4 +6,4 @@ pub mod random;
|
||||
|
||||
pub use aes::{AesCtr, AesCbc};
|
||||
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32};
|
||||
pub use random::{SecureRandom, SECURE_RANDOM};
|
||||
pub use random::SecureRandom;
|
||||
@@ -4,11 +4,6 @@ use rand::{Rng, RngCore, SeedableRng};
|
||||
use rand::rngs::StdRng;
|
||||
use parking_lot::Mutex;
|
||||
use crate::crypto::AesCtr;
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
/// Global secure random instance
|
||||
pub static SECURE_RANDOM: Lazy<SecureRandom> = Lazy::new(SecureRandom::new);
|
||||
|
||||
/// Cryptographically secure PRNG with AES-CTR
|
||||
pub struct SecureRandom {
|
||||
inner: Mutex<SecureRandomInner>,
|
||||
|
||||
12
src/main.rs
12
src/main.rs
@@ -21,6 +21,7 @@ mod util;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::proxy::ClientHandler;
|
||||
use crate::stats::{Stats, ReplayChecker};
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::transport::{create_listener, ListenOptions, UpstreamManager};
|
||||
use crate::util::ip::detect_ip;
|
||||
use crate::stream::BufferPool;
|
||||
@@ -68,10 +69,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
let config = Arc::new(config);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
|
||||
// Initialize global ReplayChecker
|
||||
// Using sharded implementation for better concurrency
|
||||
let replay_checker = Arc::new(ReplayChecker::new(config.access.replay_check_len));
|
||||
let replay_checker = Arc::new(ReplayChecker::new(
|
||||
config.access.replay_check_len,
|
||||
Duration::from_secs(config.access.replay_window_secs),
|
||||
));
|
||||
|
||||
// Initialize Upstream Manager
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
|
||||
@@ -166,6 +171,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let upstream_manager = upstream_manager.clone();
|
||||
let replay_checker = replay_checker.clone();
|
||||
let buffer_pool = buffer_pool.clone();
|
||||
let rng = rng.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
@@ -176,6 +182,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let upstream_manager = upstream_manager.clone();
|
||||
let replay_checker = replay_checker.clone();
|
||||
let buffer_pool = buffer_pool.clone();
|
||||
let rng = rng.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = ClientHandler::new(
|
||||
@@ -185,7 +192,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker,
|
||||
buffer_pool
|
||||
buffer_pool,
|
||||
rng
|
||||
).run().await {
|
||||
// Log only relevant errors
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
//! for domain fronting. The handshake looks like valid TLS 1.3 but
|
||||
//! actually carries MTProto authentication data.
|
||||
|
||||
use crate::crypto::{sha256_hmac, random::SECURE_RANDOM};
|
||||
use crate::crypto::{sha256_hmac, SecureRandom};
|
||||
use crate::error::{ProxyError, Result};
|
||||
use super::constants::*;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
@@ -315,8 +315,8 @@ pub fn validate_tls_handshake(
|
||||
///
|
||||
/// This generates random bytes that look like a valid X25519 public key.
|
||||
/// Since we're not doing real TLS, the actual cryptographic properties don't matter.
|
||||
pub fn gen_fake_x25519_key() -> [u8; 32] {
|
||||
let bytes = SECURE_RANDOM.bytes(32);
|
||||
pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] {
|
||||
let bytes = rng.bytes(32);
|
||||
bytes.try_into().unwrap()
|
||||
}
|
||||
|
||||
@@ -333,8 +333,9 @@ pub fn build_server_hello(
|
||||
client_digest: &[u8; TLS_DIGEST_LEN],
|
||||
session_id: &[u8],
|
||||
fake_cert_len: usize,
|
||||
rng: &SecureRandom,
|
||||
) -> Vec<u8> {
|
||||
let x25519_key = gen_fake_x25519_key();
|
||||
let x25519_key = gen_fake_x25519_key(rng);
|
||||
|
||||
// Build ServerHello
|
||||
let server_hello = ServerHelloBuilder::new(session_id.to_vec())
|
||||
@@ -351,7 +352,7 @@ pub fn build_server_hello(
|
||||
];
|
||||
|
||||
// Build fake certificate (Application Data record)
|
||||
let fake_cert = SECURE_RANDOM.bytes(fake_cert_len);
|
||||
let fake_cert = rng.bytes(fake_cert_len);
|
||||
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
|
||||
app_data_record.push(TLS_RECORD_APPLICATION);
|
||||
app_data_record.extend_from_slice(&TLS_VERSION);
|
||||
@@ -489,8 +490,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_gen_fake_x25519_key() {
|
||||
let key1 = gen_fake_x25519_key();
|
||||
let key2 = gen_fake_x25519_key();
|
||||
let rng = SecureRandom::new();
|
||||
let key1 = gen_fake_x25519_key(&rng);
|
||||
let key2 = gen_fake_x25519_key(&rng);
|
||||
|
||||
assert_eq!(key1.len(), 32);
|
||||
assert_eq!(key2.len(), 32);
|
||||
@@ -545,7 +547,8 @@ mod tests {
|
||||
let client_digest = [0x42u8; 32];
|
||||
let session_id = vec![0xAA; 32];
|
||||
|
||||
let response = build_server_hello(secret, &client_digest, &session_id, 2048);
|
||||
let rng = SecureRandom::new();
|
||||
let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng);
|
||||
|
||||
// Should have at least 3 records
|
||||
assert!(response.len() > 100);
|
||||
@@ -577,8 +580,9 @@ mod tests {
|
||||
let client_digest = [0x42u8; 32];
|
||||
let session_id = vec![0xAA; 32];
|
||||
|
||||
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024);
|
||||
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024);
|
||||
let rng = SecureRandom::new();
|
||||
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
|
||||
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
|
||||
|
||||
// Digest position should have non-zero data
|
||||
let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
|
||||
|
||||
@@ -15,7 +15,7 @@ use crate::protocol::tls;
|
||||
use crate::stats::{Stats, ReplayChecker};
|
||||
use crate::transport::{configure_client_socket, UpstreamManager};
|
||||
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool};
|
||||
use crate::crypto::AesCtr;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
|
||||
// Use absolute paths to avoid confusion
|
||||
use crate::proxy::handshake::{
|
||||
@@ -37,6 +37,7 @@ pub struct RunningClientHandler {
|
||||
replay_checker: Arc<ReplayChecker>,
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
}
|
||||
|
||||
impl ClientHandler {
|
||||
@@ -49,6 +50,7 @@ impl ClientHandler {
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
replay_checker: Arc<ReplayChecker>,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
) -> RunningClientHandler {
|
||||
RunningClientHandler {
|
||||
stream,
|
||||
@@ -58,6 +60,7 @@ impl ClientHandler {
|
||||
replay_checker,
|
||||
upstream_manager,
|
||||
buffer_pool,
|
||||
rng,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -168,6 +171,7 @@ impl RunningClientHandler {
|
||||
peer,
|
||||
&config,
|
||||
&replay_checker,
|
||||
&self.rng,
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
HandshakeResult::BadClient { reader, writer } => {
|
||||
@@ -211,7 +215,8 @@ impl RunningClientHandler {
|
||||
self.upstream_manager,
|
||||
self.stats,
|
||||
self.config,
|
||||
buffer_pool
|
||||
buffer_pool,
|
||||
self.rng
|
||||
).await
|
||||
}
|
||||
|
||||
@@ -272,7 +277,8 @@ impl RunningClientHandler {
|
||||
self.upstream_manager,
|
||||
self.stats,
|
||||
self.config,
|
||||
buffer_pool
|
||||
buffer_pool,
|
||||
self.rng
|
||||
).await
|
||||
}
|
||||
|
||||
@@ -285,6 +291,7 @@ impl RunningClientHandler {
|
||||
stats: Arc<Stats>,
|
||||
config: Arc<ProxyConfig>,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
) -> Result<()>
|
||||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
@@ -321,6 +328,7 @@ impl RunningClientHandler {
|
||||
tg_stream,
|
||||
&success,
|
||||
&config,
|
||||
rng.as_ref(),
|
||||
).await?;
|
||||
|
||||
debug!(peer = %success.peer, "Telegram handshake complete, starting relay");
|
||||
@@ -401,12 +409,14 @@ impl RunningClientHandler {
|
||||
mut stream: TcpStream,
|
||||
success: &HandshakeSuccess,
|
||||
config: &ProxyConfig,
|
||||
rng: &SecureRandom,
|
||||
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
|
||||
// Generate nonce with keys for TG
|
||||
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
|
||||
success.proto_tag,
|
||||
&success.dec_key, // Client's dec key
|
||||
success.dec_iv,
|
||||
rng,
|
||||
config.general.fast_mode,
|
||||
);
|
||||
|
||||
|
||||
@@ -4,8 +4,7 @@ use std::net::SocketAddr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tracing::{debug, warn, trace, info};
|
||||
|
||||
use crate::crypto::{sha256, AesCtr};
|
||||
use crate::crypto::random::SECURE_RANDOM;
|
||||
use crate::crypto::{sha256, AesCtr, SecureRandom};
|
||||
use crate::protocol::constants::*;
|
||||
use crate::protocol::tls;
|
||||
use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
|
||||
@@ -42,6 +41,7 @@ pub async fn handle_tls_handshake<R, W>(
|
||||
peer: SocketAddr,
|
||||
config: &ProxyConfig,
|
||||
replay_checker: &ReplayChecker,
|
||||
rng: &SecureRandom,
|
||||
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
@@ -101,6 +101,7 @@ where
|
||||
&validation.digest,
|
||||
&validation.session_id,
|
||||
config.censorship.fake_cert_len,
|
||||
rng,
|
||||
);
|
||||
|
||||
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
|
||||
@@ -264,10 +265,11 @@ pub fn generate_tg_nonce(
|
||||
proto_tag: ProtoTag,
|
||||
client_dec_key: &[u8; 32],
|
||||
client_dec_iv: u128,
|
||||
rng: &SecureRandom,
|
||||
fast_mode: bool,
|
||||
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
|
||||
loop {
|
||||
let bytes = SECURE_RANDOM.bytes(HANDSHAKE_LEN);
|
||||
let bytes = rng.bytes(HANDSHAKE_LEN);
|
||||
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
|
||||
|
||||
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
|
||||
@@ -323,8 +325,9 @@ mod tests {
|
||||
let client_dec_key = [0x42u8; 32];
|
||||
let client_dec_iv = 12345u128;
|
||||
|
||||
let rng = SecureRandom::new();
|
||||
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) =
|
||||
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false);
|
||||
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false);
|
||||
|
||||
// Check length
|
||||
assert_eq!(nonce.len(), HANDSHAKE_LEN);
|
||||
@@ -339,8 +342,9 @@ mod tests {
|
||||
let client_dec_key = [0x42u8; 32];
|
||||
let client_dec_iv = 12345u128;
|
||||
|
||||
let rng = SecureRandom::new();
|
||||
let (nonce, _, _, _, _) =
|
||||
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false);
|
||||
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false);
|
||||
|
||||
let encrypted = encrypt_tg_nonce(&nonce);
|
||||
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use std::time::{Instant, Duration};
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::{RwLock, Mutex};
|
||||
use lru::LruCache;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::collections::VecDeque;
|
||||
|
||||
/// Thread-safe statistics
|
||||
#[derive(Default)]
|
||||
@@ -143,57 +144,112 @@ impl Stats {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sharded Replay attack checker using LRU cache
|
||||
/// Sharded Replay attack checker using LRU cache + sliding window
|
||||
/// Uses multiple independent LRU caches to reduce lock contention
|
||||
pub struct ReplayChecker {
|
||||
shards: Vec<Mutex<LruCache<Vec<u8>, ()>>>,
|
||||
shards: Vec<Mutex<ReplayShard>>,
|
||||
shard_mask: usize,
|
||||
window: Duration,
|
||||
}
|
||||
|
||||
struct ReplayEntry {
|
||||
seen_at: Instant,
|
||||
}
|
||||
|
||||
struct ReplayShard {
|
||||
cache: LruCache<Vec<u8>, ReplayEntry>,
|
||||
queue: VecDeque<(Instant, Vec<u8>)>,
|
||||
}
|
||||
|
||||
impl ReplayShard {
|
||||
fn new(cap: NonZeroUsize) -> Self {
|
||||
Self {
|
||||
cache: LruCache::new(cap),
|
||||
queue: VecDeque::with_capacity(cap.get()),
|
||||
}
|
||||
}
|
||||
|
||||
fn cleanup(&mut self, now: Instant, window: Duration) {
|
||||
if window.is_zero() {
|
||||
return;
|
||||
}
|
||||
let cutoff = now - window;
|
||||
while let Some((ts, _)) = self.queue.front() {
|
||||
if *ts >= cutoff {
|
||||
break;
|
||||
}
|
||||
let (ts_old, key_old) = self.queue.pop_front().unwrap();
|
||||
if let Some(entry) = self.cache.get(&key_old) {
|
||||
if entry.seen_at <= ts_old {
|
||||
self.cache.pop(&key_old);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReplayChecker {
|
||||
/// Create new replay checker with specified capacity per shard
|
||||
/// Total capacity = capacity * num_shards
|
||||
pub fn new(total_capacity: usize) -> Self {
|
||||
pub fn new(total_capacity: usize, window: Duration) -> Self {
|
||||
// Use 64 shards for good concurrency
|
||||
let num_shards = 64;
|
||||
let shard_capacity = (total_capacity / num_shards).max(1);
|
||||
let cap = NonZeroUsize::new(shard_capacity).unwrap();
|
||||
|
||||
|
||||
let mut shards = Vec::with_capacity(num_shards);
|
||||
for _ in 0..num_shards {
|
||||
shards.push(Mutex::new(LruCache::new(cap)));
|
||||
shards.push(Mutex::new(ReplayShard::new(cap)));
|
||||
}
|
||||
|
||||
|
||||
Self {
|
||||
shards,
|
||||
shard_mask: num_shards - 1,
|
||||
window,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn get_shard(&self, key: &[u8]) -> usize {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
key.hash(&mut hasher);
|
||||
(hasher.finish() as usize) & self.shard_mask
|
||||
}
|
||||
|
||||
|
||||
fn check(&self, data: &[u8]) -> bool {
|
||||
let shard_idx = self.get_shard(data);
|
||||
let mut shard = self.shards[shard_idx].lock();
|
||||
let now = Instant::now();
|
||||
shard.cleanup(now, self.window);
|
||||
|
||||
let key = data.to_vec();
|
||||
shard.cache.get(&key).is_some()
|
||||
}
|
||||
|
||||
fn add(&self, data: &[u8]) {
|
||||
let shard_idx = self.get_shard(data);
|
||||
let mut shard = self.shards[shard_idx].lock();
|
||||
let now = Instant::now();
|
||||
shard.cleanup(now, self.window);
|
||||
|
||||
let key = data.to_vec();
|
||||
shard.cache.put(key.clone(), ReplayEntry { seen_at: now });
|
||||
shard.queue.push_back((now, key));
|
||||
}
|
||||
|
||||
pub fn check_handshake(&self, data: &[u8]) -> bool {
|
||||
let shard_idx = self.get_shard(data);
|
||||
self.shards[shard_idx].lock().contains(&data.to_vec())
|
||||
self.check(data)
|
||||
}
|
||||
|
||||
|
||||
pub fn add_handshake(&self, data: &[u8]) {
|
||||
let shard_idx = self.get_shard(data);
|
||||
self.shards[shard_idx].lock().put(data.to_vec(), ());
|
||||
self.add(data)
|
||||
}
|
||||
|
||||
|
||||
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
|
||||
let shard_idx = self.get_shard(data);
|
||||
self.shards[shard_idx].lock().contains(&data.to_vec())
|
||||
self.check(data)
|
||||
}
|
||||
|
||||
|
||||
pub fn add_tls_digest(&self, data: &[u8]) {
|
||||
let shard_idx = self.get_shard(data);
|
||||
self.shards[shard_idx].lock().put(data.to_vec(), ());
|
||||
self.add(data)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,7 +273,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_replay_checker_sharding() {
|
||||
let checker = ReplayChecker::new(100);
|
||||
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||
let data1 = b"test1";
|
||||
let data2 = b"test2";
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use std::io::Result;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::protocol::constants::ProtoTag;
|
||||
use crate::crypto::SecureRandom;
|
||||
|
||||
// ============= Frame Types =============
|
||||
|
||||
@@ -147,11 +149,11 @@ pub trait FrameCodec: Send + Sync {
|
||||
// ============= Codec Factory =============
|
||||
|
||||
/// Create a frame codec for the given protocol tag
|
||||
pub fn create_codec(proto_tag: ProtoTag) -> Box<dyn FrameCodec> {
|
||||
pub fn create_codec(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Box<dyn FrameCodec> {
|
||||
match proto_tag {
|
||||
ProtoTag::Abridged => Box::new(crate::stream::frame_codec::AbridgedCodec::new()),
|
||||
ProtoTag::Intermediate => Box::new(crate::stream::frame_codec::IntermediateCodec::new()),
|
||||
ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new()),
|
||||
ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new(rng)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
|
||||
use bytes::{Bytes, BytesMut, BufMut};
|
||||
use std::io::{self, Error, ErrorKind};
|
||||
use std::sync::Arc;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
use crate::protocol::constants::ProtoTag;
|
||||
use crate::crypto::SecureRandom;
|
||||
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
|
||||
|
||||
// ============= Unified Codec =============
|
||||
@@ -21,14 +23,17 @@ pub struct FrameCodec {
|
||||
proto_tag: ProtoTag,
|
||||
/// Maximum allowed frame size
|
||||
max_frame_size: usize,
|
||||
/// RNG for secure padding
|
||||
rng: Arc<SecureRandom>,
|
||||
}
|
||||
|
||||
impl FrameCodec {
|
||||
/// Create a new codec for the given protocol
|
||||
pub fn new(proto_tag: ProtoTag) -> Self {
|
||||
pub fn new(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
|
||||
Self {
|
||||
proto_tag,
|
||||
max_frame_size: 16 * 1024 * 1024, // 16MB default
|
||||
rng,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,7 +69,7 @@ impl Encoder<Frame> for FrameCodec {
|
||||
match self.proto_tag {
|
||||
ProtoTag::Abridged => encode_abridged(&frame, dst),
|
||||
ProtoTag::Intermediate => encode_intermediate(&frame, dst),
|
||||
ProtoTag::Secure => encode_secure(&frame, dst),
|
||||
ProtoTag::Secure => encode_secure(&frame, dst, &self.rng),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -288,9 +293,7 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
|
||||
Ok(Some(Frame::with_meta(data, meta)))
|
||||
}
|
||||
|
||||
fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
||||
use crate::crypto::random::SECURE_RANDOM;
|
||||
|
||||
fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::Result<()> {
|
||||
let data = &frame.data;
|
||||
|
||||
// Simple ACK: just send data
|
||||
@@ -303,10 +306,10 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
||||
// Generate padding to make length not divisible by 4
|
||||
let padding_len = if data.len() % 4 == 0 {
|
||||
// Add 1-3 bytes to make it non-aligned
|
||||
(SECURE_RANDOM.range(3) + 1) as usize
|
||||
(rng.range(3) + 1) as usize
|
||||
} else {
|
||||
// Already non-aligned, can add 0-3
|
||||
SECURE_RANDOM.range(4) as usize
|
||||
rng.range(4) as usize
|
||||
};
|
||||
|
||||
let total_len = data.len() + padding_len;
|
||||
@@ -321,7 +324,7 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
||||
dst.extend_from_slice(data);
|
||||
|
||||
if padding_len > 0 {
|
||||
let padding = SECURE_RANDOM.bytes(padding_len);
|
||||
let padding = rng.bytes(padding_len);
|
||||
dst.extend_from_slice(&padding);
|
||||
}
|
||||
|
||||
@@ -445,19 +448,21 @@ impl FrameCodecTrait for IntermediateCodec {
|
||||
/// Secure Intermediate protocol codec
|
||||
pub struct SecureCodec {
|
||||
max_frame_size: usize,
|
||||
rng: Arc<SecureRandom>,
|
||||
}
|
||||
|
||||
impl SecureCodec {
|
||||
pub fn new() -> Self {
|
||||
pub fn new(rng: Arc<SecureRandom>) -> Self {
|
||||
Self {
|
||||
max_frame_size: 16 * 1024 * 1024,
|
||||
rng,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SecureCodec {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
Self::new(Arc::new(SecureRandom::new()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -474,7 +479,7 @@ impl Encoder<Frame> for SecureCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
encode_secure(&frame, dst)
|
||||
encode_secure(&frame, dst, &self.rng)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -485,7 +490,7 @@ impl FrameCodecTrait for SecureCodec {
|
||||
|
||||
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
|
||||
let before = dst.len();
|
||||
encode_secure(frame, dst)?;
|
||||
encode_secure(frame, dst, &self.rng)?;
|
||||
Ok(dst.len() - before)
|
||||
}
|
||||
|
||||
@@ -506,6 +511,8 @@ mod tests {
|
||||
use tokio_util::codec::{FramedRead, FramedWrite};
|
||||
use tokio::io::duplex;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use crate::crypto::SecureRandom;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_framed_abridged() {
|
||||
@@ -541,8 +548,8 @@ mod tests {
|
||||
async fn test_framed_secure() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let mut writer = FramedWrite::new(client, SecureCodec::new());
|
||||
let mut reader = FramedRead::new(server, SecureCodec::new());
|
||||
let mut writer = FramedWrite::new(client, SecureCodec::new(Arc::new(SecureRandom::new())));
|
||||
let mut reader = FramedRead::new(server, SecureCodec::new(Arc::new(SecureRandom::new())));
|
||||
|
||||
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
let frame = Frame::new(original.clone());
|
||||
@@ -557,8 +564,8 @@ mod tests {
|
||||
for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag));
|
||||
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag));
|
||||
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
|
||||
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
|
||||
|
||||
// Use 4-byte aligned data for abridged compatibility
|
||||
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
@@ -607,7 +614,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_frame_too_large() {
|
||||
let mut codec = FrameCodec::new(ProtoTag::Intermediate)
|
||||
let mut codec = FrameCodec::new(ProtoTag::Intermediate, Arc::new(SecureRandom::new()))
|
||||
.with_max_frame_size(100);
|
||||
|
||||
// Create a "frame" that claims to be very large
|
||||
|
||||
@@ -4,8 +4,8 @@ use bytes::{Bytes, BytesMut};
|
||||
use std::io::{Error, ErrorKind, Result};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||
use crate::protocol::constants::*;
|
||||
use crate::crypto::crc32;
|
||||
use crate::crypto::random::SECURE_RANDOM;
|
||||
use crate::crypto::{crc32, SecureRandom};
|
||||
use std::sync::Arc;
|
||||
use super::traits::{FrameMeta, LayeredStream};
|
||||
|
||||
// ============= Abridged (Compact) Frame =============
|
||||
@@ -251,11 +251,12 @@ impl<R> LayeredStream<R> for SecureIntermediateFrameReader<R> {
|
||||
/// Writer for secure intermediate MTProto framing
|
||||
pub struct SecureIntermediateFrameWriter<W> {
|
||||
upstream: W,
|
||||
rng: Arc<SecureRandom>,
|
||||
}
|
||||
|
||||
impl<W> SecureIntermediateFrameWriter<W> {
|
||||
pub fn new(upstream: W) -> Self {
|
||||
Self { upstream }
|
||||
pub fn new(upstream: W, rng: Arc<SecureRandom>) -> Self {
|
||||
Self { upstream, rng }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,8 +268,8 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
|
||||
}
|
||||
|
||||
// Add random padding (0-3 bytes)
|
||||
let padding_len = SECURE_RANDOM.range(4);
|
||||
let padding = SECURE_RANDOM.bytes(padding_len);
|
||||
let padding_len = self.rng.range(4);
|
||||
let padding = self.rng.bytes(padding_len);
|
||||
|
||||
let total_len = data.len() + padding_len;
|
||||
let len_bytes = (total_len as u32).to_le_bytes();
|
||||
@@ -454,11 +455,11 @@ pub enum FrameWriterKind<W> {
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
|
||||
pub fn new(upstream: W, proto_tag: ProtoTag) -> Self {
|
||||
pub fn new(upstream: W, proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
|
||||
match proto_tag {
|
||||
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
|
||||
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
|
||||
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream)),
|
||||
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream, rng)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -483,6 +484,8 @@ impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::duplex;
|
||||
use std::sync::Arc;
|
||||
use crate::crypto::SecureRandom;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_abridged_roundtrip() {
|
||||
@@ -539,7 +542,7 @@ mod tests {
|
||||
async fn test_secure_intermediate_padding() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = SecureIntermediateFrameWriter::new(client);
|
||||
let mut writer = SecureIntermediateFrameWriter::new(client, Arc::new(SecureRandom::new()));
|
||||
let mut reader = SecureIntermediateFrameReader::new(server);
|
||||
|
||||
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
|
||||
@@ -572,7 +575,7 @@ mod tests {
|
||||
async fn test_frame_reader_kind() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate);
|
||||
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate, Arc::new(SecureRandom::new()));
|
||||
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
|
||||
|
||||
let data = vec![1u8, 2, 3, 4];
|
||||
|
||||
Reference in New Issue
Block a user