Middle-End Drafts

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey
2026-02-15 12:30:40 +03:00
parent 427c7dd375
commit f2455c9cb1
9 changed files with 174 additions and 82 deletions

View File

@@ -8,6 +8,8 @@ use crate::crypto::{sha256_hmac, SecureRandom};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use super::constants::*; use super::constants::*;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
use num_bigint::BigUint;
use num_traits::One;
// ============= Public Constants ============= // ============= Public Constants =============
@@ -311,13 +313,27 @@ pub fn validate_tls_handshake(
None None
} }
fn curve25519_prime() -> BigUint {
(BigUint::one() << 255) - BigUint::from(19u32)
}
/// Generate a fake X25519 public key for TLS /// Generate a fake X25519 public key for TLS
/// ///
/// This generates random bytes that look like a valid X25519 public key. /// Produces a quadratic residue mod p = 2^255 - 19 by computing n² mod p,
/// Since we're not doing real TLS, the actual cryptographic properties don't matter. /// which matches Python/C behavior and avoids DPI fingerprinting.
pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] { pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] {
let bytes = rng.bytes(32); let mut n_bytes = [0u8; 32];
bytes.try_into().unwrap() n_bytes.copy_from_slice(&rng.bytes(32));
let n = BigUint::from_bytes_le(&n_bytes);
let p = curve25519_prime();
let pk = (&n * &n) % &p;
let mut out = pk.to_bytes_le();
out.resize(32, 0);
let mut result = [0u8; 32];
result.copy_from_slice(&out[..32]);
result
} }
/// Build TLS ServerHello response /// Build TLS ServerHello response
@@ -498,6 +514,17 @@ mod tests {
assert_eq!(key2.len(), 32); assert_eq!(key2.len(), 32);
assert_ne!(key1, key2); // Should be random assert_ne!(key1, key2); // Should be random
} }
#[test]
fn test_fake_x25519_key_is_quadratic_residue() {
let rng = SecureRandom::new();
let key = gen_fake_x25519_key(&rng);
let p = curve25519_prime();
let k_num = BigUint::from_bytes_le(&key);
let exponent = (&p - BigUint::one()) >> 1;
let legendre = k_num.modpow(&exponent, &p);
assert_eq!(legendre, BigUint::one());
}
#[test] #[test]
fn test_tls_extension_builder() { fn test_tls_extension_builder() {
@@ -641,4 +668,4 @@ mod tests {
// Should return None (no match) but not panic // Should return None (no match) but not panic
assert!(result.is_none()); assert!(result.is_none());
} }
} }

View File

@@ -339,6 +339,7 @@ impl RunningClientHandler {
config, config,
buffer_pool, buffer_pool,
local_addr, local_addr,
rng,
) )
.await; .await;
} }

View File

@@ -139,6 +139,8 @@ async fn do_tg_handshake_static(
success.dc_idx, success.dc_idx,
&success.dec_key, &success.dec_key,
success.dec_iv, success.dec_iv,
&success.enc_key,
success.enc_iv,
rng, rng,
config.general.fast_mode, config.general.fast_mode,
); );

View File

