diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index 68cd3dc..520b6ea 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -8,6 +8,8 @@ use crate::crypto::{sha256_hmac, SecureRandom}; use crate::error::{ProxyError, Result}; use super::constants::*; use std::time::{SystemTime, UNIX_EPOCH}; +use num_bigint::BigUint; +use num_traits::One; // ============= Public Constants ============= @@ -311,13 +313,27 @@ pub fn validate_tls_handshake( None } +fn curve25519_prime() -> BigUint { + (BigUint::one() << 255) - BigUint::from(19u32) +} + /// Generate a fake X25519 public key for TLS /// -/// This generates random bytes that look like a valid X25519 public key. -/// Since we're not doing real TLS, the actual cryptographic properties don't matter. +/// Produces a quadratic residue mod p = 2^255 - 19 by computing n² mod p, +/// which matches Python/C behavior and avoids DPI fingerprinting. pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] { - let bytes = rng.bytes(32); - bytes.try_into().unwrap() + let mut n_bytes = [0u8; 32]; + n_bytes.copy_from_slice(&rng.bytes(32)); + + let n = BigUint::from_bytes_le(&n_bytes); + let p = curve25519_prime(); + let pk = (&n * &n) % &p; + + let mut out = pk.to_bytes_le(); + out.resize(32, 0); + let mut result = [0u8; 32]; + result.copy_from_slice(&out[..32]); + result } /// Build TLS ServerHello response @@ -498,6 +514,17 @@ mod tests { assert_eq!(key2.len(), 32); assert_ne!(key1, key2); // Should be random } + + #[test] + fn test_fake_x25519_key_is_quadratic_residue() { + let rng = SecureRandom::new(); + let key = gen_fake_x25519_key(&rng); + let p = curve25519_prime(); + let k_num = BigUint::from_bytes_le(&key); + let exponent = (&p - BigUint::one()) >> 1; + let legendre = k_num.modpow(&exponent, &p); + assert_eq!(legendre, BigUint::one()); + } #[test] fn test_tls_extension_builder() { @@ -641,4 +668,4 @@ mod tests { // Should return None (no match) but not panic assert!(result.is_none()); } -} \ No newline at end of file +} diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 7271e8c..041e7cb 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -339,6 +339,7 @@ impl RunningClientHandler { config, buffer_pool, local_addr, + rng, ) .await; } diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 46c004c..3cce39e 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -139,6 +139,8 @@ async fn do_tg_handshake_static( success.dc_idx, &success.dec_key, success.dec_iv, + &success.enc_key, + success.enc_iv, rng, config.general.fast_mode, ); diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index ab8e70c..0023b7a 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -70,7 +70,7 @@ where let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]; let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN]; - if replay_checker.check_tls_digest(digest_half) { + if replay_checker.check_and_add_tls_digest(digest_half) { warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); return HandshakeResult::BadClient { reader, writer }; } @@ -122,8 +122,6 @@ where return HandshakeResult::Error(ProxyError::Io(e)); } - replay_checker.add_tls_digest(digest_half); - info!( peer = %peer, user = %validation.user, @@ -155,7 +153,7 @@ where let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; - if replay_checker.check_handshake(dec_prekey_iv) { + if replay_checker.check_and_add_handshake(dec_prekey_iv) { warn!(peer = %peer, "MTProto replay attack detected"); return HandshakeResult::BadClient { reader, writer }; } @@ -216,8 +214,6 @@ where let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); - replay_checker.add_handshake(dec_prekey_iv); - let encryptor = AesCtr::new(&enc_key, enc_iv); let success = HandshakeSuccess { @@ -256,8 +252,10 @@ where pub fn generate_tg_nonce( proto_tag: ProtoTag, dc_idx: i16, - client_dec_key: &[u8; 32], - client_dec_iv: u128, + _client_dec_key: &[u8; 32], + _client_dec_iv: u128, + client_enc_key: &[u8; 32], + client_enc_iv: u128, rng: &SecureRandom, fast_mode: bool, ) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) { @@ -278,9 +276,11 @@ pub fn generate_tg_nonce( nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); if fast_mode { - nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key); - nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN] - .copy_from_slice(&client_dec_iv.to_be_bytes()); + let mut key_iv = Vec::with_capacity(KEY_LEN + IV_LEN); + key_iv.extend_from_slice(client_enc_key); + key_iv.extend_from_slice(&client_enc_iv.to_be_bytes()); + key_iv.reverse(); // Python/C behavior: reversed enc_key+enc_iv in nonce + nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&key_iv); } let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; @@ -332,10 +332,21 @@ mod tests { fn test_generate_tg_nonce() { let client_dec_key = [0x42u8; 32]; let client_dec_iv = 12345u128; + let client_enc_key = [0x24u8; 32]; + let client_enc_iv = 54321u128; let rng = SecureRandom::new(); let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = - generate_tg_nonce(ProtoTag::Secure, 2, &client_dec_key, client_dec_iv, &rng, false); + generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_dec_key, + client_dec_iv, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); assert_eq!(nonce.len(), HANDSHAKE_LEN); @@ -347,10 +358,21 @@ mod tests { fn test_encrypt_tg_nonce() { let client_dec_key = [0x42u8; 32]; let client_dec_iv = 12345u128; + let client_enc_key = [0x24u8; 32]; + let client_enc_iv = 54321u128; let rng = SecureRandom::new(); let (nonce, _, _, _, _) = - generate_tg_nonce(ProtoTag::Secure, 2, &client_dec_key, client_dec_iv, &rng, false); + generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_dec_key, + client_dec_iv, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); let encrypted = encrypt_tg_nonce(&nonce); @@ -379,4 +401,4 @@ mod tests { drop(success); // Drop impl zeroizes key material without panic } -} \ No newline at end of file +} diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 0882d0e..279ae05 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -5,6 +5,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::{debug, info, trace}; use crate::config::ProxyConfig; +use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; use crate::proxy::handshake::HandshakeSuccess; @@ -21,6 +22,7 @@ pub(crate) async fn handle_via_middle_proxy( _config: Arc, _buffer_pool: Arc, local_addr: SocketAddr, + rng: Arc, ) -> Result<()> where R: AsyncRead + Unpin + Send + 'static, @@ -58,16 +60,17 @@ where tokio::select! { client_frame = read_client_payload(&mut crypto_reader, proto_tag) => { match client_frame { - Ok(Some(payload)) => { + Ok(Some((payload, quickack))) => { trace!(conn_id, bytes = payload.len(), "C->ME frame"); stats.add_user_octets_from(&user, payload.len() as u64); + let flags = if quickack { proto_flags | RPC_FLAG_QUICKACK } else { proto_flags }; me_pool.send_proxy_req( conn_id, success.dc_idx, peer, translated_local_addr, &payload, - proto_flags, + flags, ).await?; } Ok(None) => { @@ -83,7 +86,7 @@ where Some(MeResponse::Data { flags, data }) => { trace!(conn_id, bytes = data.len(), flags, "ME->C data"); stats.add_user_octets_to(&user, data.len() as u64); - write_client_payload(&mut crypto_writer, proto_tag, flags, &data).await?; + write_client_payload(&mut crypto_writer, proto_tag, flags, &data, rng.as_ref()).await?; } Some(MeResponse::Ack(confirm)) => { trace!(conn_id, confirm, "ME->C quickack"); @@ -111,11 +114,11 @@ where async fn read_client_payload( client_reader: &mut CryptoReader, proto_tag: ProtoTag, -) -> Result>> +) -> Result, bool)>> where R: AsyncRead + Unpin + Send + 'static, { - let len = match proto_tag { + let (len, quickack) = match proto_tag { ProtoTag::Abridged => { let mut first = [0u8; 1]; match client_reader.read_exact(&mut first).await { @@ -124,6 +127,7 @@ where Err(e) => return Err(ProxyError::Io(e)), } + let quickack = (first[0] & 0x80) != 0; let len_words = if (first[0] & 0x7f) == 0x7f { let mut ext = [0u8; 3]; client_reader @@ -135,9 +139,10 @@ where (first[0] & 0x7f) as usize }; - len_words + let len = len_words .checked_mul(4) - .ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))? + .ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?; + (len, quickack) } ProtoTag::Intermediate | ProtoTag::Secure => { let mut len_buf = [0u8; 4]; @@ -146,7 +151,8 @@ where Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), Err(e) => return Err(ProxyError::Io(e)), } - (u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize + let quickack = (len_buf[3] & 0x80) != 0; + ((u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize, quickack) } }; @@ -159,7 +165,15 @@ where .read_exact(&mut payload) .await .map_err(ProxyError::Io)?; - Ok(Some(payload)) + + // Secure Intermediate: remove random padding (last len%4 bytes) + if proto_tag == ProtoTag::Secure { + let rem = len % 4; + if rem != 0 && payload.len() >= rem { + payload.truncate(len - rem); + } + } + Ok(Some((payload, quickack))) } async fn write_client_payload( @@ -167,6 +181,7 @@ async fn write_client_payload( proto_tag: ProtoTag, flags: u32, data: &[u8], + rng: &SecureRandom, ) -> Result<()> where W: AsyncWrite + Unpin + Send + 'static, @@ -215,7 +230,12 @@ where .map_err(ProxyError::Io)?; } ProtoTag::Intermediate | ProtoTag::Secure => { - let mut len = data.len() as u32; + let padding_len = if proto_tag == ProtoTag::Secure { + (rng.bytes(1)[0] % 4) as usize + } else { + 0 + }; + let mut len = (data.len() + padding_len) as u32; if quickack { len |= 0x8000_0000; } @@ -227,6 +247,13 @@ where .write_all(data) .await .map_err(ProxyError::Io)?; + if padding_len > 0 { + let pad = rng.bytes(padding_len); + client_writer + .write_all(&pad) + .await + .map_err(ProxyError::Io)?; + } } } diff --git a/src/stats/mod.rs b/src/stats/mod.rs index fb30742..5c3a084 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -212,28 +212,41 @@ impl ReplayChecker { (hasher.finish() as usize) & self.shard_mask } - fn check(&self, data: &[u8]) -> bool { + fn check_and_add_internal(&self, data: &[u8]) -> bool { self.checks.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); let mut shard = self.shards[idx].lock(); - let found = shard.check(data, Instant::now(), self.window); + let now = Instant::now(); + let found = shard.check(data, now, self.window); if found { self.hits.fetch_add(1, Ordering::Relaxed); + } else { + shard.add(data, now, self.window); + self.additions.fetch_add(1, Ordering::Relaxed); } found } - fn add(&self, data: &[u8]) { + fn add_only(&self, data: &[u8]) { self.additions.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); let mut shard = self.shards[idx].lock(); shard.add(data, Instant::now(), self.window); } - pub fn check_handshake(&self, data: &[u8]) -> bool { self.check(data) } - pub fn add_handshake(&self, data: &[u8]) { self.add(data) } - pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check(data) } - pub fn add_tls_digest(&self, data: &[u8]) { self.add(data) } + pub fn check_and_add_handshake(&self, data: &[u8]) -> bool { + self.check_and_add_internal(data) + } + + pub fn check_and_add_tls_digest(&self, data: &[u8]) -> bool { + self.check_and_add_internal(data) + } + + // Compatibility helpers (non-atomic split operations) — prefer check_and_add_*. + pub fn check_handshake(&self, data: &[u8]) -> bool { self.check_and_add_handshake(data) } + pub fn add_handshake(&self, data: &[u8]) { self.add_only(data) } + pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check_and_add_tls_digest(data) } + pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data) } pub fn stats(&self) -> ReplayStats { let mut total_entries = 0; @@ -326,10 +339,9 @@ mod tests { #[test] fn test_replay_checker_basic() { let checker = ReplayChecker::new(100, Duration::from_secs(60)); - assert!(!checker.check_handshake(b"test1")); - checker.add_handshake(b"test1"); - assert!(checker.check_handshake(b"test1")); - assert!(!checker.check_handshake(b"test2")); + assert!(!checker.check_handshake(b"test1")); // first time, inserts + assert!(checker.check_handshake(b"test1")); // duplicate + assert!(!checker.check_handshake(b"test2")); // new key inserts } #[test] @@ -343,7 +355,7 @@ mod tests { #[test] fn test_replay_checker_expiration() { let checker = ReplayChecker::new(100, Duration::from_millis(50)); - checker.add_handshake(b"expire"); + assert!(!checker.check_handshake(b"expire")); assert!(checker.check_handshake(b"expire")); std::thread::sleep(Duration::from_millis(100)); assert!(!checker.check_handshake(b"expire")); @@ -352,25 +364,25 @@ mod tests { #[test] fn test_replay_checker_stats() { let checker = ReplayChecker::new(100, Duration::from_secs(60)); - checker.add_handshake(b"k1"); - checker.add_handshake(b"k2"); - checker.check_handshake(b"k1"); - checker.check_handshake(b"k3"); + assert!(!checker.check_handshake(b"k1")); + assert!(!checker.check_handshake(b"k2")); + assert!(checker.check_handshake(b"k1")); + assert!(!checker.check_handshake(b"k3")); let stats = checker.stats(); - assert_eq!(stats.total_additions, 2); - assert_eq!(stats.total_checks, 2); + assert_eq!(stats.total_additions, 3); + assert_eq!(stats.total_checks, 4); assert_eq!(stats.total_hits, 1); } #[test] fn test_replay_checker_many_keys() { - let checker = ReplayChecker::new(1000, Duration::from_secs(60)); + let checker = ReplayChecker::new(10_000, Duration::from_secs(60)); for i in 0..500u32 { - checker.add(&i.to_le_bytes()); + checker.add_only(&i.to_le_bytes()); } for i in 0..500u32 { - assert!(checker.check(&i.to_le_bytes())); + assert!(checker.check_handshake(&i.to_le_bytes())); } assert_eq!(checker.stats().total_entries, 500); } -} \ No newline at end of file +} diff --git a/src/stream/buffer_pool.rs b/src/stream/buffer_pool.rs index ac4a5f9..0de5532 100644 --- a/src/stream/buffer_pool.rs +++ b/src/stream/buffer_pool.rs @@ -381,9 +381,14 @@ mod tests { // Add a buffer to pool pool.preallocate(1); - // Now try_get should succeed - assert!(pool.try_get().is_some()); + // Now try_get should succeed once while the buffer is held + let buf = pool.try_get(); + assert!(buf.is_some()); + // While buffer is held, pool is empty assert!(pool.try_get().is_none()); + // Drop buffer -> returns to pool, should be obtainable again + drop(buf); + assert!(pool.try_get().is_some()); } #[test] @@ -448,4 +453,4 @@ mod tests { // All buffers should be returned assert!(stats.pooled > 0); } -} \ No newline at end of file +} diff --git a/src/stream/tls_stream.rs b/src/stream/tls_stream.rs index a4edf58..6a3c1d6 100644 --- a/src/stream/tls_stream.rs +++ b/src/stream/tls_stream.rs @@ -32,7 +32,7 @@ //! and uploads from iOS will break (media/file sending), while small traffic //! may still work. -use bytes::{Bytes, BytesMut, BufMut}; +use bytes::{Bytes, BytesMut}; use std::io::{self, Error, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -51,9 +51,10 @@ use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer}; /// TLS record header size (type + version + length) const TLS_HEADER_SIZE: usize = 5; -/// Maximum TLS fragment size per spec (plaintext fragment). -/// We use this for *outgoing* chunking, because we build plain ApplicationData records. -const MAX_TLS_PAYLOAD: usize = 16384; +/// Maximum TLS fragment size we emit for Application Data. +/// Real TLS 1.3 ciphertexts often add ~16-24 bytes AEAD overhead, so to mimic +/// on-the-wire record sizes we allow up to 16384 + 24 bytes of plaintext. +const MAX_TLS_PAYLOAD: usize = 16384 + 24; /// Maximum pending write buffer for one record remainder. /// Note: we never queue unlimited amount of data here; state holds at most one record. @@ -918,10 +919,8 @@ mod tests { let reader = ChunkedReader::new(&record, 100); let mut tls_reader = FakeTlsReader::new(reader); - let mut buf = vec![0u8; payload.len()]; - tls_reader.read_exact(&mut buf).await.unwrap(); - - assert_eq!(&buf, payload); + let buf = tls_reader.read_exact(payload.len()).await.unwrap(); + assert_eq!(&buf[..], payload); } #[tokio::test] @@ -935,13 +934,11 @@ mod tests { let reader = ChunkedReader::new(&data, 100); let mut tls_reader = FakeTlsReader::new(reader); - let mut buf1 = vec![0u8; payload1.len()]; - tls_reader.read_exact(&mut buf1).await.unwrap(); - assert_eq!(&buf1, payload1); + let buf1 = tls_reader.read_exact(payload1.len()).await.unwrap(); + assert_eq!(&buf1[..], payload1); - let mut buf2 = vec![0u8; payload2.len()]; - tls_reader.read_exact(&mut buf2).await.unwrap(); - assert_eq!(&buf2, payload2); + let buf2 = tls_reader.read_exact(payload2.len()).await.unwrap(); + assert_eq!(&buf2[..], payload2); } #[tokio::test] @@ -953,10 +950,9 @@ mod tests { let reader = ChunkedReader::new(&record, 1); // 1 byte at a time! let mut tls_reader = FakeTlsReader::new(reader); - let mut buf = vec![0u8; payload.len()]; - tls_reader.read_exact(&mut buf).await.unwrap(); + let buf = tls_reader.read_exact(payload.len()).await.unwrap(); - assert_eq!(&buf, payload); + assert_eq!(&buf[..], payload); } #[tokio::test] @@ -967,10 +963,9 @@ mod tests { let reader = ChunkedReader::new(&record, 7); // Awkward chunk size let mut tls_reader = FakeTlsReader::new(reader); - let mut buf = vec![0u8; payload.len()]; - tls_reader.read_exact(&mut buf).await.unwrap(); + let buf = tls_reader.read_exact(payload.len()).await.unwrap(); - assert_eq!(&buf, payload); + assert_eq!(&buf[..], payload); } #[tokio::test] @@ -983,10 +978,9 @@ mod tests { let reader = ChunkedReader::new(&data, 100); let mut tls_reader = FakeTlsReader::new(reader); - let mut buf = vec![0u8; payload.len()]; - tls_reader.read_exact(&mut buf).await.unwrap(); + let buf = tls_reader.read_exact(payload.len()).await.unwrap(); - assert_eq!(&buf, payload); + assert_eq!(&buf[..], payload); } #[tokio::test] @@ -1000,10 +994,9 @@ mod tests { let reader = ChunkedReader::new(&data, 3); // Small chunks let mut tls_reader = FakeTlsReader::new(reader); - let mut buf = vec![0u8; payload.len()]; - tls_reader.read_exact(&mut buf).await.unwrap(); + let buf = tls_reader.read_exact(payload.len()).await.unwrap(); - assert_eq!(&buf, payload); + assert_eq!(&buf[..], payload); } #[tokio::test] @@ -1244,4 +1237,4 @@ mod tests { let bytes = header.to_bytes(); assert_eq!(bytes, [0x17, 0x03, 0x03, 0x12, 0x34]); } -} \ No newline at end of file +} diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 650a029..f1ef596 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -234,7 +234,10 @@ impl MePool { let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) = match (server_ip, client_ip) { - (IpMaterial::V4(srv), IpMaterial::V4(clt)) => { + // IPv4: reverse byte order for KDF (Python/C reference behavior) + (IpMaterial::V4(mut srv), IpMaterial::V4(mut clt)) => { + srv.reverse(); + clt.reverse(); (Some(srv), Some(clt), None, None, clt, srv) } (IpMaterial::V6(srv), IpMaterial::V6(clt)) => {