Antireplay on sliding window + SecureRandom

This commit is contained in:
Alexey
2026-02-07 18:26:44 +03:00
parent 5876f0c4d5
commit b9428d9780
12 changed files with 171 additions and 76 deletions

View File

@@ -14,6 +14,7 @@ fn default_port() -> u16 { 443 }
fn default_tls_domain() -> String { "www.google.com".to_string() }
fn default_mask_port() -> u16 { 443 }
fn default_replay_check_len() -> usize { 65536 }
fn default_replay_window_secs() -> u64 { 1800 }
fn default_handshake_timeout() -> u64 { 15 }
fn default_connect_timeout() -> u64 { 10 }
fn default_keepalive() -> u64 { 60 }
@@ -187,6 +188,9 @@ pub struct AccessConfig {
#[serde(default = "default_replay_check_len")]
pub replay_check_len: usize,
#[serde(default = "default_replay_window_secs")]
pub replay_window_secs: u64,
#[serde(default)]
pub ignore_time_skew: bool,
}
@@ -201,6 +205,7 @@ impl Default for AccessConfig {
user_expirations: HashMap::new(),
user_data_quota: HashMap::new(),
replay_check_len: default_replay_check_len(),
replay_window_secs: default_replay_window_secs(),
ignore_time_skew: false,
}
}

View File

@@ -6,4 +6,4 @@ pub mod random;
pub use aes::{AesCtr, AesCbc};
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32};
pub use random::{SecureRandom, SECURE_RANDOM};
pub use random::SecureRandom;

View File

