From 7be179b3c0aa9b8060f1bf6a99d7388d0f13a245 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Fri, 2 Jan 2026 01:37:02 +0300 Subject: [PATCH] Added accurate MTProto Frame Types + Tokio Async Intergr --- src/protocol/tls.rs | 498 ++++++++++++++++-- src/stream/crypto_stream.rs | 973 ++++++++++++++++++++++++++++++------ src/stream/frame.rs | 187 +++++++ src/stream/frame_codec.rs | 621 +++++++++++++++++++++++ src/stream/mod.rs | 23 +- 5 files changed, 2103 insertions(+), 199 deletions(-) create mode 100644 src/stream/frame.rs create mode 100644 src/stream/frame_codec.rs diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index f050d9f..354ee9a 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -1,14 +1,22 @@ //! Fake TLS 1.3 Handshake +//! +//! This module handles the fake TLS 1.3 handshake used by MTProto proxy +//! for domain fronting. The handshake looks like valid TLS 1.3 but +//! actually carries MTProto authentication data. use crate::crypto::{sha256_hmac, random::SECURE_RANDOM}; use crate::error::{ProxyError, Result}; use super::constants::*; use std::time::{SystemTime, UNIX_EPOCH}; +// ============= Public Constants ============= + /// TLS handshake digest length pub const TLS_DIGEST_LEN: usize = 32; + /// Position of digest in TLS ClientHello pub const TLS_DIGEST_POS: usize = 11; + /// Length to store for replay protection (first 16 bytes of digest) pub const TLS_DIGEST_HALF_LEN: usize = 16; @@ -16,6 +24,26 @@ pub const TLS_DIGEST_HALF_LEN: usize = 16; pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after +// ============= Private Constants ============= + +/// TLS Extension types +mod extension_type { + pub const KEY_SHARE: u16 = 0x0033; + pub const SUPPORTED_VERSIONS: u16 = 0x002b; +} + +/// TLS Cipher Suites +mod cipher_suite { + pub const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01]; +} + +/// TLS Named Curves +mod named_curve { + pub const X25519: u16 = 0x001d; +} + +// ============= TLS Validation Result ============= + /// Result of validating TLS handshake #[derive(Debug)] pub struct TlsValidation { @@ -29,7 +57,185 @@ pub struct TlsValidation { pub timestamp: u32, } +// ============= TLS Extension Builder ============= + +/// Builder for TLS extensions with correct length calculation +struct TlsExtensionBuilder { + extensions: Vec, +} + +impl TlsExtensionBuilder { + fn new() -> Self { + Self { + extensions: Vec::with_capacity(128), + } + } + + /// Add Key Share extension with X25519 key + fn add_key_share(&mut self, public_key: &[u8; 32]) -> &mut Self { + // Extension type: key_share (0x0033) + self.extensions.extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes()); + + // Key share entry: curve (2) + key_len (2) + key (32) = 36 bytes + // Extension data length + let entry_len: u16 = 2 + 2 + 32; // curve + length + key + self.extensions.extend_from_slice(&entry_len.to_be_bytes()); + + // Named curve: x25519 + self.extensions.extend_from_slice(&named_curve::X25519.to_be_bytes()); + + // Key length + self.extensions.extend_from_slice(&(32u16).to_be_bytes()); + + // Key data + self.extensions.extend_from_slice(public_key); + + self + } + + /// Add Supported Versions extension + fn add_supported_versions(&mut self, version: u16) -> &mut Self { + // Extension type: supported_versions (0x002b) + self.extensions.extend_from_slice(&extension_type::SUPPORTED_VERSIONS.to_be_bytes()); + + // Extension data: length (2) + version (2) + self.extensions.extend_from_slice(&(2u16).to_be_bytes()); + + // Selected version + self.extensions.extend_from_slice(&version.to_be_bytes()); + + self + } + + /// Build final extensions with length prefix + fn build(self) -> Vec { + let mut result = Vec::with_capacity(2 + self.extensions.len()); + + // Extensions length (2 bytes) + let len = self.extensions.len() as u16; + result.extend_from_slice(&len.to_be_bytes()); + + // Extensions data + result.extend_from_slice(&self.extensions); + + result + } + + /// Get current extensions without length prefix (for calculation) + #[allow(dead_code)] + fn as_bytes(&self) -> &[u8] { + &self.extensions + } +} + +// ============= ServerHello Builder ============= + +/// Builder for TLS ServerHello with correct structure +struct ServerHelloBuilder { + /// Random bytes (32 bytes, will contain digest) + random: [u8; 32], + /// Session ID (echoed from ClientHello) + session_id: Vec, + /// Cipher suite + cipher_suite: [u8; 2], + /// Compression method + compression: u8, + /// Extensions + extensions: TlsExtensionBuilder, +} + +impl ServerHelloBuilder { + fn new(session_id: Vec) -> Self { + Self { + random: [0u8; 32], + session_id, + cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256, + compression: 0x00, + extensions: TlsExtensionBuilder::new(), + } + } + + fn with_x25519_key(mut self, key: &[u8; 32]) -> Self { + self.extensions.add_key_share(key); + self + } + + fn with_tls13_version(mut self) -> Self { + // TLS 1.3 = 0x0304 + self.extensions.add_supported_versions(0x0304); + self + } + + /// Build ServerHello message (without record header) + fn build_message(&self) -> Vec { + let extensions = self.extensions.extensions.clone(); + let extensions_len = extensions.len() as u16; + + // Calculate total length + let body_len = 2 + // version + 32 + // random + 1 + self.session_id.len() + // session_id length + data + 2 + // cipher suite + 1 + // compression + 2 + extensions.len(); // extensions length + data + + let mut message = Vec::with_capacity(4 + body_len); + + // Handshake header + message.push(0x02); // ServerHello message type + + // 3-byte length + let len_bytes = (body_len as u32).to_be_bytes(); + message.extend_from_slice(&len_bytes[1..4]); + + // Server version (TLS 1.2 in header, actual version in extension) + message.extend_from_slice(&TLS_VERSION); + + // Random (32 bytes) - placeholder, will be replaced with digest + message.extend_from_slice(&self.random); + + // Session ID + message.push(self.session_id.len() as u8); + message.extend_from_slice(&self.session_id); + + // Cipher suite + message.extend_from_slice(&self.cipher_suite); + + // Compression method + message.push(self.compression); + + // Extensions length + message.extend_from_slice(&extensions_len.to_be_bytes()); + + // Extensions data + message.extend_from_slice(&extensions); + + message + } + + /// Build complete ServerHello TLS record + fn build_record(&self) -> Vec { + let message = self.build_message(); + + let mut record = Vec::with_capacity(5 + message.len()); + + // TLS record header + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(message.len() as u16).to_be_bytes()); + + // Message + record.extend_from_slice(&message); + + record + } +} + +// ============= Public Functions ============= + /// Validate TLS ClientHello against user secrets +/// +/// Returns validation result if a matching user is found. pub fn validate_tls_handshake( handshake: &[u8], secrets: &[(String, Vec)], @@ -86,7 +292,8 @@ pub fn validate_tls_handshake( // Check time skew if !ignore_time_skew { // Allow very small timestamps (boot time instead of unix time) - let is_boot_time = timestamp < 60 * 60 * 24 * 1000; + // This is a quirk in some clients that use uptime instead of real time + let is_boot_time = timestamp < 60 * 60 * 24 * 1000; // < ~2.7 years in seconds if !is_boot_time && (time_diff < TIME_SKEW_MIN || time_diff > TIME_SKEW_MAX) { continue; @@ -105,15 +312,22 @@ pub fn validate_tls_handshake( } /// Generate a fake X25519 public key for TLS -/// This generates a value that looks like a valid X25519 key +/// +/// This generates random bytes that look like a valid X25519 public key. +/// Since we're not doing real TLS, the actual cryptographic properties don't matter. pub fn gen_fake_x25519_key() -> [u8; 32] { - // For simplicity, just generate random 32 bytes - // In real X25519, this would be a point on the curve let bytes = SECURE_RANDOM.bytes(32); bytes.try_into().unwrap() } /// Build TLS ServerHello response +/// +/// This builds a complete TLS 1.3-like response including: +/// - ServerHello record with extensions +/// - Change Cipher Spec record +/// - Fake encrypted certificate (Application Data record) +/// +/// The response includes an HMAC digest that the client can verify. pub fn build_server_hello( secret: &[u8], client_digest: &[u8; TLS_DIGEST_LEN], @@ -122,62 +336,48 @@ pub fn build_server_hello( ) -> Vec { let x25519_key = gen_fake_x25519_key(); - // TLS extensions - let mut extensions = Vec::new(); - extensions.extend_from_slice(&[0x00, 0x2e]); // Extension length placeholder - extensions.extend_from_slice(&[0x00, 0x33, 0x00, 0x24]); // Key share extension - extensions.extend_from_slice(&[0x00, 0x1d, 0x00, 0x20]); // X25519 curve - extensions.extend_from_slice(&x25519_key); - extensions.extend_from_slice(&[0x00, 0x2b, 0x00, 0x02, 0x03, 0x04]); // Supported versions + // Build ServerHello + let server_hello = ServerHelloBuilder::new(session_id.to_vec()) + .with_x25519_key(&x25519_key) + .with_tls13_version() + .build_record(); - // ServerHello body - let mut srv_hello = Vec::new(); - srv_hello.extend_from_slice(&TLS_VERSION); - srv_hello.extend_from_slice(&[0u8; TLS_DIGEST_LEN]); // Placeholder for digest - srv_hello.push(session_id.len() as u8); - srv_hello.extend_from_slice(session_id); - srv_hello.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256 - srv_hello.push(0x00); // No compression - srv_hello.extend_from_slice(&extensions); + // Build Change Cipher Spec record + let change_cipher_spec = [ + TLS_RECORD_CHANGE_CIPHER, + TLS_VERSION[0], TLS_VERSION[1], + 0x00, 0x01, // length = 1 + 0x01, // CCS byte + ]; - // Build complete packet - let mut hello_pkt = Vec::new(); - - // ServerHello record - hello_pkt.push(TLS_RECORD_HANDSHAKE); - hello_pkt.extend_from_slice(&TLS_VERSION); - hello_pkt.extend_from_slice(&((srv_hello.len() + 4) as u16).to_be_bytes()); - hello_pkt.push(0x02); // ServerHello message type - let len_bytes = (srv_hello.len() as u32).to_be_bytes(); - hello_pkt.extend_from_slice(&len_bytes[1..4]); // 3-byte length - hello_pkt.extend_from_slice(&srv_hello); - - // Change Cipher Spec record - hello_pkt.extend_from_slice(&[ - TLS_RECORD_CHANGE_CIPHER, - TLS_VERSION[0], TLS_VERSION[1], - 0x00, 0x01, 0x01 - ]); - - // Application Data record (fake certificate) + // Build fake certificate (Application Data record) let fake_cert = SECURE_RANDOM.bytes(fake_cert_len); - hello_pkt.push(TLS_RECORD_APPLICATION); - hello_pkt.extend_from_slice(&TLS_VERSION); - hello_pkt.extend_from_slice(&(fake_cert.len() as u16).to_be_bytes()); - hello_pkt.extend_from_slice(&fake_cert); + let mut app_data_record = Vec::with_capacity(5 + fake_cert_len); + app_data_record.push(TLS_RECORD_APPLICATION); + app_data_record.extend_from_slice(&TLS_VERSION); + app_data_record.extend_from_slice(&(fake_cert_len as u16).to_be_bytes()); + app_data_record.extend_from_slice(&fake_cert); + + // Combine all records + let mut response = Vec::with_capacity( + server_hello.len() + change_cipher_spec.len() + app_data_record.len() + ); + response.extend_from_slice(&server_hello); + response.extend_from_slice(&change_cipher_spec); + response.extend_from_slice(&app_data_record); // Compute HMAC for the response - let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + hello_pkt.len()); + let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + response.len()); hmac_input.extend_from_slice(client_digest); - hmac_input.extend_from_slice(&hello_pkt); + hmac_input.extend_from_slice(&response); let response_digest = sha256_hmac(secret, &hmac_input); - // Insert computed digest - // Position: after record header (5) + message type/length (4) + version (2) = 11 - hello_pkt[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + // Insert computed digest into ServerHello + // Position: record header (5) + message type (1) + length (3) + version (2) = 11 + response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] .copy_from_slice(&response_digest); - hello_pkt + response } /// Check if bytes look like a TLS ClientHello @@ -186,7 +386,7 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool { return false; } - // TLS record header: 0x16 0x03 0x01 + // TLS record header: 0x16 (handshake) 0x03 0x01 (TLS 1.0) first_bytes[0] == TLS_RECORD_HANDSHAKE && first_bytes[1] == 0x03 && first_bytes[2] == 0x01 @@ -206,6 +406,61 @@ pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> { Some((record_type, length)) } +/// Validate a ServerHello response structure +/// +/// This is useful for testing that our ServerHello is well-formed. +#[cfg(test)] +fn validate_server_hello_structure(data: &[u8]) -> Result<()> { + if data.len() < 5 { + return Err(ProxyError::InvalidTlsRecord { + record_type: 0, + version: [0, 0], + }); + } + + // Check record header + if data[0] != TLS_RECORD_HANDSHAKE { + return Err(ProxyError::InvalidTlsRecord { + record_type: data[0], + version: [data[1], data[2]], + }); + } + + // Check version + if data[1..3] != TLS_VERSION { + return Err(ProxyError::InvalidTlsRecord { + record_type: data[0], + version: [data[1], data[2]], + }); + } + + // Check record length + let record_len = u16::from_be_bytes([data[3], data[4]]) as usize; + if data.len() < 5 + record_len { + return Err(ProxyError::InvalidHandshake( + format!("ServerHello record truncated: expected {}, got {}", + 5 + record_len, data.len()) + )); + } + + // Check message type + if data[5] != 0x02 { + return Err(ProxyError::InvalidHandshake( + format!("Expected ServerHello (0x02), got 0x{:02x}", data[5]) + )); + } + + // Parse message length + let msg_len = u32::from_be_bytes([0, data[6], data[7], data[8]]) as usize; + if msg_len + 4 != record_len { + return Err(ProxyError::InvalidHandshake( + format!("Message length mismatch: {} + 4 != {}", msg_len, record_len) + )); + } + + Ok(()) +} + #[cfg(test)] mod tests { use super::*; @@ -241,4 +496,145 @@ mod tests { assert_eq!(key2.len(), 32); assert_ne!(key1, key2); // Should be random } + + #[test] + fn test_tls_extension_builder() { + let key = [0x42u8; 32]; + + let mut builder = TlsExtensionBuilder::new(); + builder.add_key_share(&key); + builder.add_supported_versions(0x0304); + + let result = builder.build(); + + // Check length prefix + let len = u16::from_be_bytes([result[0], result[1]]) as usize; + assert_eq!(len, result.len() - 2); + + // Check key_share extension is present + assert!(result.len() > 40); // At least key share + } + + #[test] + fn test_server_hello_builder() { + let session_id = vec![0x01, 0x02, 0x03, 0x04]; + let key = [0x55u8; 32]; + + let builder = ServerHelloBuilder::new(session_id.clone()) + .with_x25519_key(&key) + .with_tls13_version(); + + let record = builder.build_record(); + + // Validate structure + validate_server_hello_structure(&record).expect("Invalid ServerHello structure"); + + // Check record type + assert_eq!(record[0], TLS_RECORD_HANDSHAKE); + + // Check version + assert_eq!(&record[1..3], &TLS_VERSION); + + // Check message type (ServerHello = 0x02) + assert_eq!(record[5], 0x02); + } + + #[test] + fn test_build_server_hello_structure() { + let secret = b"test secret"; + let client_digest = [0x42u8; 32]; + let session_id = vec![0xAA; 32]; + + let response = build_server_hello(secret, &client_digest, &session_id, 2048); + + // Should have at least 3 records + assert!(response.len() > 100); + + // First record should be ServerHello + assert_eq!(response[0], TLS_RECORD_HANDSHAKE); + + // Validate ServerHello structure + validate_server_hello_structure(&response).expect("Invalid ServerHello"); + + // Find Change Cipher Spec + let server_hello_len = 5 + u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_start = server_hello_len; + + assert!(response.len() > ccs_start + 6); + assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER); + + // Find Application Data + let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize; + let app_start = ccs_start + ccs_len; + + assert!(response.len() > app_start + 5); + assert_eq!(response[app_start], TLS_RECORD_APPLICATION); + } + + #[test] + fn test_build_server_hello_digest() { + let secret = b"test secret key here"; + let client_digest = [0x42u8; 32]; + let session_id = vec![0xAA; 32]; + + let response1 = build_server_hello(secret, &client_digest, &session_id, 1024); + let response2 = build_server_hello(secret, &client_digest, &session_id, 1024); + + // Digest position should have non-zero data + let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; + assert!(!digest1.iter().all(|&b| b == 0)); + + // Different calls should have different digests (due to random cert) + let digest2 = &response2[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; + assert_ne!(digest1, digest2); + } + + #[test] + fn test_server_hello_extensions_length() { + let session_id = vec![0x01; 32]; + let key = [0x55u8; 32]; + + let builder = ServerHelloBuilder::new(session_id) + .with_x25519_key(&key) + .with_tls13_version(); + + let record = builder.build_record(); + + // Parse to find extensions + let msg_start = 5; // After record header + let msg_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize; + + // Skip to session ID + let session_id_pos = msg_start + 4 + 2 + 32; // header(4) + version(2) + random(32) + let session_id_len = record[session_id_pos] as usize; + + // Skip to extensions + let ext_len_pos = session_id_pos + 1 + session_id_len + 2 + 1; // session_id + cipher(2) + compression(1) + let ext_len = u16::from_be_bytes([record[ext_len_pos], record[ext_len_pos + 1]]) as usize; + + // Verify extensions length matches actual data + let extensions_data = &record[ext_len_pos + 2..msg_start + 4 + msg_len]; + assert_eq!(ext_len, extensions_data.len(), + "Extension length mismatch: declared {}, actual {}", ext_len, extensions_data.len()); + } + + #[test] + fn test_validate_tls_handshake_format() { + // Build a minimal ClientHello-like structure + let mut handshake = vec![0u8; 100]; + + // Put a valid-looking digest at position 11 + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .copy_from_slice(&[0x42; 32]); + + // Session ID length + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32; + + // This won't validate (wrong HMAC) but shouldn't panic + let secrets = vec![("test".to_string(), b"secret".to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + + // Should return None (no match) but not panic + assert!(result.is_none()); + } } \ No newline at end of file diff --git a/src/stream/crypto_stream.rs b/src/stream/crypto_stream.rs index 123dfa5..5ee93a7 100644 --- a/src/stream/crypto_stream.rs +++ b/src/stream/crypto_stream.rs @@ -1,18 +1,98 @@ //! Encrypted stream wrappers using AES-CTR +//! +//! This module provides stateful async stream wrappers that handle +//! encryption/decryption with proper partial read/write handling. +//! +//! Key design principles: +//! - Explicit state machines for all async operations +//! - Never lose data on partial reads/writes +//! - Honest reporting of bytes written +//! - Bounded internal buffers with backpressure use bytes::{Bytes, BytesMut, BufMut}; -use std::io::{Error, ErrorKind, Result}; +use std::io::{self, Error, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf}; -use crate::crypto::AesCtr; -use parking_lot::Mutex; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -/// Reader that decrypts data using AES-CTR +use crate::crypto::AesCtr; +use crate::error::StreamError; +use super::state::{StreamState, ReadBuffer, WriteBuffer, YieldBuffer}; + +// ============= Constants ============= + +/// Maximum size for pending write buffer (256KB) +const MAX_PENDING_WRITE: usize = 256 * 1024; + +/// Default read buffer capacity +const DEFAULT_READ_CAPACITY: usize = 16 * 1024; + +// ============= CryptoReader State ============= + +/// State machine states for CryptoReader +#[derive(Debug)] +enum CryptoReaderState { + /// Ready to read new data + Idle, + + /// Have decrypted data ready to yield to caller + Yielding { + /// Buffer containing decrypted data + buffer: YieldBuffer, + }, + + /// Stream encountered an error and cannot be used + Poisoned { + /// The error that caused poisoning (taken on first access) + error: Option, + }, +} + +impl StreamState for CryptoReaderState { + fn is_terminal(&self) -> bool { + matches!(self, Self::Poisoned { .. }) + } + + fn is_poisoned(&self) -> bool { + matches!(self, Self::Poisoned { .. }) + } + + fn state_name(&self) -> &'static str { + match self { + Self::Idle => "Idle", + Self::Yielding { .. } => "Yielding", + Self::Poisoned { .. } => "Poisoned", + } + } +} + +// ============= CryptoReader ============= + +/// Reader that decrypts data using AES-CTR with proper state machine +/// +/// This reader handles partial reads correctly by maintaining internal state +/// and never losing any data that has been read from upstream. +/// +/// # State Machine +/// +/// ┌──────────┐ read ┌──────────┐ +/// │ Idle │ ------------> │ Yielding │ +/// │ │ <------------ │ │ +/// └──────────┘ drained └──────────┘ +/// │ │ +/// │ errors │ +/// ┌──────────────────────────────────────┐ +/// │ Poisoned │ +/// └──────────────────────────────────────┘ pub struct CryptoReader { + /// Upstream reader upstream: R, + /// AES-CTR decryptor decryptor: AesCtr, - buffer: BytesMut, + /// Current state + state: CryptoReaderState, + /// Internal read buffer for upstream reads + read_buf: BytesMut, } impl CryptoReader { @@ -21,7 +101,8 @@ impl CryptoReader { Self { upstream, decryptor, - buffer: BytesMut::with_capacity(8192), + state: CryptoReaderState::Idle, + read_buf: BytesMut::with_capacity(DEFAULT_READ_CAPACITY), } } @@ -39,6 +120,33 @@ impl CryptoReader { pub fn into_inner(self) -> R { self.upstream } + + /// Check if stream is in poisoned state + pub fn is_poisoned(&self) -> bool { + self.state.is_poisoned() + } + + /// Get current state name (for debugging) + pub fn state_name(&self) -> &'static str { + self.state.state_name() + } + + /// Transition to poisoned state + fn poison(&mut self, error: io::Error) { + self.state = CryptoReaderState::Poisoned { error: Some(error) }; + } + + /// Take error from poisoned state + fn take_poison_error(&mut self) -> io::Error { + match &mut self.state { + CryptoReaderState::Poisoned { error } => { + error.take().unwrap_or_else(|| { + io::Error::new(ErrorKind::Other, "stream previously poisoned") + }) + } + _ => io::Error::new(ErrorKind::Other, "stream not poisoned"), + } + } } impl AsyncRead for CryptoReader { @@ -49,89 +157,350 @@ impl AsyncRead for CryptoReader { ) -> Poll> { let this = self.get_mut(); - if !this.buffer.is_empty() { - let to_copy = this.buffer.len().min(buf.remaining()); - buf.put_slice(&this.buffer.split_to(to_copy)); - return Poll::Ready(Ok(())); - } - - // Zero-copy Reader - let before = buf.filled().len(); - - match Pin::new(&mut this.upstream).poll_read(cx, buf) { - Poll::Ready(Ok(())) => { - let after = buf.filled().len(); - let bytes_read = after - before; - - if bytes_read > 0 { - // Decrypt in-place - let filled = buf.filled_mut(); - this.decryptor.apply(&mut filled[before..after]); + loop { + match &mut this.state { + // Poisoned state - return error + CryptoReaderState::Poisoned { .. } => { + let err = this.take_poison_error(); + return Poll::Ready(Err(err)); } - Poll::Ready(Ok(())) + // Have buffered data to yield + CryptoReaderState::Yielding { buffer } => { + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + + // Copy as much as possible to output + let to_copy = buffer.remaining().min(buf.remaining()); + let dst = buf.initialize_unfilled_to(to_copy); + let copied = buffer.copy_to(dst); + buf.advance(copied); + + // If buffer is drained, transition to Idle + if buffer.is_empty() { + this.state = CryptoReaderState::Idle; + } + + return Poll::Ready(Ok(())); + } + + // Ready to read from upstream + CryptoReaderState::Idle => { + // If caller's buffer is empty, nothing to do + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + + // Try to read directly into caller's buffer for zero-copy path + // We need to be careful: read into unfilled portion, then decrypt + let before_len = buf.filled().len(); + + match Pin::new(&mut this.upstream).poll_read(cx, buf) { + Poll::Pending => return Poll::Pending, + + Poll::Ready(Err(e)) => { + this.poison(io::Error::new(e.kind(), e.to_string())); + return Poll::Ready(Err(e)); + } + + Poll::Ready(Ok(())) => { + let after_len = buf.filled().len(); + let bytes_read = after_len - before_len; + + if bytes_read == 0 { + // EOF + return Poll::Ready(Ok(())); + } + + // Decrypt the newly read data in-place + let filled = buf.filled_mut(); + this.decryptor.apply(&mut filled[before_len..after_len]); + + return Poll::Ready(Ok(())); + } + } + } } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, } } } impl CryptoReader { - /// Read and decrypt exactly n bytes with Async + /// Read and decrypt exactly n bytes + /// + /// This is a convenience method that accumulates data until + /// exactly n bytes are available. pub async fn read_exact_decrypt(&mut self, n: usize) -> Result { - let mut result = BytesMut::with_capacity(n); + use tokio::io::AsyncReadExt; - if !self.buffer.is_empty() { - let to_take = self.buffer.len().min(n); - result.extend_from_slice(&self.buffer.split_to(to_take)); + if self.is_poisoned() { + return Err(self.take_poison_error()); } - // Reread + let mut result = BytesMut::with_capacity(n); + + // First drain any buffered data from Yielding state + if let CryptoReaderState::Yielding { buffer } = &mut self.state { + let to_take = buffer.remaining().min(n); + let mut temp = vec![0u8; to_take]; + buffer.copy_to(&mut temp); + result.extend_from_slice(&temp); + + if buffer.is_empty() { + self.state = CryptoReaderState::Idle; + } + } + + // Read remaining from upstream while result.len() < n { let mut temp = vec![0u8; n - result.len()]; - let read = self.upstream.read(&mut temp).await?; + let read = self.read(&mut temp).await?; if read == 0 { - return Err(Error::new(ErrorKind::UnexpectedEof, "Connection closed")); + return Err(io::Error::new( + ErrorKind::UnexpectedEof, + format!("expected {} bytes, got {}", n, result.len()) + )); } - // Decrypt - self.decryptor.apply(&mut temp[..read]); result.extend_from_slice(&temp[..read]); } Ok(result.freeze()) } + + /// Read into internal buffer and return decrypted bytes + /// + /// Useful when you need the data as Bytes rather than copying to a slice. + pub async fn read_decrypt(&mut self, max_size: usize) -> Result { + use tokio::io::AsyncReadExt; + + if self.is_poisoned() { + return Err(self.take_poison_error()); + } + + // First check if we have buffered data + if let CryptoReaderState::Yielding { buffer } = &mut self.state { + let to_take = buffer.remaining().min(max_size); + let mut temp = vec![0u8; to_take]; + buffer.copy_to(&mut temp); + + if buffer.is_empty() { + self.state = CryptoReaderState::Idle; + } + + return Ok(Bytes::from(temp)); + } + + // Read from upstream + let mut temp = vec![0u8; max_size]; + let read = self.read(&mut temp).await?; + + if read == 0 { + return Ok(Bytes::new()); + } + + temp.truncate(read); + Ok(Bytes::from(temp)) + } } -/// Writer that encrypts data using AES-CTR +// ============= CryptoWriter State ============= + +/// State machine states for CryptoWriter +#[derive(Debug)] +enum CryptoWriterState { + /// Ready to accept new data + Idle, + + /// Have pending encrypted data to flush + Flushing { + /// Buffer of encrypted data waiting to be written + pending: WriteBuffer, + }, + + /// Stream encountered an error and cannot be used + Poisoned { + /// The error that caused poisoning + error: Option, + }, +} + +impl StreamState for CryptoWriterState { + fn is_terminal(&self) -> bool { + matches!(self, Self::Poisoned { .. }) + } + + fn is_poisoned(&self) -> bool { + matches!(self, Self::Poisoned { .. }) + } + + fn state_name(&self) -> &'static str { + match self { + Self::Idle => "Idle", + Self::Flushing { .. } => "Flushing", + Self::Poisoned { .. } => "Poisoned", + } + } +} + +// ============= CryptoWriter ============= + +/// Writer that encrypts data using AES-CTR with proper state machine +/// +/// This writer handles partial writes correctly by: +/// - Maintaining internal state for pending data +/// - Returning honest byte counts (only what's actually written or safely buffered) +/// - Implementing backpressure when internal buffer is full +/// +/// # State Machine +/// +/// ┌──────────┐ write ┌──────────┐ +/// │ Idle │ ----------> │ Flushing │ +/// │ │ <---------- │ │ +/// └──────────┘ flushed └──────────┘ +/// │ │ +/// │ errors │ +/// ┌───────────────────────────────────┐ +/// │ Poisoned │ +/// └───────────────────────────────────┘ +/// +/// # Backpressure +/// +/// When the internal pending buffer exceeds `MAX_PENDING_WRITE`, the writer +/// will return `Poll::Pending` until some data has been flushed to upstream. pub struct CryptoWriter { + /// Upstream writer upstream: W, + /// AES-CTR encryptor encryptor: AesCtr, - pending: BytesMut, + /// Current state + state: CryptoWriterState, } impl CryptoWriter { + /// Create new crypto writer pub fn new(upstream: W, encryptor: AesCtr) -> Self { Self { upstream, encryptor, - pending: BytesMut::with_capacity(8192), + state: CryptoWriterState::Idle, } } + /// Get reference to upstream pub fn get_ref(&self) -> &W { &self.upstream } + /// Get mutable reference to upstream pub fn get_mut(&mut self) -> &mut W { &mut self.upstream } + /// Consume and return upstream pub fn into_inner(self) -> W { self.upstream } + + /// Check if stream is in poisoned state + pub fn is_poisoned(&self) -> bool { + self.state.is_poisoned() + } + + /// Get current state name (for debugging) + pub fn state_name(&self) -> &'static str { + self.state.state_name() + } + + /// Check if there's pending data to flush + pub fn has_pending(&self) -> bool { + matches!(&self.state, CryptoWriterState::Flushing { pending } if !pending.is_empty()) + } + + /// Get pending bytes count + pub fn pending_len(&self) -> usize { + match &self.state { + CryptoWriterState::Flushing { pending } => pending.len(), + _ => 0, + } + } + + /// Transition to poisoned state + fn poison(&mut self, error: io::Error) { + self.state = CryptoWriterState::Poisoned { error: Some(error) }; + } + + /// Take error from poisoned state + fn take_poison_error(&mut self) -> io::Error { + match &mut self.state { + CryptoWriterState::Poisoned { error } => { + error.take().unwrap_or_else(|| { + io::Error::new(ErrorKind::Other, "stream previously poisoned") + }) + } + _ => io::Error::new(ErrorKind::Other, "stream not poisoned"), + } + } +} + +impl CryptoWriter { + /// Try to flush pending data to upstream + /// + /// Returns: + /// - `Poll::Ready(Ok(true))` if all pending data was flushed + /// - `Poll::Ready(Ok(false))` if some data remains + /// - `Poll::Pending` if upstream would block + /// - `Poll::Ready(Err(_))` on error + fn poll_flush_pending(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match &mut self.state { + CryptoWriterState::Idle => { + return Poll::Ready(Ok(true)); + } + + CryptoWriterState::Poisoned { .. } => { + let err = self.take_poison_error(); + return Poll::Ready(Err(err)); + } + + CryptoWriterState::Flushing { pending } => { + if pending.is_empty() { + self.state = CryptoWriterState::Idle; + return Poll::Ready(Ok(true)); + } + + let data = pending.pending(); + match Pin::new(&mut self.upstream).poll_write(cx, data) { + Poll::Pending => return Poll::Pending, + + Poll::Ready(Err(e)) => { + self.poison(io::Error::new(e.kind(), e.to_string())); + return Poll::Ready(Err(e)); + } + + Poll::Ready(Ok(0)) => { + let err = io::Error::new( + ErrorKind::WriteZero, + "upstream returned 0 bytes written" + ); + self.poison(err.into()); + return Poll::Ready(Err(io::Error::new( + ErrorKind::WriteZero, + "upstream returned 0 bytes written" + ))); + } + + Poll::Ready(Ok(n)) => { + pending.advance(n); + // Continue loop to check if fully flushed + } + } + } + } + } + } } impl AsyncWrite for CryptoWriter { @@ -142,101 +511,183 @@ impl AsyncWrite for CryptoWriter { ) -> Poll> { let this = self.get_mut(); - if !this.pending.is_empty() { - match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) { - Poll::Ready(Ok(written)) => { - let _ = this.pending.split_to(written); - - if !this.pending.is_empty() { - cx.waker().wake_by_ref(); - return Poll::Pending; - } - } - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Pending => return Poll::Pending, - } + // Check for poisoned state + if let CryptoWriterState::Poisoned { .. } = &this.state { + let err = this.take_poison_error(); + return Poll::Ready(Err(err)); } - // Pending Null + // Empty write is always successful if buf.is_empty() { return Poll::Ready(Ok(0)); } - // Encrypt + // First, try to flush any pending data + match this.poll_flush_pending(cx) { + Poll::Pending => { + // Check backpressure + if this.pending_len() >= MAX_PENDING_WRITE { + // Too much pending, must wait + return Poll::Pending; + } + // Can buffer more, continue below + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Ready(Ok(_)) => { + // Flushed (possibly partially), continue + } + } + + // Encrypt the data let mut encrypted = buf.to_vec(); this.encryptor.apply(&mut encrypted); - // Write Try + // Try to write directly to upstream first match Pin::new(&mut this.upstream).poll_write(cx, &encrypted) { - Poll::Ready(Ok(written)) => { - if written < encrypted.len() { - // Partial write — сохраняем остаток в pending - this.pending.extend_from_slice(&encrypted[written..]); - } - Poll::Ready(Ok(buf.len())) - } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => { - this.pending.extend_from_slice(&encrypted); + Poll::Ready(Ok(n)) if n == encrypted.len() => { + // All data written directly Poll::Ready(Ok(buf.len())) } + + Poll::Ready(Ok(n)) => { + // Partial write - buffer the rest + let remaining = &encrypted[n..]; + + // Ensure we're in Flushing state + let pending = match &mut this.state { + CryptoWriterState::Flushing { pending } => pending, + CryptoWriterState::Idle => { + this.state = CryptoWriterState::Flushing { + pending: WriteBuffer::with_max_size(MAX_PENDING_WRITE), + }; + match &mut this.state { + CryptoWriterState::Flushing { pending } => pending, + _ => unreachable!(), + } + } + CryptoWriterState::Poisoned { .. } => unreachable!(), + }; + + // Try to buffer remaining + if pending.remaining_capacity() >= remaining.len() { + pending.extend(remaining).expect("capacity checked"); + Poll::Ready(Ok(buf.len())) + } else { + // Not enough buffer space - report what we could write + // The caller will need to retry with the rest + let bytes_accepted = n + pending.remaining_capacity(); + if bytes_accepted > n { + let can_buffer = &encrypted[n..bytes_accepted]; + pending.extend(can_buffer).expect("capacity checked"); + } + Poll::Ready(Ok(bytes_accepted.min(buf.len()))) + } + } + + Poll::Ready(Err(e)) => { + this.poison(io::Error::new(e.kind(), e.to_string())); + Poll::Ready(Err(e)) + } + + Poll::Pending => { + // Upstream would block - buffer the encrypted data + let pending = match &mut this.state { + CryptoWriterState::Flushing { pending } => pending, + CryptoWriterState::Idle => { + this.state = CryptoWriterState::Flushing { + pending: WriteBuffer::with_max_size(MAX_PENDING_WRITE), + }; + match &mut this.state { + CryptoWriterState::Flushing { pending } => pending, + _ => unreachable!(), + } + } + CryptoWriterState::Poisoned { .. } => unreachable!(), + }; + + // Check if we can buffer all + if pending.remaining_capacity() >= encrypted.len() { + pending.extend(&encrypted).expect("capacity checked"); + // Wake up to try flushing later + cx.waker().wake_by_ref(); + Poll::Ready(Ok(buf.len())) + } else if pending.remaining_capacity() > 0 { + // Partial buffer + let can_buffer = pending.remaining_capacity(); + pending.extend(&encrypted[..can_buffer]).expect("capacity checked"); + cx.waker().wake_by_ref(); + Poll::Ready(Ok(can_buffer)) + } else { + // No buffer space - backpressure + Poll::Pending + } + } } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - while !this.pending.is_empty() { - match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) { - Poll::Ready(Ok(0)) => { - return Poll::Ready(Err(Error::new( - ErrorKind::WriteZero, - "Failed to write pending data during flush", - ))); - } - Poll::Ready(Ok(written)) => { - let _ = this.pending.split_to(written); - } - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Pending => return Poll::Pending, + // First flush our pending buffer + match this.poll_flush_pending(cx)? { + Poll::Pending => return Poll::Pending, + Poll::Ready(false) => { + cx.waker().wake_by_ref(); + return Poll::Pending; } + Poll::Ready(true) => {} } + // Then flush upstream Pin::new(&mut this.upstream).poll_flush(cx) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - while !this.pending.is_empty() { - match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) { - Poll::Ready(Ok(0)) => { - break; - } - Poll::Ready(Ok(written)) => { - let _ = this.pending.split_to(written); - } - Poll::Ready(Err(_)) => { - break; - } - Poll::Pending => return Poll::Pending, + // Try to flush pending data first (best effort) + match this.poll_flush_pending(cx) { + Poll::Pending => { + // Continue with shutdown anyway after registering waker } + Poll::Ready(Err(_)) => { + // Ignore flush errors during shutdown + } + Poll::Ready(Ok(_)) => {} } + // Shutdown upstream Pin::new(&mut this.upstream).poll_shutdown(cx) } } +// ============= PassthroughStream ============= + /// Passthrough stream for fast mode - no encryption/decryption +/// +/// Used when keys are set up so that client and Telegram use the same +/// encryption, allowing data to pass through without re-encryption. pub struct PassthroughStream { inner: S, } impl PassthroughStream { + /// Create new passthrough stream pub fn new(inner: S) -> Self { Self { inner } } + /// Get reference to inner stream + pub fn get_ref(&self) -> &S { + &self.inner + } + + /// Get mutable reference to inner stream + pub fn get_mut(&mut self) -> &mut S { + &mut self.inner + } + + /// Consume and return inner stream pub fn into_inner(self) -> S { self.inner } @@ -270,19 +721,40 @@ impl AsyncWrite for PassthroughStream { } } +// ============= Tests ============= + #[cfg(test)] mod tests { use super::*; use std::collections::VecDeque; use std::pin::Pin; use std::task::{Context, Poll, Waker, RawWaker, RawWakerVTable}; - use tokio::io::duplex; + use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; - /// Mock writer + // ============= Test Helpers ============= + + fn noop_waker() -> Waker { + const VTABLE: RawWakerVTable = RawWakerVTable::new( + |_| RawWaker::new(std::ptr::null(), &VTABLE), + |_| {}, + |_| {}, + |_| {}, + ); + unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) } + } + + /// Mock writer that simulates partial writes struct PartialWriter { + /// Max bytes to accept per write chunk_size: usize, + /// Collected data data: Vec, + /// Number of writes performed write_count: usize, + /// If true, return Pending on first write attempt + first_pending: bool, + /// Track if first call happened + first_call: bool, } impl PartialWriter { @@ -291,16 +763,29 @@ mod tests { chunk_size, data: Vec::new(), write_count: 0, + first_pending: false, + first_call: true, } } + + fn with_first_pending(mut self) -> Self { + self.first_pending = true; + self + } } impl AsyncWrite for PartialWriter { fn poll_write( mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, + cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { + if self.first_pending && self.first_call { + self.first_call = false; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + self.write_count += 1; let to_write = buf.len().min(self.chunk_size); self.data.extend_from_slice(&buf[..to_write]); @@ -316,37 +801,166 @@ mod tests { } } - fn noop_waker() -> Waker { - const VTABLE: RawWakerVTable = RawWakerVTable::new( - |_| RawWaker::new(std::ptr::null(), &VTABLE), - |_| {}, - |_| {}, - |_| {}, - ); - unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) } + /// Mock reader that returns data in chunks + struct ChunkedReader { + data: VecDeque, + chunk_size: usize, } - #[test] - fn test_crypto_writer_partial_write_correctness() { + impl ChunkedReader { + fn new(data: &[u8], chunk_size: usize) -> Self { + Self { + data: data.iter().copied().collect(), + chunk_size, + } + } + } + + impl AsyncRead for ChunkedReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.data.is_empty() { + return Poll::Ready(Ok(())); + } + + let to_read = self.chunk_size.min(self.data.len()).min(buf.remaining()); + for _ in 0..to_read { + if let Some(byte) = self.data.pop_front() { + buf.put_slice(&[byte]); + } + } + + Poll::Ready(Ok(())) + } + } + + // ============= CryptoReader Tests ============= + + #[tokio::test] + async fn test_crypto_reader_basic() { let key = [0x42u8; 32]; let iv = 12345u128; - // 10-byte Writer - let mock_writer = PartialWriter::new(10); + // Encrypt some data + let original = b"Hello, encrypted world!"; + let mut encryptor = AesCtr::new(&key, iv); + let encrypted = encryptor.encrypt(original); + + // Create reader + let reader = ChunkedReader::new(&encrypted, 100); + let decryptor = AesCtr::new(&key, iv); + let mut crypto_reader = CryptoReader::new(reader, decryptor); + + // Read and decrypt + let mut buf = vec![0u8; original.len()]; + crypto_reader.read_exact(&mut buf).await.unwrap(); + + assert_eq!(&buf, original); + } + + #[tokio::test] + async fn test_crypto_reader_chunked() { + let key = [0x42u8; 32]; + let iv = 12345u128; + + let original = b"This is a longer message that will be read in chunks"; + let mut encryptor = AesCtr::new(&key, iv); + let encrypted = encryptor.encrypt(original); + + // Read in very small chunks + let reader = ChunkedReader::new(&encrypted, 5); + let decryptor = AesCtr::new(&key, iv); + let mut crypto_reader = CryptoReader::new(reader, decryptor); + + let mut result = Vec::new(); + let mut buf = [0u8; 7]; // Read in chunks different from write chunks + + loop { + let n = crypto_reader.read(&mut buf).await.unwrap(); + if n == 0 { + break; + } + result.extend_from_slice(&buf[..n]); + } + + assert_eq!(&result, original); + } + + #[tokio::test] + async fn test_crypto_reader_read_exact_decrypt() { + let key = [0x42u8; 32]; + let iv = 12345u128; + + let original = b"Exact read test data!"; + let mut encryptor = AesCtr::new(&key, iv); + let encrypted = encryptor.encrypt(original); + + let reader = ChunkedReader::new(&encrypted, 3); // Small chunks + let decryptor = AesCtr::new(&key, iv); + let mut crypto_reader = CryptoReader::new(reader, decryptor); + + let result = crypto_reader.read_exact_decrypt(original.len()).await.unwrap(); + assert_eq!(&result[..], original); + } + + // ============= CryptoWriter Tests ============= + + #[test] + fn test_crypto_writer_basic_sync() { + let key = [0x42u8; 32]; + let iv = 12345u128; + + let mock_writer = PartialWriter::new(100); let encryptor = AesCtr::new(&key, iv); let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor); let waker = noop_waker(); let mut cx = Context::from_waker(&waker); - // 25 byte - let original = b"Hello, this is test data!"; + let original = b"Hello, world!"; - // First Write + // Write + let result = Pin::new(&mut crypto_writer).poll_write(&mut cx, original); + assert!(matches!(result, Poll::Ready(Ok(13)))); + + // Verify encryption happened + let encrypted = &crypto_writer.upstream.data; + assert_eq!(encrypted.len(), original.len()); + assert_ne!(encrypted.as_slice(), original); // Should be encrypted + + // Decrypt and verify + let mut decryptor = AesCtr::new(&key, iv); + let mut decrypted = encrypted.clone(); + decryptor.apply(&mut decrypted); + assert_eq!(&decrypted, original); + } + + #[test] + fn test_crypto_writer_partial_write() { + let key = [0x42u8; 32]; + let iv = 12345u128; + + // Writer that only accepts 5 bytes at a time + let mock_writer = PartialWriter::new(5); + let encryptor = AesCtr::new(&key, iv); + let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor); + + let waker = noop_waker(); + let mut cx = Context::from_waker(&waker); + + let original = b"This is a longer message!"; // 25 bytes + + // First write - should accept all 25 bytes (5 written, 20 buffered) let result = Pin::new(&mut crypto_writer).poll_write(&mut cx, original); assert!(matches!(result, Poll::Ready(Ok(25)))); - // Flush before continue Pending + // Should have pending data + assert!(crypto_writer.has_pending()); + + // Flush to drain pending loop { match Pin::new(&mut crypto_writer).poll_flush(&mut cx) { Poll::Ready(Ok(())) => break, @@ -355,56 +969,48 @@ mod tests { } } - // Write Check - let encrypted = &crypto_writer.upstream.data; - assert_eq!(encrypted.len(), 25); + // All data should be written now + assert!(!crypto_writer.has_pending()); + assert_eq!(crypto_writer.upstream.data.len(), 25); - // Decrypt + Verify + // Verify decryption let mut decryptor = AesCtr::new(&key, iv); - let mut decrypted = encrypted.clone(); + let mut decrypted = crypto_writer.upstream.data.clone(); decryptor.apply(&mut decrypted); - assert_eq!(&decrypted, original); } #[test] - fn test_crypto_writer_multiple_partial_writes() { - let key = [0xAB; 32]; - let iv = 9999u128; - - let mock_writer = PartialWriter::new(3); + fn test_crypto_writer_pending_on_first_write() { + let key = [0x42u8; 32]; + let iv = 12345u128; + + // Writer that returns Pending on first call + let mock_writer = PartialWriter::new(100).with_first_pending(); let encryptor = AesCtr::new(&key, iv); let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor); let waker = noop_waker(); let mut cx = Context::from_waker(&waker); - let data1 = b"First"; - let data2 = b"Second"; - let data3 = b"Third"; + let original = b"Test data"; - Pin::new(&mut crypto_writer).poll_write(&mut cx, data1).unwrap(); - // Flush - while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {} + // First write should buffer and return Ready (not Pending) + // because we have buffer space + let result = Pin::new(&mut crypto_writer).poll_write(&mut cx, original); + assert!(matches!(result, Poll::Ready(Ok(9)))); - Pin::new(&mut crypto_writer).poll_write(&mut cx, data2).unwrap(); - while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {} + // Data should be buffered + assert!(crypto_writer.has_pending()); - Pin::new(&mut crypto_writer).poll_write(&mut cx, data3).unwrap(); - while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {} - - // Assemble - let mut expected = Vec::new(); - expected.extend_from_slice(data1); - expected.extend_from_slice(data2); - expected.extend_from_slice(data3); - - // Decrypt - let mut decryptor = AesCtr::new(&key, iv); - let mut decrypted = crypto_writer.upstream.data.clone(); - decryptor.apply(&mut decrypted); - - assert_eq!(decrypted, expected); + // Second poll_flush should succeed + loop { + match Pin::new(&mut crypto_writer).poll_flush(&mut cx) { + Poll::Ready(Ok(())) => break, + Poll::Ready(Err(e)) => panic!("Flush error: {}", e), + Poll::Pending => continue, + } + } } #[tokio::test] @@ -445,10 +1051,10 @@ mod tests { let mut writer = CryptoWriter::new(client, encryptor); let mut reader = CryptoReader::new(server, decryptor); - // Hugeload + // Large data let original: Vec = (0..10000).map(|i| (i % 256) as u8).collect(); - // Write + // Write in background let write_data = original.clone(); let write_handle = tokio::spawn(async move { writer.write_all(&write_data).await.unwrap(); @@ -471,4 +1077,77 @@ mod tests { assert_eq!(received, original); } + + #[tokio::test] + async fn test_crypto_writer_backpressure() { + let key = [0x42u8; 32]; + let iv = 12345u128; + + // Very small buffer duplex + let (client, _server) = duplex(64); + + let encryptor = AesCtr::new(&key, iv); + let mut writer = CryptoWriter::new(client, encryptor); + + // Try to write a lot of data + let large_data = vec![0u8; MAX_PENDING_WRITE + 1000]; + + // This should eventually block due to backpressure + // (duplex buffer full + our pending buffer full) + let write_result = tokio::time::timeout( + std::time::Duration::from_millis(100), + writer.write_all(&large_data) + ).await; + + // Should timeout because we can't write all data + assert!(write_result.is_err()); + } + + // ============= State Tests ============= + + #[test] + fn test_reader_state_transitions() { + let key = [0u8; 32]; + let iv = 0u128; + + let reader = ChunkedReader::new(&[], 10); + let decryptor = AesCtr::new(&key, iv); + let reader = CryptoReader::new(reader, decryptor); + + assert_eq!(reader.state_name(), "Idle"); + assert!(!reader.is_poisoned()); + } + + #[test] + fn test_writer_state_transitions() { + let key = [0u8; 32]; + let iv = 0u128; + + let writer = PartialWriter::new(10); + let encryptor = AesCtr::new(&key, iv); + let writer = CryptoWriter::new(writer, encryptor); + + assert_eq!(writer.state_name(), "Idle"); + assert!(!writer.is_poisoned()); + assert!(!writer.has_pending()); + } + + // ============= Passthrough Tests ============= + + #[tokio::test] + async fn test_passthrough_stream() { + let (client, server) = duplex(4096); + + let mut writer = PassthroughStream::new(client); + let mut reader = PassthroughStream::new(server); + + let data = b"No encryption here!"; + writer.write_all(data).await.unwrap(); + writer.flush().await.unwrap(); + + let mut buf = vec![0u8; data.len()]; + reader.read_exact(&mut buf).await.unwrap(); + + assert_eq!(&buf, data); + } } \ No newline at end of file diff --git a/src/stream/frame.rs b/src/stream/frame.rs new file mode 100644 index 0000000..32270eb --- /dev/null +++ b/src/stream/frame.rs @@ -0,0 +1,187 @@ +//! MTProto frame types and traits +//! +//! This module defines the common types and traits used by all +//! frame encoding/decoding implementations. + +use bytes::{Bytes, BytesMut}; +use std::io::Result; + +use crate::protocol::constants::ProtoTag; + +// ============= Frame Types ============= + +/// A decoded MTProto frame +#[derive(Debug, Clone)] +pub struct Frame { + /// Frame payload data + pub data: Bytes, + /// Frame metadata + pub meta: FrameMeta, +} + +impl Frame { + /// Create a new frame with data and default metadata + pub fn new(data: Bytes) -> Self { + Self { + data, + meta: FrameMeta::default(), + } + } + + /// Create a new frame with data and metadata + pub fn with_meta(data: Bytes, meta: FrameMeta) -> Self { + Self { data, meta } + } + + /// Create an empty frame + pub fn empty() -> Self { + Self::new(Bytes::new()) + } + + /// Check if frame is empty + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Get frame length + pub fn len(&self) -> usize { + self.data.len() + } + + /// Create a QuickAck request frame + pub fn quickack(data: Bytes) -> Self { + Self { + data, + meta: FrameMeta { + quickack: true, + ..Default::default() + }, + } + } + + /// Create a simple ACK frame + pub fn simple_ack(data: Bytes) -> Self { + Self { + data, + meta: FrameMeta { + simple_ack: true, + ..Default::default() + }, + } + } +} + +/// Frame metadata +#[derive(Debug, Clone, Default)] +pub struct FrameMeta { + /// Quick ACK requested - client wants immediate acknowledgment + pub quickack: bool, + /// This is a simple ACK message (reversed data) + pub simple_ack: bool, + /// Original padding length (for secure mode) + pub padding_len: u8, +} + +impl FrameMeta { + /// Create new empty metadata + pub fn new() -> Self { + Self::default() + } + + /// Create with quickack flag + pub fn with_quickack(mut self) -> Self { + self.quickack = true; + self + } + + /// Create with simple_ack flag + pub fn with_simple_ack(mut self) -> Self { + self.simple_ack = true; + self + } + + /// Create with padding length + pub fn with_padding(mut self, len: u8) -> Self { + self.padding_len = len; + self + } + + /// Check if any special flags are set + pub fn has_flags(&self) -> bool { + self.quickack || self.simple_ack + } +} + +// ============= Codec Trait ============= + +/// Trait for frame codecs that can encode and decode frames +pub trait FrameCodec: Send + Sync { + /// Get the protocol tag for this codec + fn proto_tag(&self) -> ProtoTag; + + /// Encode a frame into the destination buffer + /// + /// Returns the number of bytes written. + fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> Result; + + /// Try to decode a frame from the source buffer + /// + /// Returns: + /// - `Ok(Some(frame))` if a complete frame was decoded + /// - `Ok(None)` if more data is needed + /// - `Err(e)` if an error occurred + /// + /// On success, the consumed bytes are removed from `src`. + fn decode(&self, src: &mut BytesMut) -> Result>; + + /// Get the minimum bytes needed to determine frame length + fn min_header_size(&self) -> usize; + + /// Get the maximum allowed frame size + fn max_frame_size(&self) -> usize { + // Default: 16MB + 16 * 1024 * 1024 + } +} + +// ============= Codec Factory ============= + +/// Create a frame codec for the given protocol tag +pub fn create_codec(proto_tag: ProtoTag) -> Box { + match proto_tag { + ProtoTag::Abridged => Box::new(super::frame_codec::AbridgedCodec::new()), + ProtoTag::Intermediate => Box::new(super::frame_codec::IntermediateCodec::new()), + ProtoTag::Secure => Box::new(super::frame_codec::SecureCodec::new()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_frame_creation() { + let frame = Frame::new(Bytes::from_static(b"test")); + assert_eq!(frame.len(), 4); + assert!(!frame.is_empty()); + assert!(!frame.meta.quickack); + + let frame = Frame::empty(); + assert!(frame.is_empty()); + + let frame = Frame::quickack(Bytes::from_static(b"ack")); + assert!(frame.meta.quickack); + } + + #[test] + fn test_frame_meta() { + let meta = FrameMeta::new() + .with_quickack() + .with_padding(3); + + assert!(meta.quickack); + assert!(!meta.simple_ack); + assert_eq!(meta.padding_len, 3); + assert!(meta.has_flags()); + } +} \ No newline at end of file diff --git a/src/stream/frame_codec.rs b/src/stream/frame_codec.rs new file mode 100644 index 0000000..75f6bde --- /dev/null +++ b/src/stream/frame_codec.rs @@ -0,0 +1,621 @@ +//! tokio-util codec integration for MTProto frames +//! +//! This module provides Encoder/Decoder implementations compatible +//! with tokio-util's Framed wrapper for easy async frame I/O. + +use bytes::{Bytes, BytesMut, BufMut}; +use std::io::{self, Error, ErrorKind}; +use tokio_util::codec::{Decoder, Encoder}; + +use crate::protocol::constants::ProtoTag; +use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait}; + +// ============= Unified Codec ============= + +/// Unified frame codec that wraps all protocol variants +/// +/// This codec implements tokio-util's Encoder and Decoder traits, +/// allowing it to be used with `Framed` for async frame I/O. +pub struct FrameCodec { + /// Protocol variant + proto_tag: ProtoTag, + /// Maximum allowed frame size + max_frame_size: usize, +} + +impl FrameCodec { + /// Create a new codec for the given protocol + pub fn new(proto_tag: ProtoTag) -> Self { + Self { + proto_tag, + max_frame_size: 16 * 1024 * 1024, // 16MB default + } + } + + /// Set maximum frame size + pub fn with_max_frame_size(mut self, size: usize) -> Self { + self.max_frame_size = size; + self + } + + /// Get protocol tag + pub fn proto_tag(&self) -> ProtoTag { + self.proto_tag + } +} + +impl Decoder for FrameCodec { + type Item = Frame; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + match self.proto_tag { + ProtoTag::Abridged => decode_abridged(src, self.max_frame_size), + ProtoTag::Intermediate => decode_intermediate(src, self.max_frame_size), + ProtoTag::Secure => decode_secure(src, self.max_frame_size), + } + } +} + +impl Encoder for FrameCodec { + type Error = io::Error; + + fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { + match self.proto_tag { + ProtoTag::Abridged => encode_abridged(&frame, dst), + ProtoTag::Intermediate => encode_intermediate(&frame, dst), + ProtoTag::Secure => encode_secure(&frame, dst), + } + } +} + +// ============= Abridged Protocol ============= + +fn decode_abridged(src: &mut BytesMut, max_size: usize) -> io::Result> { + if src.is_empty() { + return Ok(None); + } + + let mut meta = FrameMeta::new(); + let first_byte = src[0]; + + // Extract length and quickack flag + let mut len_words = (first_byte & 0x7f) as usize; + if first_byte >= 0x80 { + meta.quickack = true; + } + + let header_len; + + if len_words == 0x7f { + // Extended length (3 more bytes needed) + if src.len() < 4 { + return Ok(None); + } + len_words = u32::from_le_bytes([src[1], src[2], src[3], 0]) as usize; + header_len = 4; + } else { + header_len = 1; + } + + // Length is in 4-byte words + let byte_len = len_words.checked_mul(4).ok_or_else(|| { + Error::new(ErrorKind::InvalidData, "frame length overflow") + })?; + + // Validate size + if byte_len > max_size { + return Err(Error::new( + ErrorKind::InvalidData, + format!("frame too large: {} bytes (max {})", byte_len, max_size) + )); + } + + let total_len = header_len + byte_len; + + if src.len() < total_len { + // Reserve space for the rest of the frame + src.reserve(total_len - src.len()); + return Ok(None); + } + + // Extract data + let _ = src.split_to(header_len); + let data = src.split_to(byte_len).freeze(); + + Ok(Some(Frame::with_meta(data, meta))) +} + +fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { + let data = &frame.data; + + // Validate alignment + if data.len() % 4 != 0 { + return Err(Error::new( + ErrorKind::InvalidInput, + format!("abridged frame must be 4-byte aligned, got {} bytes", data.len()) + )); + } + + // Simple ACK: send reversed data without header + if frame.meta.simple_ack { + dst.reserve(data.len()); + for byte in data.iter().rev() { + dst.put_u8(*byte); + } + return Ok(()); + } + + let len_words = data.len() / 4; + + if len_words < 0x7f { + // Short header + dst.reserve(1 + data.len()); + let mut len_byte = len_words as u8; + if frame.meta.quickack { + len_byte |= 0x80; + } + dst.put_u8(len_byte); + } else if len_words < (1 << 24) { + // Extended header + dst.reserve(4 + data.len()); + let mut first = 0x7fu8; + if frame.meta.quickack { + first |= 0x80; + } + dst.put_u8(first); + let len_bytes = (len_words as u32).to_le_bytes(); + dst.extend_from_slice(&len_bytes[..3]); + } else { + return Err(Error::new( + ErrorKind::InvalidInput, + format!("frame too large: {} bytes", data.len()) + )); + } + + dst.extend_from_slice(data); + Ok(()) +} + +// ============= Intermediate Protocol ============= + +fn decode_intermediate(src: &mut BytesMut, max_size: usize) -> io::Result> { + if src.len() < 4 { + return Ok(None); + } + + let mut meta = FrameMeta::new(); + let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize; + + // Check QuickACK flag + if len >= 0x80000000 { + meta.quickack = true; + len -= 0x80000000; + } + + // Validate size + if len > max_size { + return Err(Error::new( + ErrorKind::InvalidData, + format!("frame too large: {} bytes (max {})", len, max_size) + )); + } + + let total_len = 4 + len; + + if src.len() < total_len { + src.reserve(total_len - src.len()); + return Ok(None); + } + + // Extract data + let _ = src.split_to(4); + let data = src.split_to(len).freeze(); + + Ok(Some(Frame::with_meta(data, meta))) +} + +fn encode_intermediate(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { + let data = &frame.data; + + // Simple ACK: just send data + if frame.meta.simple_ack { + dst.reserve(data.len()); + dst.extend_from_slice(data); + return Ok(()); + } + + dst.reserve(4 + data.len()); + + let mut len = data.len() as u32; + if frame.meta.quickack { + len |= 0x80000000; + } + + dst.extend_from_slice(&len.to_le_bytes()); + dst.extend_from_slice(data); + + Ok(()) +} + +// ============= Secure Intermediate Protocol ============= + +fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result> { + if src.len() < 4 { + return Ok(None); + } + + let mut meta = FrameMeta::new(); + let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize; + + // Check QuickACK flag + if len >= 0x80000000 { + meta.quickack = true; + len -= 0x80000000; + } + + // Validate size + if len > max_size { + return Err(Error::new( + ErrorKind::InvalidData, + format!("frame too large: {} bytes (max {})", len, max_size) + )); + } + + let total_len = 4 + len; + + if src.len() < total_len { + src.reserve(total_len - src.len()); + return Ok(None); + } + + // Calculate padding (indicated by length not divisible by 4) + let padding_len = len % 4; + let data_len = if padding_len != 0 { + len - padding_len + } else { + len + }; + + meta.padding_len = padding_len as u8; + + // Extract data (excluding padding) + let _ = src.split_to(4); + let all_data = src.split_to(len); + // Copy only the data portion, excluding padding + let data = Bytes::copy_from_slice(&all_data[..data_len]); + + Ok(Some(Frame::with_meta(data, meta))) +} + +fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { + use crate::crypto::random::SECURE_RANDOM; + + let data = &frame.data; + + // Simple ACK: just send data + if frame.meta.simple_ack { + dst.reserve(data.len()); + dst.extend_from_slice(data); + return Ok(()); + } + + // Generate padding to make length not divisible by 4 + let padding_len = if data.len() % 4 == 0 { + // Add 1-3 bytes to make it non-aligned + (SECURE_RANDOM.range(3) + 1) as usize + } else { + // Already non-aligned, can add 0-3 + SECURE_RANDOM.range(4) as usize + }; + + let total_len = data.len() + padding_len; + dst.reserve(4 + total_len); + + let mut len = total_len as u32; + if frame.meta.quickack { + len |= 0x80000000; + } + + dst.extend_from_slice(&len.to_le_bytes()); + dst.extend_from_slice(data); + + if padding_len > 0 { + let padding = SECURE_RANDOM.bytes(padding_len); + dst.extend_from_slice(&padding); + } + + Ok(()) +} + +// ============= Typed Codecs ============= + +/// Abridged protocol codec +pub struct AbridgedCodec { + max_frame_size: usize, +} + +impl AbridgedCodec { + pub fn new() -> Self { + Self { + max_frame_size: 16 * 1024 * 1024, + } + } +} + +impl Default for AbridgedCodec { + fn default() -> Self { + Self::new() + } +} + +impl Decoder for AbridgedCodec { + type Item = Frame; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + decode_abridged(src, self.max_frame_size) + } +} + +impl Encoder for AbridgedCodec { + type Error = io::Error; + + fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { + encode_abridged(&frame, dst) + } +} + +impl FrameCodecTrait for AbridgedCodec { + fn proto_tag(&self) -> ProtoTag { + ProtoTag::Abridged + } + + fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result { + let before = dst.len(); + encode_abridged(frame, dst)?; + Ok(dst.len() - before) + } + + fn decode(&self, src: &mut BytesMut) -> io::Result> { + decode_abridged(src, self.max_frame_size) + } + + fn min_header_size(&self) -> usize { + 1 + } +} + +/// Intermediate protocol codec +pub struct IntermediateCodec { + max_frame_size: usize, +} + +impl IntermediateCodec { + pub fn new() -> Self { + Self { + max_frame_size: 16 * 1024 * 1024, + } + } +} + +impl Default for IntermediateCodec { + fn default() -> Self { + Self::new() + } +} + +impl Decoder for IntermediateCodec { + type Item = Frame; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + decode_intermediate(src, self.max_frame_size) + } +} + +impl Encoder for IntermediateCodec { + type Error = io::Error; + + fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { + encode_intermediate(&frame, dst) + } +} + +impl FrameCodecTrait for IntermediateCodec { + fn proto_tag(&self) -> ProtoTag { + ProtoTag::Intermediate + } + + fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result { + let before = dst.len(); + encode_intermediate(frame, dst)?; + Ok(dst.len() - before) + } + + fn decode(&self, src: &mut BytesMut) -> io::Result> { + decode_intermediate(src, self.max_frame_size) + } + + fn min_header_size(&self) -> usize { + 4 + } +} + +/// Secure Intermediate protocol codec +pub struct SecureCodec { + max_frame_size: usize, +} + +impl SecureCodec { + pub fn new() -> Self { + Self { + max_frame_size: 16 * 1024 * 1024, + } + } +} + +impl Default for SecureCodec { + fn default() -> Self { + Self::new() + } +} + +impl Decoder for SecureCodec { + type Item = Frame; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + decode_secure(src, self.max_frame_size) + } +} + +impl Encoder for SecureCodec { + type Error = io::Error; + + fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { + encode_secure(&frame, dst) + } +} + +impl FrameCodecTrait for SecureCodec { + fn proto_tag(&self) -> ProtoTag { + ProtoTag::Secure + } + + fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result { + let before = dst.len(); + encode_secure(frame, dst)?; + Ok(dst.len() - before) + } + + fn decode(&self, src: &mut BytesMut) -> io::Result> { + decode_secure(src, self.max_frame_size) + } + + fn min_header_size(&self) -> usize { + 4 + } +} + +// ============= Tests ============= + +#[cfg(test)] +mod tests { + use super::*; + use tokio_util::codec::{FramedRead, FramedWrite}; + use tokio::io::duplex; + use futures::{SinkExt, StreamExt}; + + #[tokio::test] + async fn test_framed_abridged() { + let (client, server) = duplex(4096); + + let mut writer = FramedWrite::new(client, AbridgedCodec::new()); + let mut reader = FramedRead::new(server, AbridgedCodec::new()); + + // Write a frame + let frame = Frame::new(Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8])); + writer.send(frame).await.unwrap(); + + // Read it back + let received = reader.next().await.unwrap().unwrap(); + assert_eq!(&received.data[..], &[1, 2, 3, 4, 5, 6, 7, 8]); + } + + #[tokio::test] + async fn test_framed_intermediate() { + let (client, server) = duplex(4096); + + let mut writer = FramedWrite::new(client, IntermediateCodec::new()); + let mut reader = FramedRead::new(server, IntermediateCodec::new()); + + let frame = Frame::new(Bytes::from_static(b"hello world")); + writer.send(frame).await.unwrap(); + + let received = reader.next().await.unwrap().unwrap(); + assert_eq!(&received.data[..], b"hello world"); + } + + #[tokio::test] + async fn test_framed_secure() { + let (client, server) = duplex(4096); + + let mut writer = FramedWrite::new(client, SecureCodec::new()); + let mut reader = FramedRead::new(server, SecureCodec::new()); + + let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); + let frame = Frame::new(original.clone()); + writer.send(frame).await.unwrap(); + + let received = reader.next().await.unwrap().unwrap(); + assert_eq!(&received.data[..], &original[..]); + } + + #[tokio::test] + async fn test_unified_codec() { + for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] { + let (client, server) = duplex(4096); + + let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag)); + let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag)); + + // Use 4-byte aligned data for abridged compatibility + let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); + let frame = Frame::new(original.clone()); + writer.send(frame).await.unwrap(); + + let received = reader.next().await.unwrap().unwrap(); + assert_eq!(received.data.len(), 8); + } + } + + #[tokio::test] + async fn test_multiple_frames() { + let (client, server) = duplex(4096); + + let mut writer = FramedWrite::new(client, IntermediateCodec::new()); + let mut reader = FramedRead::new(server, IntermediateCodec::new()); + + // Send multiple frames + for i in 0..10 { + let data: Vec = (0..((i + 1) * 10)).map(|j| (j % 256) as u8).collect(); + let frame = Frame::new(Bytes::from(data)); + writer.send(frame).await.unwrap(); + } + + // Receive them + for i in 0..10 { + let received = reader.next().await.unwrap().unwrap(); + assert_eq!(received.data.len(), (i + 1) * 10); + } + } + + #[tokio::test] + async fn test_quickack_flag() { + let (client, server) = duplex(4096); + + let mut writer = FramedWrite::new(client, IntermediateCodec::new()); + let mut reader = FramedRead::new(server, IntermediateCodec::new()); + + let frame = Frame::quickack(Bytes::from_static(b"urgent")); + writer.send(frame).await.unwrap(); + + let received = reader.next().await.unwrap().unwrap(); + assert!(received.meta.quickack); + } + + #[test] + fn test_frame_too_large() { + let mut codec = FrameCodec::new(ProtoTag::Intermediate) + .with_max_frame_size(100); + + // Create a "frame" that claims to be very large + let mut buf = BytesMut::new(); + buf.extend_from_slice(&1000u32.to_le_bytes()); // length = 1000 + buf.extend_from_slice(&[0u8; 10]); // partial data + + let result = codec.decode(&mut buf); + assert!(result.is_err()); + } +} \ No newline at end of file diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 2f5e545..a86b56f 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -5,6 +5,10 @@ pub mod buffer_pool; pub mod traits; pub mod crypto_stream; pub mod tls_stream; +pub mod frame; +pub mod frame_codec; + +// Legacy compatibility - will be removed later pub mod frame_stream; // Re-export state machine types @@ -19,4 +23,21 @@ pub use buffer_pool::{BufferPool, PooledBuffer, PoolStats}; // Re-export stream implementations pub use crypto_stream::{CryptoReader, CryptoWriter, PassthroughStream}; pub use tls_stream::{FakeTlsReader, FakeTlsWriter}; -pub use frame_stream::*; \ No newline at end of file + +// Re-export frame types +pub use frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait, create_codec}; + +// Re-export tokio-util compatible codecs +pub use frame_codec::{ + FrameCodec, + AbridgedCodec, IntermediateCodec, SecureCodec, +}; + +// Legacy re-exports for compatibility +pub use frame_stream::{ + AbridgedFrameReader, AbridgedFrameWriter, + IntermediateFrameReader, IntermediateFrameWriter, + SecureIntermediateFrameReader, SecureIntermediateFrameWriter, + MtprotoFrameReader, MtprotoFrameWriter, + FrameReaderKind, FrameWriterKind, +}; \ No newline at end of file