@@ -70,7 +70,7 @@ where
let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]; let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN];
let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN]; let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_tls_digest(digest_half) { if replay_checker.check_and_add_tls_digest(digest_half) {
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
@@ -122,8 +122,6 @@ where
return HandshakeResult::Error(ProxyError::Io(e)); return HandshakeResult::Error(ProxyError::Io(e));
} }
replay_checker.add_tls_digest(digest_half);
info!( info!(
peer = %peer, peer = %peer,
user = %validation.user, user = %validation.user,
@@ -155,7 +153,7 @@ where
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
if replay_checker.check_handshake(dec_prekey_iv) { if replay_checker.check_and_add_handshake(dec_prekey_iv) {
warn!(peer = %peer, "MTProto replay attack detected"); warn!(peer = %peer, "MTProto replay attack detected");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
@@ -216,8 +214,6 @@ where
let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
replay_checker.add_handshake(dec_prekey_iv);
let encryptor = AesCtr::new(&enc_key, enc_iv); let encryptor = AesCtr::new(&enc_key, enc_iv);
let success = HandshakeSuccess { let success = HandshakeSuccess {
@@ -256,8 +252,10 @@ where
pub fn generate_tg_nonce( pub fn generate_tg_nonce(
proto_tag: ProtoTag, proto_tag: ProtoTag,
dc_idx: i16, dc_idx: i16,
client_dec_key: &[u8; 32], _client_dec_key: &[u8; 32],
client_dec_iv: u128, _client_dec_iv: u128,
client_enc_key: &[u8; 32],
client_enc_iv: u128,
rng: &SecureRandom, rng: &SecureRandom,
fast_mode: bool, fast_mode: bool,
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) { ) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
@@ -278,9 +276,11 @@ pub fn generate_tg_nonce(
nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
if fast_mode { if fast_mode {
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key); let mut key_iv = Vec::with_capacity(KEY_LEN + IV_LEN);
nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN] key_iv.extend_from_slice(client_enc_key);
.copy_from_slice(&client_dec_iv.to_be_bytes()); key_iv.extend_from_slice(&client_enc_iv.to_be_bytes());
key_iv.reverse(); // Python/C behavior: reversed enc_key+enc_iv in nonce
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&key_iv);
} }
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
@@ -332,10 +332,21 @@ mod tests {
fn test_generate_tg_nonce() { fn test_generate_tg_nonce() {
let client_dec_key = [0x42u8; 32]; let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128; let client_dec_iv = 12345u128;
let client_enc_key = [0x24u8; 32];
let client_enc_iv = 54321u128;
let rng = SecureRandom::new(); let rng = SecureRandom::new();
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) =
generate_tg_nonce(ProtoTag::Secure, 2, &client_dec_key, client_dec_iv, &rng, false); generate_tg_nonce(
ProtoTag::Secure,
2,
&client_dec_key,
client_dec_iv,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
assert_eq!(nonce.len(), HANDSHAKE_LEN); assert_eq!(nonce.len(), HANDSHAKE_LEN);
@@ -347,10 +358,21 @@ mod tests {
fn test_encrypt_tg_nonce() { fn test_encrypt_tg_nonce() {
let client_dec_key = [0x42u8; 32]; let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128; let client_dec_iv = 12345u128;
let client_enc_key = [0x24u8; 32];
let client_enc_iv = 54321u128;
let rng = SecureRandom::new(); let rng = SecureRandom::new();
let (nonce, _, _, _, _) = let (nonce, _, _, _, _) =
generate_tg_nonce(ProtoTag::Secure, 2, &client_dec_key, client_dec_iv, &rng, false); generate_tg_nonce(
ProtoTag::Secure,
2,
&client_dec_key,
client_dec_iv,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
let encrypted = encrypt_tg_nonce(&nonce); let encrypted = encrypt_tg_nonce(&nonce);
@@ -379,4 +401,4 @@ mod tests {
drop(success); drop(success);
// Drop impl zeroizes key material without panic // Drop impl zeroizes key material without panic
} }
} }

View File

@@ -5,6 +5,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::{debug, info, trace}; use tracing::{debug, info, trace};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::proxy::handshake::HandshakeSuccess; use crate::proxy::handshake::HandshakeSuccess;
@@ -21,6 +22,7 @@ pub(crate) async fn handle_via_middle_proxy<R, W>(
_config: Arc<ProxyConfig>, _config: Arc<ProxyConfig>,
_buffer_pool: Arc<BufferPool>, _buffer_pool: Arc<BufferPool>,
local_addr: SocketAddr, local_addr: SocketAddr,
rng: Arc<SecureRandom>,
) -> Result<()> ) -> Result<()>
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
@@ -58,16 +60,17 @@ where
tokio::select! { tokio::select! {
client_frame = read_client_payload(&mut crypto_reader, proto_tag) => { client_frame = read_client_payload(&mut crypto_reader, proto_tag) => {
match client_frame { match client_frame {
Ok(Some(payload)) => { Ok(Some((payload, quickack))) => {
trace!(conn_id, bytes = payload.len(), "C->ME frame"); trace!(conn_id, bytes = payload.len(), "C->ME frame");
stats.add_user_octets_from(&user, payload.len() as u64); stats.add_user_octets_from(&user, payload.len() as u64);
let flags = if quickack { proto_flags | RPC_FLAG_QUICKACK } else { proto_flags };
me_pool.send_proxy_req( me_pool.send_proxy_req(
conn_id, conn_id,
success.dc_idx, success.dc_idx,
peer, peer,
translated_local_addr, translated_local_addr,
&payload, &payload,
proto_flags, flags,
).await?; ).await?;
} }
Ok(None) => { Ok(None) => {
@@ -83,7 +86,7 @@ where
Some(MeResponse::Data { flags, data }) => { Some(MeResponse::Data { flags, data }) => {
trace!(conn_id, bytes = data.len(), flags, "ME->C data"); trace!(conn_id, bytes = data.len(), flags, "ME->C data");
stats.add_user_octets_to(&user, data.len() as u64); stats.add_user_octets_to(&user, data.len() as u64);
write_client_payload(&mut crypto_writer, proto_tag, flags, &data).await?; write_client_payload(&mut crypto_writer, proto_tag, flags, &data, rng.as_ref()).await?;
} }
Some(MeResponse::Ack(confirm)) => { Some(MeResponse::Ack(confirm)) => {
trace!(conn_id, confirm, "ME->C quickack"); trace!(conn_id, confirm, "ME->C quickack");
@@ -111,11 +114,11 @@ where
async fn read_client_payload<R>( async fn read_client_payload<R>(
client_reader: &mut CryptoReader<R>, client_reader: &mut CryptoReader<R>,
proto_tag: ProtoTag, proto_tag: ProtoTag,
) -> Result<Option<Vec<u8>>> ) -> Result<Option<(Vec<u8>, bool)>>
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
{ {
let len = match proto_tag { let (len, quickack) = match proto_tag {
ProtoTag::Abridged => { ProtoTag::Abridged => {
let mut first = [0u8; 1]; let mut first = [0u8; 1];
match client_reader.read_exact(&mut first).await { match client_reader.read_exact(&mut first).await {
@@ -124,6 +127,7 @@ where
Err(e) => return Err(ProxyError::Io(e)), Err(e) => return Err(ProxyError::Io(e)),
} }
let quickack = (first[0] & 0x80) != 0;
let len_words = if (first[0] & 0x7f) == 0x7f { let len_words = if (first[0] & 0x7f) == 0x7f {
let mut ext = [0u8; 3]; let mut ext = [0u8; 3];
client_reader client_reader
@@ -135,9 +139,10 @@ where
(first[0] & 0x7f) as usize (first[0] & 0x7f) as usize
}; };
len_words let len = len_words
.checked_mul(4) .checked_mul(4)
.ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))? .ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?;
(len, quickack)
} }
ProtoTag::Intermediate | ProtoTag::Secure => { ProtoTag::Intermediate | ProtoTag::Secure => {
let mut len_buf = [0u8; 4]; let mut len_buf = [0u8; 4];
@@ -146,7 +151,8 @@ where
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ProxyError::Io(e)), Err(e) => return Err(ProxyError::Io(e)),
} }
(u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize let quickack = (len_buf[3] & 0x80) != 0;
((u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize, quickack)
} }
}; };
@@ -159,7 +165,15 @@ where
.read_exact(&mut payload) .read_exact(&mut payload)
.await .await
.map_err(ProxyError::Io)?; .map_err(ProxyError::Io)?;
Ok(Some(payload))
// Secure Intermediate: remove random padding (last len%4 bytes)
if proto_tag == ProtoTag::Secure {
let rem = len % 4;
if rem != 0 && payload.len() >= rem {
payload.truncate(len - rem);
}
}
Ok(Some((payload, quickack)))
} }
async fn write_client_payload<W>( async fn write_client_payload<W>(
@@ -167,6 +181,7 @@ async fn write_client_payload<W>(
proto_tag: ProtoTag, proto_tag: ProtoTag,
flags: u32, flags: u32,
data: &[u8], data: &[u8],
rng: &SecureRandom,
) -> Result<()> ) -> Result<()>
where where
W: AsyncWrite + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static,
@@ -215,7 +230,12 @@ where
.map_err(ProxyError::Io)?; .map_err(ProxyError::Io)?;
} }
ProtoTag::Intermediate | ProtoTag::Secure => { ProtoTag::Intermediate | ProtoTag::Secure => {
let mut len = data.len() as u32; let padding_len = if proto_tag == ProtoTag::Secure {
(rng.bytes(1)[0] % 4) as usize
} else {
0
};
let mut len = (data.len() + padding_len) as u32;
if quickack { if quickack {
len |= 0x8000_0000; len |= 0x8000_0000;
} }
@@ -227,6 +247,13 @@ where
.write_all(data) .write_all(data)
.await .await
.map_err(ProxyError::Io)?; .map_err(ProxyError::Io)?;
if padding_len > 0 {
let pad = rng.bytes(padding_len);
client_writer
.write_all(&pad)
.await
.map_err(ProxyError::Io)?;
}
} }
} }

View File

@@ -212,28 +212,41 @@ impl ReplayChecker {
(hasher.finish() as usize) & self.shard_mask (hasher.finish() as usize) & self.shard_mask
} }
fn check(&self, data: &[u8]) -> bool { fn check_and_add_internal(&self, data: &[u8]) -> bool {
self.checks.fetch_add(1, Ordering::Relaxed); self.checks.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data); let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock(); let mut shard = self.shards[idx].lock();
let found = shard.check(data, Instant::now(), self.window); let now = Instant::now();
let found = shard.check(data, now, self.window);
if found { if found {
self.hits.fetch_add(1, Ordering::Relaxed); self.hits.fetch_add(1, Ordering::Relaxed);
} else {
shard.add(data, now, self.window);
self.additions.fetch_add(1, Ordering::Relaxed);
} }
found found
} }
fn add(&self, data: &[u8]) { fn add_only(&self, data: &[u8]) {
self.additions.fetch_add(1, Ordering::Relaxed); self.additions.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data); let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock(); let mut shard = self.shards[idx].lock();
shard.add(data, Instant::now(), self.window); shard.add(data, Instant::now(), self.window);
} }
pub fn check_handshake(&self, data: &[u8]) -> bool { self.check(data) } pub fn check_and_add_handshake(&self, data: &[u8]) -> bool {
pub fn add_handshake(&self, data: &[u8]) { self.add(data) } self.check_and_add_internal(data)
pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check(data) } }
pub fn add_tls_digest(&self, data: &[u8]) { self.add(data) }
pub fn check_and_add_tls_digest(&self, data: &[u8]) -> bool {
self.check_and_add_internal(data)
}
// Compatibility helpers (non-atomic split operations) — prefer check_and_add_*.
pub fn check_handshake(&self, data: &[u8]) -> bool { self.check_and_add_handshake(data) }
pub fn add_handshake(&self, data: &[u8]) { self.add_only(data) }
pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check_and_add_tls_digest(data) }
pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data) }
pub fn stats(&self) -> ReplayStats { pub fn stats(&self) -> ReplayStats {
let mut total_entries = 0; let mut total_entries = 0;
@@ -326,10 +339,9 @@ mod tests {
#[test] #[test]
fn test_replay_checker_basic() { fn test_replay_checker_basic() {
let checker = ReplayChecker::new(100, Duration::from_secs(60)); let checker = ReplayChecker::new(100, Duration::from_secs(60));
assert!(!checker.check_handshake(b"test1")); assert!(!checker.check_handshake(b"test1")); // first time, inserts
checker.add_handshake(b"test1"); assert!(checker.check_handshake(b"test1")); // duplicate
assert!(checker.check_handshake(b"test1")); assert!(!checker.check_handshake(b"test2")); // new key inserts
assert!(!checker.check_handshake(b"test2"));
} }
#[test] #[test]
@@ -343,7 +355,7 @@ mod tests {
#[test] #[test]
fn test_replay_checker_expiration() { fn test_replay_checker_expiration() {
let checker = ReplayChecker::new(100, Duration::from_millis(50)); let checker = ReplayChecker::new(100, Duration::from_millis(50));
checker.add_handshake(b"expire"); assert!(!checker.check_handshake(b"expire"));
assert!(checker.check_handshake(b"expire")); assert!(checker.check_handshake(b"expire"));
std::thread::sleep(Duration::from_millis(100)); std::thread::sleep(Duration::from_millis(100));
assert!(!checker.check_handshake(b"expire")); assert!(!checker.check_handshake(b"expire"));
@@ -352,25 +364,25 @@ mod tests {
#[test] #[test]
fn test_replay_checker_stats() { fn test_replay_checker_stats() {
let checker = ReplayChecker::new(100, Duration::from_secs(60)); let checker = ReplayChecker::new(100, Duration::from_secs(60));
checker.add_handshake(b"k1"); assert!(!checker.check_handshake(b"k1"));
checker.add_handshake(b"k2"); assert!(!checker.check_handshake(b"k2"));
checker.check_handshake(b"k1"); assert!(checker.check_handshake(b"k1"));
checker.check_handshake(b"k3"); assert!(!checker.check_handshake(b"k3"));
let stats = checker.stats(); let stats = checker.stats();
assert_eq!(stats.total_additions, 2); assert_eq!(stats.total_additions, 3);
assert_eq!(stats.total_checks, 2); assert_eq!(stats.total_checks, 4);
assert_eq!(stats.total_hits, 1); assert_eq!(stats.total_hits, 1);
} }
#[test] #[test]
fn test_replay_checker_many_keys() { fn test_replay_checker_many_keys() {
let checker = ReplayChecker::new(1000, Duration::from_secs(60)); let checker = ReplayChecker::new(10_000, Duration::from_secs(60));
for i in 0..500u32 { for i in 0..500u32 {
checker.add(&i.to_le_bytes()); checker.add_only(&i.to_le_bytes());
} }
for i in 0..500u32 { for i in 0..500u32 {
assert!(checker.check(&i.to_le_bytes())); assert!(checker.check_handshake(&i.to_le_bytes()));
} }
assert_eq!(checker.stats().total_entries, 500); assert_eq!(checker.stats().total_entries, 500);
} }
} }

View File

@@ -381,9 +381,14 @@ mod tests {
// Add a buffer to pool // Add a buffer to pool
pool.preallocate(1); pool.preallocate(1);
// Now try_get should succeed // Now try_get should succeed once while the buffer is held
assert!(pool.try_get().is_some()); let buf = pool.try_get();
assert!(buf.is_some());
// While buffer is held, pool is empty
assert!(pool.try_get().is_none()); assert!(pool.try_get().is_none());
// Drop buffer -> returns to pool, should be obtainable again
drop(buf);
assert!(pool.try_get().is_some());
} }
#[test] #[test]
@@ -448,4 +453,4 @@ mod tests {
// All buffers should be returned // All buffers should be returned
assert!(stats.pooled > 0); assert!(stats.pooled > 0);
} }
} }

View File

@@ -32,7 +32,7 @@
//! and uploads from iOS will break (media/file sending), while small traffic //! and uploads from iOS will break (media/file sending), while small traffic
//! may still work. //! may still work.
use bytes::{Bytes, BytesMut, BufMut}; use bytes::{Bytes, BytesMut};
use std::io::{self, Error, ErrorKind, Result}; use std::io::{self, Error, ErrorKind, Result};
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
@@ -51,9 +51,10 @@ use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer};
/// TLS record header size (type + version + length) /// TLS record header size (type + version + length)
const TLS_HEADER_SIZE: usize = 5; const TLS_HEADER_SIZE: usize = 5;
/// Maximum TLS fragment size per spec (plaintext fragment). /// Maximum TLS fragment size we emit for Application Data.
/// We use this for *outgoing* chunking, because we build plain ApplicationData records. /// Real TLS 1.3 ciphertexts often add ~16-24 bytes AEAD overhead, so to mimic
const MAX_TLS_PAYLOAD: usize = 16384; /// on-the-wire record sizes we allow up to 16384 + 24 bytes of plaintext.
const MAX_TLS_PAYLOAD: usize = 16384 + 24;
/// Maximum pending write buffer for one record remainder. /// Maximum pending write buffer for one record remainder.
/// Note: we never queue unlimited amount of data here; state holds at most one record. /// Note: we never queue unlimited amount of data here; state holds at most one record.
@@ -918,10 +919,8 @@ mod tests {
let reader = ChunkedReader::new(&record, 100); let reader = ChunkedReader::new(&record, 100);
let mut tls_reader = FakeTlsReader::new(reader); let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; payload.len()]; let buf = tls_reader.read_exact(payload.len()).await.unwrap();
tls_reader.read_exact(&mut buf).await.unwrap(); assert_eq!(&buf[..], payload);
assert_eq!(&buf, payload);
} }
#[tokio::test] #[tokio::test]
@@ -935,13 +934,11 @@ mod tests {
let reader = ChunkedReader::new(&data, 100); let reader = ChunkedReader::new(&data, 100);
let mut tls_reader = FakeTlsReader::new(reader); let mut tls_reader = FakeTlsReader::new(reader);
let mut buf1 = vec![0u8; payload1.len()]; let buf1 = tls_reader.read_exact(payload1.len()).await.unwrap();
tls_reader.read_exact(&mut buf1).await.unwrap(); assert_eq!(&buf1[..], payload1);
assert_eq!(&buf1, payload1);
let mut buf2 = vec![0u8; payload2.len()]; let buf2 = tls_reader.read_exact(payload2.len()).await.unwrap();
tls_reader.read_exact(&mut buf2).await.unwrap(); assert_eq!(&buf2[..], payload2);
assert_eq!(&buf2, payload2);
} }
#[tokio::test] #[tokio::test]
@@ -953,10 +950,9 @@ mod tests {
let reader = ChunkedReader::new(&record, 1); // 1 byte at a time! let reader = ChunkedReader::new(&record, 1); // 1 byte at a time!
let mut tls_reader = FakeTlsReader::new(reader); let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; payload.len()]; let buf = tls_reader.read_exact(payload.len()).await.unwrap();
tls_reader.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, payload); assert_eq!(&buf[..], payload);
} }
#[tokio::test] #[tokio::test]
@@ -967,10 +963,9 @@ mod tests {
let reader = ChunkedReader::new(&record, 7); // Awkward chunk size let reader = ChunkedReader::new(&record, 7); // Awkward chunk size
let mut tls_reader = FakeTlsReader::new(reader); let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; payload.len()]; let buf = tls_reader.read_exact(payload.len()).await.unwrap();
tls_reader.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, payload); assert_eq!(&buf[..], payload);
} }
#[tokio::test] #[tokio::test]
@@ -983,10 +978,9 @@ mod tests {
let reader = ChunkedReader::new(&data, 100); let reader = ChunkedReader::new(&data, 100);
let mut tls_reader = FakeTlsReader::new(reader); let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; payload.len()]; let buf = tls_reader.read_exact(payload.len()).await.unwrap();
tls_reader.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, payload); assert_eq!(&buf[..], payload);
} }
#[tokio::test] #[tokio::test]
@@ -1000,10 +994,9 @@ mod tests {
let reader = ChunkedReader::new(&data, 3); // Small chunks let reader = ChunkedReader::new(&data, 3); // Small chunks
let mut tls_reader = FakeTlsReader::new(reader); let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; payload.len()]; let buf = tls_reader.read_exact(payload.len()).await.unwrap();
tls_reader.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, payload); assert_eq!(&buf[..], payload);
} }
#[tokio::test] #[tokio::test]
@@ -1244,4 +1237,4 @@ mod tests {
let bytes = header.to_bytes(); let bytes = header.to_bytes();
assert_eq!(bytes, [0x17, 0x03, 0x03, 0x12, 0x34]); assert_eq!(bytes, [0x17, 0x03, 0x03, 0x12, 0x34]);
} }
} }

View File

@@ -234,7 +234,10 @@ impl MePool {
let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) = let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) =
match (server_ip, client_ip) { match (server_ip, client_ip) {
(IpMaterial::V4(srv), IpMaterial::V4(clt)) => { // IPv4: reverse byte order for KDF (Python/C reference behavior)
(IpMaterial::V4(mut srv), IpMaterial::V4(mut clt)) => {
srv.reverse();
clt.reverse();
(Some(srv), Some(clt), None, None, clt, srv) (Some(srv), Some(clt), None, None, clt, srv)
} }
(IpMaterial::V6(srv), IpMaterial::V6(clt)) => { (IpMaterial::V6(srv), IpMaterial::V6(clt)) => {