@@ -4,11 +4,6 @@ use rand::{Rng, RngCore, SeedableRng};
use rand::rngs::StdRng;
use parking_lot::Mutex;
use crate::crypto::AesCtr;
use once_cell::sync::Lazy;
/// Global secure random instance
pub static SECURE_RANDOM: Lazy<SecureRandom> = Lazy::new(SecureRandom::new);
/// Cryptographically secure PRNG with AES-CTR
pub struct SecureRandom {
inner: Mutex<SecureRandomInner>,

View File

@@ -21,6 +21,7 @@ mod util;
use crate::config::ProxyConfig;
use crate::proxy::ClientHandler;
use crate::stats::{Stats, ReplayChecker};
use crate::crypto::SecureRandom;
use crate::transport::{create_listener, ListenOptions, UpstreamManager};
use crate::util::ip::detect_ip;
use crate::stream::BufferPool;
@@ -68,10 +69,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let config = Arc::new(config);
let stats = Arc::new(Stats::new());
let rng = Arc::new(SecureRandom::new());
// Initialize global ReplayChecker
// Using sharded implementation for better concurrency
let replay_checker = Arc::new(ReplayChecker::new(config.access.replay_check_len));
let replay_checker = Arc::new(ReplayChecker::new(
config.access.replay_check_len,
Duration::from_secs(config.access.replay_window_secs),
));
// Initialize Upstream Manager
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
@@ -166,6 +171,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
tokio::spawn(async move {
loop {
@@ -176,6 +182,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
tokio::spawn(async move {
if let Err(e) = ClientHandler::new(
@@ -185,7 +192,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
stats,
upstream_manager,
replay_checker,
buffer_pool
buffer_pool,
rng
).run().await {
// Log only relevant errors
}

View File

@@ -4,7 +4,7 @@
//! for domain fronting. The handshake looks like valid TLS 1.3 but
//! actually carries MTProto authentication data.
use crate::crypto::{sha256_hmac, random::SECURE_RANDOM};
use crate::crypto::{sha256_hmac, SecureRandom};
use crate::error::{ProxyError, Result};
use super::constants::*;
use std::time::{SystemTime, UNIX_EPOCH};
@@ -315,8 +315,8 @@ pub fn validate_tls_handshake(
///
/// This generates random bytes that look like a valid X25519 public key.
/// Since we're not doing real TLS, the actual cryptographic properties don't matter.
pub fn gen_fake_x25519_key() -> [u8; 32] {
let bytes = SECURE_RANDOM.bytes(32);
pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] {
let bytes = rng.bytes(32);
bytes.try_into().unwrap()
}
@@ -333,8 +333,9 @@ pub fn build_server_hello(
client_digest: &[u8; TLS_DIGEST_LEN],
session_id: &[u8],
fake_cert_len: usize,
rng: &SecureRandom,
) -> Vec<u8> {
let x25519_key = gen_fake_x25519_key();
let x25519_key = gen_fake_x25519_key(rng);
// Build ServerHello
let server_hello = ServerHelloBuilder::new(session_id.to_vec())
@@ -351,7 +352,7 @@ pub fn build_server_hello(
];
// Build fake certificate (Application Data record)
let fake_cert = SECURE_RANDOM.bytes(fake_cert_len);
let fake_cert = rng.bytes(fake_cert_len);
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
app_data_record.push(TLS_RECORD_APPLICATION);
app_data_record.extend_from_slice(&TLS_VERSION);
@@ -489,8 +490,9 @@ mod tests {
#[test]
fn test_gen_fake_x25519_key() {
let key1 = gen_fake_x25519_key();
let key2 = gen_fake_x25519_key();
let rng = SecureRandom::new();
let key1 = gen_fake_x25519_key(&rng);
let key2 = gen_fake_x25519_key(&rng);
assert_eq!(key1.len(), 32);
assert_eq!(key2.len(), 32);
@@ -545,7 +547,8 @@ mod tests {
let client_digest = [0x42u8; 32];
let session_id = vec![0xAA; 32];
let response = build_server_hello(secret, &client_digest, &session_id, 2048);
let rng = SecureRandom::new();
let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng);
// Should have at least 3 records
assert!(response.len() > 100);
@@ -577,8 +580,9 @@ mod tests {
let client_digest = [0x42u8; 32];
let session_id = vec![0xAA; 32];
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024);
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024);
let rng = SecureRandom::new();
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
// Digest position should have non-zero data
let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];

View File

@@ -15,7 +15,7 @@ use crate::protocol::tls;
use crate::stats::{Stats, ReplayChecker};
use crate::transport::{configure_client_socket, UpstreamManager};
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool};
use crate::crypto::AesCtr;
use crate::crypto::{AesCtr, SecureRandom};
// Use absolute paths to avoid confusion
use crate::proxy::handshake::{
@@ -37,6 +37,7 @@ pub struct RunningClientHandler {
replay_checker: Arc<ReplayChecker>,
upstream_manager: Arc<UpstreamManager>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
}
impl ClientHandler {
@@ -49,6 +50,7 @@ impl ClientHandler {
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
) -> RunningClientHandler {
RunningClientHandler {
stream,
@@ -58,6 +60,7 @@ impl ClientHandler {
replay_checker,
upstream_manager,
buffer_pool,
rng,
}
}
}
@@ -168,6 +171,7 @@ impl RunningClientHandler {
peer,
&config,
&replay_checker,
&self.rng,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
@@ -211,7 +215,8 @@ impl RunningClientHandler {
self.upstream_manager,
self.stats,
self.config,
buffer_pool
buffer_pool,
self.rng
).await
}
@@ -272,7 +277,8 @@ impl RunningClientHandler {
self.upstream_manager,
self.stats,
self.config,
buffer_pool
buffer_pool,
self.rng
).await
}
@@ -285,6 +291,7 @@ impl RunningClientHandler {
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
@@ -321,6 +328,7 @@ impl RunningClientHandler {
tg_stream,
&success,
&config,
rng.as_ref(),
).await?;
debug!(peer = %success.peer, "Telegram handshake complete, starting relay");
@@ -401,12 +409,14 @@ impl RunningClientHandler {
mut stream: TcpStream,
success: &HandshakeSuccess,
config: &ProxyConfig,
rng: &SecureRandom,
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
// Generate nonce with keys for TG
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
success.proto_tag,
&success.dec_key, // Client's dec key
success.dec_iv,
rng,
config.general.fast_mode,
);

View File

@@ -4,8 +4,7 @@ use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, warn, trace, info};
use crate::crypto::{sha256, AesCtr};
use crate::crypto::random::SECURE_RANDOM;
use crate::crypto::{sha256, AesCtr, SecureRandom};
use crate::protocol::constants::*;
use crate::protocol::tls;
use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
@@ -42,6 +41,7 @@ pub async fn handle_tls_handshake<R, W>(
peer: SocketAddr,
config: &ProxyConfig,
replay_checker: &ReplayChecker,
rng: &SecureRandom,
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
where
R: AsyncRead + Unpin,
@@ -101,6 +101,7 @@ where
&validation.digest,
&validation.session_id,
config.censorship.fake_cert_len,
rng,
);
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
@@ -264,10 +265,11 @@ pub fn generate_tg_nonce(
proto_tag: ProtoTag,
client_dec_key: &[u8; 32],
client_dec_iv: u128,
rng: &SecureRandom,
fast_mode: bool,
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
loop {
let bytes = SECURE_RANDOM.bytes(HANDSHAKE_LEN);
let bytes = rng.bytes(HANDSHAKE_LEN);
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
@@ -323,8 +325,9 @@ mod tests {
let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128;
let rng = SecureRandom::new();
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) =
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false);
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false);
// Check length
assert_eq!(nonce.len(), HANDSHAKE_LEN);
@@ -339,8 +342,9 @@ mod tests {
let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128;
let rng = SecureRandom::new();
let (nonce, _, _, _, _) =
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false);
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false);
let encrypted = encrypt_tg_nonce(&nonce);

View File

@@ -2,13 +2,14 @@
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use std::time::{Instant, Duration};
use dashmap::DashMap;
use parking_lot::{RwLock, Mutex};
use lru::LruCache;
use std::num::NonZeroUsize;
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
use std::collections::VecDeque;
/// Thread-safe statistics
#[derive(Default)]
@@ -143,57 +144,112 @@ impl Stats {
}
}
/// Sharded Replay attack checker using LRU cache
/// Sharded Replay attack checker using LRU cache + sliding window
/// Uses multiple independent LRU caches to reduce lock contention
pub struct ReplayChecker {
shards: Vec<Mutex<LruCache<Vec<u8>, ()>>>,
shards: Vec<Mutex<ReplayShard>>,
shard_mask: usize,
window: Duration,
}
struct ReplayEntry {
seen_at: Instant,
}
struct ReplayShard {
cache: LruCache<Vec<u8>, ReplayEntry>,
queue: VecDeque<(Instant, Vec<u8>)>,
}
impl ReplayShard {
fn new(cap: NonZeroUsize) -> Self {
Self {
cache: LruCache::new(cap),
queue: VecDeque::with_capacity(cap.get()),
}
}
fn cleanup(&mut self, now: Instant, window: Duration) {
if window.is_zero() {
return;
}
let cutoff = now - window;
while let Some((ts, _)) = self.queue.front() {
if *ts >= cutoff {
break;
}
let (ts_old, key_old) = self.queue.pop_front().unwrap();
if let Some(entry) = self.cache.get(&key_old) {
if entry.seen_at <= ts_old {
self.cache.pop(&key_old);
}
}
}
}
}
impl ReplayChecker {
/// Create new replay checker with specified capacity per shard
/// Total capacity = capacity * num_shards
pub fn new(total_capacity: usize) -> Self {
pub fn new(total_capacity: usize, window: Duration) -> Self {
// Use 64 shards for good concurrency
let num_shards = 64;
let shard_capacity = (total_capacity / num_shards).max(1);
let cap = NonZeroUsize::new(shard_capacity).unwrap();
let mut shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
shards.push(Mutex::new(LruCache::new(cap)));
shards.push(Mutex::new(ReplayShard::new(cap)));
}
Self {
shards,
shard_mask: num_shards - 1,
window,
}
}
fn get_shard(&self, key: &[u8]) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
(hasher.finish() as usize) & self.shard_mask
}
fn check(&self, data: &[u8]) -> bool {
let shard_idx = self.get_shard(data);
let mut shard = self.shards[shard_idx].lock();
let now = Instant::now();
shard.cleanup(now, self.window);
let key = data.to_vec();
shard.cache.get(&key).is_some()
}
fn add(&self, data: &[u8]) {
let shard_idx = self.get_shard(data);
let mut shard = self.shards[shard_idx].lock();
let now = Instant::now();
shard.cleanup(now, self.window);
let key = data.to_vec();
shard.cache.put(key.clone(), ReplayEntry { seen_at: now });
shard.queue.push_back((now, key));
}
pub fn check_handshake(&self, data: &[u8]) -> bool {
let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().contains(&data.to_vec())
self.check(data)
}
pub fn add_handshake(&self, data: &[u8]) {
let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().put(data.to_vec(), ());
self.add(data)
}
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().contains(&data.to_vec())
self.check(data)
}
pub fn add_tls_digest(&self, data: &[u8]) {
let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().put(data.to_vec(), ());
self.add(data)
}
}
@@ -217,7 +273,7 @@ mod tests {
#[test]
fn test_replay_checker_sharding() {
let checker = ReplayChecker::new(100);
let checker = ReplayChecker::new(100, Duration::from_secs(60));
let data1 = b"test1";
let data2 = b"test2";

View File

@@ -5,8 +5,10 @@
use bytes::{Bytes, BytesMut};
use std::io::Result;
use std::sync::Arc;
use crate::protocol::constants::ProtoTag;
use crate::crypto::SecureRandom;
// ============= Frame Types =============
@@ -147,11 +149,11 @@ pub trait FrameCodec: Send + Sync {
// ============= Codec Factory =============
/// Create a frame codec for the given protocol tag
pub fn create_codec(proto_tag: ProtoTag) -> Box<dyn FrameCodec> {
pub fn create_codec(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Box<dyn FrameCodec> {
match proto_tag {
ProtoTag::Abridged => Box::new(crate::stream::frame_codec::AbridgedCodec::new()),
ProtoTag::Intermediate => Box::new(crate::stream::frame_codec::IntermediateCodec::new()),
ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new()),
ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new(rng)),
}
}

View File

@@ -5,9 +5,11 @@
use bytes::{Bytes, BytesMut, BufMut};
use std::io::{self, Error, ErrorKind};
use std::sync::Arc;
use tokio_util::codec::{Decoder, Encoder};
use crate::protocol::constants::ProtoTag;
use crate::crypto::SecureRandom;
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
// ============= Unified Codec =============
@@ -21,14 +23,17 @@ pub struct FrameCodec {
proto_tag: ProtoTag,
/// Maximum allowed frame size
max_frame_size: usize,
/// RNG for secure padding
rng: Arc<SecureRandom>,
}
impl FrameCodec {
/// Create a new codec for the given protocol
pub fn new(proto_tag: ProtoTag) -> Self {
pub fn new(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
Self {
proto_tag,
max_frame_size: 16 * 1024 * 1024, // 16MB default
rng,
}
}
@@ -64,7 +69,7 @@ impl Encoder<Frame> for FrameCodec {
match self.proto_tag {
ProtoTag::Abridged => encode_abridged(&frame, dst),
ProtoTag::Intermediate => encode_intermediate(&frame, dst),
ProtoTag::Secure => encode_secure(&frame, dst),
ProtoTag::Secure => encode_secure(&frame, dst, &self.rng),
}
}
}
@@ -288,9 +293,7 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
Ok(Some(Frame::with_meta(data, meta)))
}
fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
use crate::crypto::random::SECURE_RANDOM;
fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::Result<()> {
let data = &frame.data;
// Simple ACK: just send data
@@ -303,10 +306,10 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
// Generate padding to make length not divisible by 4
let padding_len = if data.len() % 4 == 0 {
// Add 1-3 bytes to make it non-aligned
(SECURE_RANDOM.range(3) + 1) as usize
(rng.range(3) + 1) as usize
} else {
// Already non-aligned, can add 0-3
SECURE_RANDOM.range(4) as usize
rng.range(4) as usize
};
let total_len = data.len() + padding_len;
@@ -321,7 +324,7 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
dst.extend_from_slice(data);
if padding_len > 0 {
let padding = SECURE_RANDOM.bytes(padding_len);
let padding = rng.bytes(padding_len);
dst.extend_from_slice(&padding);
}
@@ -445,19 +448,21 @@ impl FrameCodecTrait for IntermediateCodec {
/// Secure Intermediate protocol codec
pub struct SecureCodec {
max_frame_size: usize,
rng: Arc<SecureRandom>,
}
impl SecureCodec {
pub fn new() -> Self {
pub fn new(rng: Arc<SecureRandom>) -> Self {
Self {
max_frame_size: 16 * 1024 * 1024,
rng,
}
}
}
impl Default for SecureCodec {
fn default() -> Self {
Self::new()
Self::new(Arc::new(SecureRandom::new()))
}
}
@@ -474,7 +479,7 @@ impl Encoder<Frame> for SecureCodec {
type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
encode_secure(&frame, dst)
encode_secure(&frame, dst, &self.rng)
}
}
@@ -485,7 +490,7 @@ impl FrameCodecTrait for SecureCodec {
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
let before = dst.len();
encode_secure(frame, dst)?;
encode_secure(frame, dst, &self.rng)?;
Ok(dst.len() - before)
}
@@ -506,6 +511,8 @@ mod tests {
use tokio_util::codec::{FramedRead, FramedWrite};
use tokio::io::duplex;
use futures::{SinkExt, StreamExt};
use crate::crypto::SecureRandom;
use std::sync::Arc;
#[tokio::test]
async fn test_framed_abridged() {
@@ -541,8 +548,8 @@ mod tests {
async fn test_framed_secure() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, SecureCodec::new());
let mut reader = FramedRead::new(server, SecureCodec::new());
let mut writer = FramedWrite::new(client, SecureCodec::new(Arc::new(SecureRandom::new())));
let mut reader = FramedRead::new(server, SecureCodec::new(Arc::new(SecureRandom::new())));
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
let frame = Frame::new(original.clone());
@@ -557,8 +564,8 @@ mod tests {
for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag));
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag));
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
// Use 4-byte aligned data for abridged compatibility
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
@@ -607,7 +614,7 @@ mod tests {
#[test]
fn test_frame_too_large() {
let mut codec = FrameCodec::new(ProtoTag::Intermediate)
let mut codec = FrameCodec::new(ProtoTag::Intermediate, Arc::new(SecureRandom::new()))
.with_max_frame_size(100);
// Create a "frame" that claims to be very large

View File

@@ -4,8 +4,8 @@ use bytes::{Bytes, BytesMut};
use std::io::{Error, ErrorKind, Result};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use crate::protocol::constants::*;
use crate::crypto::crc32;
use crate::crypto::random::SECURE_RANDOM;
use crate::crypto::{crc32, SecureRandom};
use std::sync::Arc;
use super::traits::{FrameMeta, LayeredStream};
// ============= Abridged (Compact) Frame =============
@@ -251,11 +251,12 @@ impl<R> LayeredStream<R> for SecureIntermediateFrameReader<R> {
/// Writer for secure intermediate MTProto framing
pub struct SecureIntermediateFrameWriter<W> {
upstream: W,
rng: Arc<SecureRandom>,
}
impl<W> SecureIntermediateFrameWriter<W> {
pub fn new(upstream: W) -> Self {
Self { upstream }
pub fn new(upstream: W, rng: Arc<SecureRandom>) -> Self {
Self { upstream, rng }
}
}
@@ -267,8 +268,8 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
}
// Add random padding (0-3 bytes)
let padding_len = SECURE_RANDOM.range(4);
let padding = SECURE_RANDOM.bytes(padding_len);
let padding_len = self.rng.range(4);
let padding = self.rng.bytes(padding_len);
let total_len = data.len() + padding_len;
let len_bytes = (total_len as u32).to_le_bytes();
@@ -454,11 +455,11 @@ pub enum FrameWriterKind<W> {
}
impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
pub fn new(upstream: W, proto_tag: ProtoTag) -> Self {
pub fn new(upstream: W, proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
match proto_tag {
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream)),
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream, rng)),
}
}
@@ -483,6 +484,8 @@ impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
mod tests {
use super::*;
use tokio::io::duplex;
use std::sync::Arc;
use crate::crypto::SecureRandom;
#[tokio::test]
async fn test_abridged_roundtrip() {
@@ -539,7 +542,7 @@ mod tests {
async fn test_secure_intermediate_padding() {
let (client, server) = duplex(1024);
let mut writer = SecureIntermediateFrameWriter::new(client);
let mut writer = SecureIntermediateFrameWriter::new(client, Arc::new(SecureRandom::new()));
let mut reader = SecureIntermediateFrameReader::new(server);
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
@@ -572,7 +575,7 @@ mod tests {
async fn test_frame_reader_kind() {
let (client, server) = duplex(1024);
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate);
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate, Arc::new(SecureRandom::new()));
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
let data = vec![1u8, 2, 3, 4];