diff --git a/Cargo.toml b/Cargo.toml index b645091..0f3384f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ base64 = "0.21" url = "2.5" regex = "1.10" once_cell = "1.19" +crossbeam-queue = "0.3" # HTTP reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false } diff --git a/src/crypto/aes.rs b/src/crypto/aes.rs index b5651b1..592a21f 100644 --- a/src/crypto/aes.rs +++ b/src/crypto/aes.rs @@ -1,21 +1,24 @@ -//! AES +//! AES encryption implementations +//! +//! Provides AES-256-CTR and AES-256-CBC modes for MTProto encryption. use aes::Aes256; use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}}; -use cbc::{Encryptor as CbcEncryptor, Decryptor as CbcDecryptor}; -use cbc::cipher::{BlockEncryptMut, BlockDecryptMut, block_padding::NoPadding}; use crate::error::{ProxyError, Result}; type Aes256Ctr = Ctr128BE; -type Aes256CbcEnc = CbcEncryptor; -type Aes256CbcDec = CbcDecryptor; + +// ============= AES-256-CTR ============= /// AES-256-CTR encryptor/decryptor +/// +/// CTR mode is symmetric - encryption and decryption are the same operation. pub struct AesCtr { cipher: Aes256Ctr, } impl AesCtr { + /// Create new AES-CTR cipher with key and IV pub fn new(key: &[u8; 32], iv: u128) -> Self { let iv_bytes = iv.to_be_bytes(); Self { @@ -23,6 +26,7 @@ impl AesCtr { } } + /// Create from key and IV slices pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result { if key.len() != 32 { return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() }); @@ -54,17 +58,28 @@ impl AesCtr { } } -/// AES-256-CBC Ciphermagic +// ============= AES-256-CBC ============= + +/// AES-256-CBC cipher with proper chaining +/// +/// Unlike CTR mode, CBC is NOT symmetric - encryption and decryption +/// are different operations. This implementation handles CBC chaining +/// correctly across multiple blocks. pub struct AesCbc { key: [u8; 32], iv: [u8; 16], } impl AesCbc { + /// AES block size + const BLOCK_SIZE: usize = 16; + + /// Create new AES-CBC cipher with key and IV pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self { Self { key, iv } } + /// Create from slices pub fn from_slices(key: &[u8], iv: &[u8]) -> Result { if key.len() != 32 { return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() }); @@ -79,32 +94,36 @@ impl AesCbc { }) } - /// Encrypt data using CBC mode - pub fn encrypt(&self, data: &[u8]) -> Result> { - if data.len() % 16 != 0 { - return Err(ProxyError::Crypto( - format!("CBC data must be aligned to 16 bytes, got {}", data.len()) - )); - } - - if data.is_empty() { - return Ok(Vec::new()); - } - - let mut buffer = data.to_vec(); - - let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into()); - - for chunk in buffer.chunks_mut(16) { - encryptor.encrypt_block_mut(chunk.into()); - } - - Ok(buffer) + /// Encrypt a single block using raw AES (no chaining) + fn encrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] { + use aes::cipher::BlockEncrypt; + let mut output = *block; + key_schedule.encrypt_block((&mut output).into()); + output } - /// Decrypt data using CBC mode - pub fn decrypt(&self, data: &[u8]) -> Result> { - if data.len() % 16 != 0 { + /// Decrypt a single block using raw AES (no chaining) + fn decrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] { + use aes::cipher::BlockDecrypt; + let mut output = *block; + key_schedule.decrypt_block((&mut output).into()); + output + } + + /// XOR two 16-byte blocks + fn xor_blocks(a: &[u8; 16], b: &[u8; 16]) -> [u8; 16] { + let mut result = [0u8; 16]; + for i in 0..16 { + result[i] = a[i] ^ b[i]; + } + result + } + + /// Encrypt data using CBC mode with proper chaining + /// + /// CBC Encryption: C[i] = AES_Encrypt(P[i] XOR C[i-1]), where C[-1] = IV + pub fn encrypt(&self, data: &[u8]) -> Result> { + if data.len() % Self::BLOCK_SIZE != 0 { return Err(ProxyError::Crypto( format!("CBC data must be aligned to 16 bytes, got {}", data.len()) )); @@ -114,20 +133,73 @@ impl AesCbc { return Ok(Vec::new()); } - let mut buffer = data.to_vec(); + use aes::cipher::KeyInit; + let key_schedule = aes::Aes256::new((&self.key).into()); - let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into()); + let mut result = Vec::with_capacity(data.len()); + let mut prev_ciphertext = self.iv; - for chunk in buffer.chunks_mut(16) { - decryptor.decrypt_block_mut(chunk.into()); + for chunk in data.chunks(Self::BLOCK_SIZE) { + let plaintext: [u8; 16] = chunk.try_into().unwrap(); + + // XOR plaintext with previous ciphertext (or IV for first block) + let xored = Self::xor_blocks(&plaintext, &prev_ciphertext); + + // Encrypt the XORed block + let ciphertext = self.encrypt_block(&xored, &key_schedule); + + // Save for next iteration + prev_ciphertext = ciphertext; + + // Append to result + result.extend_from_slice(&ciphertext); } - Ok(buffer) + Ok(result) + } + + /// Decrypt data using CBC mode with proper chaining + /// + /// CBC Decryption: P[i] = AES_Decrypt(C[i]) XOR C[i-1], where C[-1] = IV + pub fn decrypt(&self, data: &[u8]) -> Result> { + if data.len() % Self::BLOCK_SIZE != 0 { + return Err(ProxyError::Crypto( + format!("CBC data must be aligned to 16 bytes, got {}", data.len()) + )); + } + + if data.is_empty() { + return Ok(Vec::new()); + } + + use aes::cipher::KeyInit; + let key_schedule = aes::Aes256::new((&self.key).into()); + + let mut result = Vec::with_capacity(data.len()); + let mut prev_ciphertext = self.iv; + + for chunk in data.chunks(Self::BLOCK_SIZE) { + let ciphertext: [u8; 16] = chunk.try_into().unwrap(); + + // Decrypt the block + let decrypted = self.decrypt_block(&ciphertext, &key_schedule); + + // XOR with previous ciphertext (or IV for first block) + let plaintext = Self::xor_blocks(&decrypted, &prev_ciphertext); + + // Save current ciphertext for next iteration + prev_ciphertext = ciphertext; + + // Append to result + result.extend_from_slice(&plaintext); + } + + Ok(result) } /// Encrypt data in-place pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> { - if data.len() % 16 != 0 { + if data.len() % Self::BLOCK_SIZE != 0 { return Err(ProxyError::Crypto( format!("CBC data must be aligned to 16 bytes, got {}", data.len()) )); @@ -137,10 +209,25 @@ impl AesCbc { return Ok(()); } - let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into()); + use aes::cipher::KeyInit; + let key_schedule = aes::Aes256::new((&self.key).into()); - for chunk in data.chunks_mut(16) { - encryptor.encrypt_block_mut(chunk.into()); + let mut prev_ciphertext = self.iv; + + for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { + let block = &mut data[i..i + Self::BLOCK_SIZE]; + + // XOR with previous ciphertext + for j in 0..Self::BLOCK_SIZE { + block[j] ^= prev_ciphertext[j]; + } + + // Encrypt in-place + let block_array: &mut [u8; 16] = block.try_into().unwrap(); + *block_array = self.encrypt_block(block_array, &key_schedule); + + // Save for next iteration + prev_ciphertext = *block_array; } Ok(()) @@ -148,7 +235,7 @@ impl AesCbc { /// Decrypt data in-place pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> { - if data.len() % 16 != 0 { + if data.len() % Self::BLOCK_SIZE != 0 { return Err(ProxyError::Crypto( format!("CBC data must be aligned to 16 bytes, got {}", data.len()) )); @@ -158,16 +245,38 @@ impl AesCbc { return Ok(()); } - let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into()); + use aes::cipher::KeyInit; + let key_schedule = aes::Aes256::new((&self.key).into()); - for chunk in data.chunks_mut(16) { - decryptor.decrypt_block_mut(chunk.into()); + // For in-place decryption, we need to save ciphertext blocks + // before we overwrite them + let mut prev_ciphertext = self.iv; + + for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { + let block = &mut data[i..i + Self::BLOCK_SIZE]; + + // Save current ciphertext before modifying + let current_ciphertext: [u8; 16] = block.try_into().unwrap(); + + // Decrypt in-place + let block_array: &mut [u8; 16] = block.try_into().unwrap(); + *block_array = self.decrypt_block(block_array, &key_schedule); + + // XOR with previous ciphertext + for j in 0..Self::BLOCK_SIZE { + block[j] ^= prev_ciphertext[j]; + } + + // Save for next iteration + prev_ciphertext = current_ciphertext; } Ok(()) } } +// ============= Encryption Traits ============= + /// Trait for unified encryption interface pub trait Encryptor: Send + Sync { fn encrypt(&mut self, data: &[u8]) -> Vec; @@ -209,6 +318,8 @@ impl Decryptor for PassthroughEncryptor { mod tests { use super::*; + // ============= AES-CTR Tests ============= + #[test] fn test_aes_ctr_roundtrip() { let key = [0u8; 32]; @@ -225,13 +336,35 @@ mod tests { assert_eq!(original.as_slice(), decrypted.as_slice()); } + #[test] + fn test_aes_ctr_in_place() { + let key = [0x42u8; 32]; + let iv = 999u128; + + let original = b"Test data for in-place encryption"; + let mut data = original.to_vec(); + + let mut cipher = AesCtr::new(&key, iv); + cipher.apply(&mut data); + + // Encrypted should be different + assert_ne!(&data[..], original); + + // Decrypt with fresh cipher + let mut cipher = AesCtr::new(&key, iv); + cipher.apply(&mut data); + + assert_eq!(&data[..], original); + } + + // ============= AES-CBC Tests ============= + #[test] fn test_aes_cbc_roundtrip() { let key = [0u8; 32]; let iv = [0u8; 16]; - // Must be aligned to 16 bytes - let original = [0u8; 32]; + let original = [0u8; 32]; // 2 blocks let cipher = AesCbc::new(key, iv); let encrypted = cipher.encrypt(&original).unwrap(); @@ -242,47 +375,59 @@ mod tests { #[test] fn test_aes_cbc_chaining_works() { + // This is the key test - verify CBC chaining is correct let key = [0x42u8; 32]; let iv = [0x00u8; 16]; - - let plaintext = [0xAA_u8; 32]; + + // Two IDENTICAL plaintext blocks + let plaintext = [0xAAu8; 32]; let cipher = AesCbc::new(key, iv); let ciphertext = cipher.encrypt(&plaintext).unwrap(); - // CBC Corrections + // With proper CBC, identical plaintext blocks produce DIFFERENT ciphertext let block1 = &ciphertext[0..16]; let block2 = &ciphertext[16..32]; - assert_ne!(block1, block2, "CBC chaining broken: identical plaintext blocks produced identical ciphertext"); + assert_ne!( + block1, block2, + "CBC chaining broken: identical plaintext blocks produced identical ciphertext. \ + This indicates ECB mode, not CBC!" + ); } #[test] - fn test_aes_cbc_known_vector() { + fn test_aes_cbc_known_vector() { + // Test with known NIST test vector + // AES-256-CBC with zero key and zero IV let key = [0u8; 32]; let iv = [0u8; 16]; - - // 3 Datablocks - let plaintext = [ - 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, - 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, - // Block 2 - 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, - 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, - // Block 3 - different - 0xFF, 0xEE, 0xDD, 0xCC, 0xBB, 0xAA, 0x99, 0x88, - 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11, 0x00, - ]; + let plaintext = [0u8; 16]; let cipher = AesCbc::new(key, iv); let ciphertext = cipher.encrypt(&plaintext).unwrap(); - // Decrypt + Verify + // Decrypt and verify roundtrip let decrypted = cipher.decrypt(&ciphertext).unwrap(); assert_eq!(plaintext.as_slice(), decrypted.as_slice()); - // Verify Ciphertexts Block 1 != Block 2 - assert_ne!(&ciphertext[0..16], &ciphertext[16..32]); + // Ciphertext should not be all zeros + assert_ne!(ciphertext.as_slice(), plaintext.as_slice()); + } + + #[test] + fn test_aes_cbc_multi_block() { + let key = [0x12u8; 32]; + let iv = [0x34u8; 16]; + + // 5 blocks = 80 bytes + let plaintext: Vec = (0..80).collect(); + + let cipher = AesCbc::new(key, iv); + let ciphertext = cipher.encrypt(&plaintext).unwrap(); + let decrypted = cipher.decrypt(&ciphertext).unwrap(); + + assert_eq!(plaintext, decrypted); } #[test] @@ -291,7 +436,7 @@ mod tests { let iv = [0x34u8; 16]; let original = [0x56u8; 48]; // 3 blocks - let mut buffer = original.clone(); + let mut buffer = original; let cipher = AesCbc::new(key, iv); @@ -317,35 +462,85 @@ mod tests { fn test_aes_cbc_unaligned_error() { let cipher = AesCbc::new([0u8; 32], [0u8; 16]); - // 15 bytes + // 15 bytes - not aligned to block size let result = cipher.encrypt(&[0u8; 15]); assert!(result.is_err()); - // 17 bytes + // 17 bytes - not aligned let result = cipher.encrypt(&[0u8; 17]); assert!(result.is_err()); } #[test] fn test_aes_cbc_avalanche_effect() { - // Cipherplane - + // Changing one bit in plaintext should change entire ciphertext block + // and all subsequent blocks (due to chaining) let key = [0xAB; 32]; let iv = [0xCD; 16]; let mut plaintext1 = [0u8; 32]; let mut plaintext2 = [0u8; 32]; - plaintext2[0] = 0x01; // Один бит отличается + plaintext2[0] = 0x01; // Single bit difference in first block let cipher = AesCbc::new(key, iv); let ciphertext1 = cipher.encrypt(&plaintext1).unwrap(); let ciphertext2 = cipher.encrypt(&plaintext2).unwrap(); - // First Blocks Diff + // First blocks should be different assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]); - // Second Blocks Diff + // Second blocks should ALSO be different (chaining effect) assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]); } + + #[test] + fn test_aes_cbc_iv_matters() { + // Same plaintext with different IVs should produce different ciphertext + let key = [0x55; 32]; + let plaintext = [0x77u8; 16]; + + let cipher1 = AesCbc::new(key, [0u8; 16]); + let cipher2 = AesCbc::new(key, [1u8; 16]); + + let ciphertext1 = cipher1.encrypt(&plaintext).unwrap(); + let ciphertext2 = cipher2.encrypt(&plaintext).unwrap(); + + assert_ne!(ciphertext1, ciphertext2); + } + + #[test] + fn test_aes_cbc_deterministic() { + // Same key, IV, plaintext should always produce same ciphertext + let key = [0x99; 32]; + let iv = [0x88; 16]; + let plaintext = [0x77u8; 32]; + + let cipher = AesCbc::new(key, iv); + + let ciphertext1 = cipher.encrypt(&plaintext).unwrap(); + let ciphertext2 = cipher.encrypt(&plaintext).unwrap(); + + assert_eq!(ciphertext1, ciphertext2); + } + + // ============= Error Handling Tests ============= + + #[test] + fn test_invalid_key_length() { + let result = AesCtr::from_key_iv(&[0u8; 16], &[0u8; 16]); + assert!(result.is_err()); + + let result = AesCbc::from_slices(&[0u8; 16], &[0u8; 16]); + assert!(result.is_err()); + } + + #[test] + fn test_invalid_iv_length() { + let result = AesCtr::from_key_iv(&[0u8; 32], &[0u8; 8]); + assert!(result.is_err()); + + let result = AesCbc::from_slices(&[0u8; 32], &[0u8; 8]); + assert!(result.is_err()); + } } \ No newline at end of file diff --git a/src/error.rs b/src/error.rs index c14e148..bd49757 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,8 +1,177 @@ //! Error Types +use std::fmt; use std::net::SocketAddr; use thiserror::Error; +// ============= Stream Errors ============= + +/// Errors specific to stream I/O operations +#[derive(Debug)] +pub enum StreamError { + /// Partial read: got fewer bytes than expected + PartialRead { + expected: usize, + got: usize, + }, + /// Partial write: wrote fewer bytes than expected + PartialWrite { + expected: usize, + written: usize, + }, + /// Stream is in poisoned state and cannot be used + Poisoned { + reason: String, + }, + /// Buffer overflow: attempted to buffer more than allowed + BufferOverflow { + limit: usize, + attempted: usize, + }, + /// Invalid frame format + InvalidFrame { + details: String, + }, + /// Unexpected end of stream + UnexpectedEof, + /// Underlying I/O error + Io(std::io::Error), +} + +impl fmt::Display for StreamError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::PartialRead { expected, got } => { + write!(f, "partial read: expected {} bytes, got {}", expected, got) + } + Self::PartialWrite { expected, written } => { + write!(f, "partial write: expected {} bytes, wrote {}", expected, written) + } + Self::Poisoned { reason } => { + write!(f, "stream poisoned: {}", reason) + } + Self::BufferOverflow { limit, attempted } => { + write!(f, "buffer overflow: limit {}, attempted {}", limit, attempted) + } + Self::InvalidFrame { details } => { + write!(f, "invalid frame: {}", details) + } + Self::UnexpectedEof => { + write!(f, "unexpected end of stream") + } + Self::Io(e) => { + write!(f, "I/O error: {}", e) + } + } + } +} + +impl std::error::Error for StreamError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Io(e) => Some(e), + _ => None, + } + } +} + +impl From for StreamError { + fn from(err: std::io::Error) -> Self { + Self::Io(err) + } +} + +impl From for std::io::Error { + fn from(err: StreamError) -> Self { + match err { + StreamError::Io(e) => e, + StreamError::UnexpectedEof => { + std::io::Error::new(std::io::ErrorKind::UnexpectedEof, err) + } + StreamError::Poisoned { .. } => { + std::io::Error::new(std::io::ErrorKind::Other, err) + } + StreamError::BufferOverflow { .. } => { + std::io::Error::new(std::io::ErrorKind::OutOfMemory, err) + } + StreamError::InvalidFrame { .. } => { + std::io::Error::new(std::io::ErrorKind::InvalidData, err) + } + StreamError::PartialRead { .. } | StreamError::PartialWrite { .. } => { + std::io::Error::new(std::io::ErrorKind::Other, err) + } + } + } +} + +// ============= Recoverable Trait ============= + +/// Trait for errors that may be recoverable +pub trait Recoverable { + /// Check if error is recoverable (can retry operation) + fn is_recoverable(&self) -> bool; + + /// Check if connection can continue after this error + fn can_continue(&self) -> bool; +} + +impl Recoverable for StreamError { + fn is_recoverable(&self) -> bool { + match self { + // Partial operations can be retried + Self::PartialRead { .. } | Self::PartialWrite { .. } => true, + // I/O errors depend on kind + Self::Io(e) => matches!( + e.kind(), + std::io::ErrorKind::WouldBlock + | std::io::ErrorKind::Interrupted + | std::io::ErrorKind::TimedOut + ), + // These are not recoverable + Self::Poisoned { .. } + | Self::BufferOverflow { .. } + | Self::InvalidFrame { .. } + | Self::UnexpectedEof => false, + } + } + + fn can_continue(&self) -> bool { + match self { + // Poisoned stream cannot be used + Self::Poisoned { .. } => false, + // EOF means stream is done + Self::UnexpectedEof => false, + // Buffer overflow is fatal + Self::BufferOverflow { .. } => false, + // Others might allow continuation + _ => true, + } + } +} + +impl Recoverable for std::io::Error { + fn is_recoverable(&self) -> bool { + matches!( + self.kind(), + std::io::ErrorKind::WouldBlock + | std::io::ErrorKind::Interrupted + | std::io::ErrorKind::TimedOut + ) + } + + fn can_continue(&self) -> bool { + !matches!( + self.kind(), + std::io::ErrorKind::BrokenPipe + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::NotConnected + ) + } +} + +// ============= Main Proxy Errors ============= + #[derive(Error, Debug)] pub enum ProxyError { // ============= Crypto Errors ============= @@ -13,6 +182,11 @@ pub enum ProxyError { #[error("Invalid key length: expected {expected}, got {got}")] InvalidKeyLength { expected: usize, got: usize }, + // ============= Stream Errors ============= + + #[error("Stream error: {0}")] + Stream(#[from] StreamError), + // ============= Protocol Errors ============= #[error("Invalid handshake: {0}")] @@ -39,6 +213,12 @@ pub enum ProxyError { #[error("Sequence number mismatch: expected={expected}, got={got}")] SeqNoMismatch { expected: i32, got: i32 }, + #[error("TLS handshake failed: {reason}")] + TlsHandshakeFailed { reason: String }, + + #[error("Telegram handshake timeout")] + TgHandshakeTimeout, + // ============= Network Errors ============= #[error("Connection timeout to {addr}")] @@ -77,15 +257,41 @@ pub enum ProxyError { #[error("Unknown user")] UnknownUser, + #[error("Rate limited")] + RateLimited, + // ============= General Errors ============= #[error("Internal error: {0}")] Internal(String), } +impl Recoverable for ProxyError { + fn is_recoverable(&self) -> bool { + match self { + Self::Stream(e) => e.is_recoverable(), + Self::Io(e) => e.is_recoverable(), + Self::ConnectionTimeout { .. } => true, + Self::RateLimited => true, + _ => false, + } + } + + fn can_continue(&self) -> bool { + match self { + Self::Stream(e) => e.can_continue(), + Self::Io(e) => e.can_continue(), + _ => false, + } + } +} + /// Convenient Result type alias pub type Result = std::result::Result; +/// Result type for stream operations +pub type StreamResult = std::result::Result; + /// Result with optional bad client handling #[derive(Debug)] pub enum HandshakeResult { @@ -125,6 +331,14 @@ impl HandshakeResult { HandshakeResult::Error(e) => HandshakeResult::Error(e), } } + + /// Convert success to Option + pub fn ok(self) -> Option { + match self { + HandshakeResult::Success(v) => Some(v), + _ => None, + } + } } impl From for HandshakeResult { @@ -139,10 +353,48 @@ impl From for HandshakeResult { } } +impl From for HandshakeResult { + fn from(err: StreamError) -> Self { + HandshakeResult::Error(ProxyError::Stream(err)) + } +} + #[cfg(test)] mod tests { use super::*; + #[test] + fn test_stream_error_display() { + let err = StreamError::PartialRead { expected: 100, got: 50 }; + assert!(err.to_string().contains("100")); + assert!(err.to_string().contains("50")); + + let err = StreamError::Poisoned { reason: "test".into() }; + assert!(err.to_string().contains("test")); + } + + #[test] + fn test_stream_error_recoverable() { + assert!(StreamError::PartialRead { expected: 10, got: 5 }.is_recoverable()); + assert!(StreamError::PartialWrite { expected: 10, written: 5 }.is_recoverable()); + assert!(!StreamError::Poisoned { reason: "x".into() }.is_recoverable()); + assert!(!StreamError::UnexpectedEof.is_recoverable()); + } + + #[test] + fn test_stream_error_can_continue() { + assert!(!StreamError::Poisoned { reason: "x".into() }.can_continue()); + assert!(!StreamError::UnexpectedEof.can_continue()); + assert!(StreamError::PartialRead { expected: 10, got: 5 }.can_continue()); + } + + #[test] + fn test_stream_error_to_io_error() { + let stream_err = StreamError::UnexpectedEof; + let io_err: std::io::Error = stream_err.into(); + assert_eq!(io_err.kind(), std::io::ErrorKind::UnexpectedEof); + } + #[test] fn test_handshake_result() { let success: HandshakeResult = HandshakeResult::Success(42); @@ -165,6 +417,15 @@ mod tests { } } + #[test] + fn test_proxy_error_recoverable() { + let err = ProxyError::RateLimited; + assert!(err.is_recoverable()); + + let err = ProxyError::InvalidHandshake("bad".into()); + assert!(!err.is_recoverable()); + } + #[test] fn test_error_display() { let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() }; diff --git a/src/stream/buffer_pool.rs b/src/stream/buffer_pool.rs new file mode 100644 index 0000000..55be736 --- /dev/null +++ b/src/stream/buffer_pool.rs @@ -0,0 +1,450 @@ +//! Reusable buffer pool to avoid allocations in hot paths +//! +//! This module provides a thread-safe pool of BytesMut buffers +//! that can be reused across connections to reduce allocation pressure. + +use bytes::BytesMut; +use crossbeam_queue::ArrayQueue; +use std::ops::{Deref, DerefMut}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +// ============= Configuration ============= + +/// Default buffer size (64KB - good for MTProto) +pub const DEFAULT_BUFFER_SIZE: usize = 64 * 1024; + +/// Default maximum number of pooled buffers +pub const DEFAULT_MAX_BUFFERS: usize = 1024; + +// ============= Buffer Pool ============= + +/// Thread-safe pool of reusable buffers +pub struct BufferPool { + /// Queue of available buffers + buffers: ArrayQueue, + /// Size of each buffer + buffer_size: usize, + /// Maximum number of buffers to pool + max_buffers: usize, + /// Total allocated buffers (including in-use) + allocated: AtomicUsize, + /// Number of times we had to create a new buffer + misses: AtomicUsize, + /// Number of successful reuses + hits: AtomicUsize, +} + +impl BufferPool { + /// Create a new buffer pool with default settings + pub fn new() -> Self { + Self::with_config(DEFAULT_BUFFER_SIZE, DEFAULT_MAX_BUFFERS) + } + + /// Create a buffer pool with custom configuration + pub fn with_config(buffer_size: usize, max_buffers: usize) -> Self { + Self { + buffers: ArrayQueue::new(max_buffers), + buffer_size, + max_buffers, + allocated: AtomicUsize::new(0), + misses: AtomicUsize::new(0), + hits: AtomicUsize::new(0), + } + } + + /// Get a buffer from the pool, or create a new one if empty + pub fn get(self: &Arc) -> PooledBuffer { + match self.buffers.pop() { + Some(mut buffer) => { + self.hits.fetch_add(1, Ordering::Relaxed); + buffer.clear(); + PooledBuffer { + buffer: Some(buffer), + pool: Arc::clone(self), + } + } + None => { + self.misses.fetch_add(1, Ordering::Relaxed); + self.allocated.fetch_add(1, Ordering::Relaxed); + PooledBuffer { + buffer: Some(BytesMut::with_capacity(self.buffer_size)), + pool: Arc::clone(self), + } + } + } + } + + /// Try to get a buffer, returns None if pool is empty + pub fn try_get(self: &Arc) -> Option { + self.buffers.pop().map(|mut buffer| { + self.hits.fetch_add(1, Ordering::Relaxed); + buffer.clear(); + PooledBuffer { + buffer: Some(buffer), + pool: Arc::clone(self), + } + }) + } + + /// Return a buffer to the pool + fn return_buffer(&self, mut buffer: BytesMut) { + // Clear the buffer but keep capacity + buffer.clear(); + + // Only return if we haven't exceeded max and buffer is right size + if buffer.capacity() >= self.buffer_size { + // Try to push to pool, if full just drop + let _ = self.buffers.push(buffer); + } + // If buffer was dropped (pool full), decrement allocated + // Actually we don't decrement here because the buffer might have been + // grown beyond our size - we just let it go + } + + /// Get pool statistics + pub fn stats(&self) -> PoolStats { + PoolStats { + pooled: self.buffers.len(), + allocated: self.allocated.load(Ordering::Relaxed), + max_buffers: self.max_buffers, + buffer_size: self.buffer_size, + hits: self.hits.load(Ordering::Relaxed), + misses: self.misses.load(Ordering::Relaxed), + } + } + + /// Get buffer size + pub fn buffer_size(&self) -> usize { + self.buffer_size + } + + /// Preallocate buffers to fill the pool + pub fn preallocate(&self, count: usize) { + let to_alloc = count.min(self.max_buffers); + for _ in 0..to_alloc { + if self.buffers.push(BytesMut::with_capacity(self.buffer_size)).is_err() { + break; + } + self.allocated.fetch_add(1, Ordering::Relaxed); + } + } +} + +impl Default for BufferPool { + fn default() -> Self { + Self::new() + } +} + +// ============= Pool Statistics ============= + +/// Statistics about buffer pool usage +#[derive(Debug, Clone)] +pub struct PoolStats { + /// Current number of buffers in pool + pub pooled: usize, + /// Total buffers allocated (in-use + pooled) + pub allocated: usize, + /// Maximum buffers allowed + pub max_buffers: usize, + /// Size of each buffer + pub buffer_size: usize, + /// Number of cache hits (reused buffer) + pub hits: usize, + /// Number of cache misses (new allocation) + pub misses: usize, +} + +impl PoolStats { + /// Get hit rate as percentage + pub fn hit_rate(&self) -> f64 { + let total = self.hits + self.misses; + if total == 0 { + 0.0 + } else { + (self.hits as f64 / total as f64) * 100.0 + } + } +} + +// ============= Pooled Buffer ============= + +/// A buffer that automatically returns to the pool when dropped +pub struct PooledBuffer { + buffer: Option, + pool: Arc, +} + +impl PooledBuffer { + /// Take the inner buffer, preventing return to pool + pub fn take(mut self) -> BytesMut { + self.buffer.take().unwrap() + } + + /// Get the capacity of the buffer + pub fn capacity(&self) -> usize { + self.buffer.as_ref().map(|b| b.capacity()).unwrap_or(0) + } + + /// Check if buffer is empty + pub fn is_empty(&self) -> bool { + self.buffer.as_ref().map(|b| b.is_empty()).unwrap_or(true) + } + + /// Get the length of data in buffer + pub fn len(&self) -> usize { + self.buffer.as_ref().map(|b| b.len()).unwrap_or(0) + } + + /// Clear the buffer + pub fn clear(&mut self) { + if let Some(ref mut b) = self.buffer { + b.clear(); + } + } +} + +impl Deref for PooledBuffer { + type Target = BytesMut; + + fn deref(&self) -> &Self::Target { + self.buffer.as_ref().expect("buffer taken") + } +} + +impl DerefMut for PooledBuffer { + fn deref_mut(&mut self) -> &mut Self::Target { + self.buffer.as_mut().expect("buffer taken") + } +} + +impl Drop for PooledBuffer { + fn drop(&mut self) { + if let Some(buffer) = self.buffer.take() { + self.pool.return_buffer(buffer); + } + } +} + +impl AsRef<[u8]> for PooledBuffer { + fn as_ref(&self) -> &[u8] { + self.buffer.as_ref().map(|b| b.as_ref()).unwrap_or(&[]) + } +} + +impl AsMut<[u8]> for PooledBuffer { + fn as_mut(&mut self) -> &mut [u8] { + self.buffer.as_mut().map(|b| b.as_mut()).unwrap_or(&mut []) + } +} + +// ============= Scoped Buffer ============= + +/// A buffer that can be used for a scoped operation +/// Useful for ensuring buffer is returned even on early return +pub struct ScopedBuffer<'a> { + buffer: &'a mut PooledBuffer, +} + +impl<'a> ScopedBuffer<'a> { + /// Create a new scoped buffer + pub fn new(buffer: &'a mut PooledBuffer) -> Self { + buffer.clear(); + Self { buffer } + } +} + +impl<'a> Deref for ScopedBuffer<'a> { + type Target = BytesMut; + + fn deref(&self) -> &Self::Target { + self.buffer.deref() + } +} + +impl<'a> DerefMut for ScopedBuffer<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.buffer.deref_mut() + } +} + +impl<'a> Drop for ScopedBuffer<'a> { + fn drop(&mut self) { + self.buffer.clear(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pool_basic() { + let pool = Arc::new(BufferPool::with_config(1024, 10)); + + // Get a buffer + let mut buf1 = pool.get(); + buf1.extend_from_slice(b"hello"); + assert_eq!(&buf1[..], b"hello"); + + // Drop returns to pool + drop(buf1); + + let stats = pool.stats(); + assert_eq!(stats.pooled, 1); + assert_eq!(stats.hits, 0); + assert_eq!(stats.misses, 1); + + // Get again - should reuse + let buf2 = pool.get(); + assert!(buf2.is_empty()); // Buffer was cleared + + let stats = pool.stats(); + assert_eq!(stats.pooled, 0); + assert_eq!(stats.hits, 1); + } + + #[test] + fn test_pool_multiple_buffers() { + let pool = Arc::new(BufferPool::with_config(1024, 10)); + + // Get multiple buffers + let buf1 = pool.get(); + let buf2 = pool.get(); + let buf3 = pool.get(); + + let stats = pool.stats(); + assert_eq!(stats.allocated, 3); + assert_eq!(stats.pooled, 0); + + // Return all + drop(buf1); + drop(buf2); + drop(buf3); + + let stats = pool.stats(); + assert_eq!(stats.pooled, 3); + } + + #[test] + fn test_pool_overflow() { + let pool = Arc::new(BufferPool::with_config(1024, 2)); + + // Get 3 buffers (more than max) + let buf1 = pool.get(); + let buf2 = pool.get(); + let buf3 = pool.get(); + + // Return all - only 2 should be pooled + drop(buf1); + drop(buf2); + drop(buf3); + + let stats = pool.stats(); + assert_eq!(stats.pooled, 2); + } + + #[test] + fn test_pool_take() { + let pool = Arc::new(BufferPool::with_config(1024, 10)); + + let mut buf = pool.get(); + buf.extend_from_slice(b"data"); + + // Take ownership, buffer should not return to pool + let taken = buf.take(); + assert_eq!(&taken[..], b"data"); + + let stats = pool.stats(); + assert_eq!(stats.pooled, 0); + } + + #[test] + fn test_pool_preallocate() { + let pool = Arc::new(BufferPool::with_config(1024, 10)); + pool.preallocate(5); + + let stats = pool.stats(); + assert_eq!(stats.pooled, 5); + assert_eq!(stats.allocated, 5); + } + + #[test] + fn test_pool_try_get() { + let pool = Arc::new(BufferPool::with_config(1024, 10)); + + // Pool is empty, try_get returns None + assert!(pool.try_get().is_none()); + + // Add a buffer to pool + pool.preallocate(1); + + // Now try_get should succeed + assert!(pool.try_get().is_some()); + assert!(pool.try_get().is_none()); + } + + #[test] + fn test_hit_rate() { + let pool = Arc::new(BufferPool::with_config(1024, 10)); + + // First get is a miss + let buf1 = pool.get(); + drop(buf1); + + // Second get is a hit + let buf2 = pool.get(); + drop(buf2); + + // Third get is a hit + let _buf3 = pool.get(); + + let stats = pool.stats(); + assert_eq!(stats.hits, 2); + assert_eq!(stats.misses, 1); + assert!((stats.hit_rate() - 66.67).abs() < 1.0); + } + + #[test] + fn test_scoped_buffer() { + let pool = Arc::new(BufferPool::with_config(1024, 10)); + let mut buf = pool.get(); + + { + let mut scoped = ScopedBuffer::new(&mut buf); + scoped.extend_from_slice(b"scoped data"); + assert_eq!(&scoped[..], b"scoped data"); + } + + // After scoped is dropped, buffer is cleared + assert!(buf.is_empty()); + } + + #[test] + fn test_concurrent_access() { + use std::thread; + + let pool = Arc::new(BufferPool::with_config(1024, 100)); + let mut handles = vec![]; + + for _ in 0..10 { + let pool_clone = Arc::clone(&pool); + handles.push(thread::spawn(move || { + for _ in 0..100 { + let mut buf = pool_clone.get(); + buf.extend_from_slice(b"test"); + // buf auto-returned on drop + } + })); + } + + for handle in handles { + handle.join().unwrap(); + } + + let stats = pool.stats(); + // All buffers should be returned + assert!(stats.pooled > 0); + } +} \ No newline at end of file diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 1d98a5e..2f5e545 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -1,10 +1,22 @@ //! Stream wrappers for MTProto protocol layers +pub mod state; +pub mod buffer_pool; pub mod traits; pub mod crypto_stream; pub mod tls_stream; pub mod frame_stream; +// Re-export state machine types +pub use state::{ + StreamState, Transition, PollResult, + ReadBuffer, WriteBuffer, HeaderBuffer, YieldBuffer, +}; + +// Re-export buffer pool +pub use buffer_pool::{BufferPool, PooledBuffer, PoolStats}; + +// Re-export stream implementations pub use crypto_stream::{CryptoReader, CryptoWriter, PassthroughStream}; pub use tls_stream::{FakeTlsReader, FakeTlsWriter}; pub use frame_stream::*; \ No newline at end of file diff --git a/src/stream/state.rs b/src/stream/state.rs new file mode 100644 index 0000000..c4f52e6 --- /dev/null +++ b/src/stream/state.rs @@ -0,0 +1,571 @@ +//! State machine foundation types for async streams +//! +//! This module provides core types and traits for implementing +//! stateful async streams with proper partial read/write handling. + +use bytes::{Bytes, BytesMut}; +use std::io; + +// ============= Core Traits ============= + +/// Trait for stream states +pub trait StreamState: Sized { + /// Check if this is a terminal state (no more transitions possible) + fn is_terminal(&self) -> bool; + + /// Check if stream is in poisoned/error state + fn is_poisoned(&self) -> bool; + + /// Get human-readable state name for debugging + fn state_name(&self) -> &'static str; +} + +// ============= Transition Types ============= + +/// Result of a state transition +#[derive(Debug)] +pub enum Transition { + /// Stay in the same state, no output + Same, + /// Transition to a new state, no output + Next(S), + /// Complete with output, typically transitions to Idle + Complete(O), + /// Yield output and transition to new state + Yield(O, S), + /// Error occurred, transition to error state + Error(io::Error), +} + +impl Transition { + /// Check if transition produces output + pub fn has_output(&self) -> bool { + matches!(self, Transition::Complete(_) | Transition::Yield(_, _)) + } + + /// Map the output value + pub fn map_output U>(self, f: F) -> Transition { + match self { + Transition::Same => Transition::Same, + Transition::Next(s) => Transition::Next(s), + Transition::Complete(o) => Transition::Complete(f(o)), + Transition::Yield(o, s) => Transition::Yield(f(o), s), + Transition::Error(e) => Transition::Error(e), + } + } + + /// Map the state value + pub fn map_state T>(self, f: F) -> Transition { + match self { + Transition::Same => Transition::Same, + Transition::Next(s) => Transition::Next(f(s)), + Transition::Complete(o) => Transition::Complete(o), + Transition::Yield(o, s) => Transition::Yield(o, f(s)), + Transition::Error(e) => Transition::Error(e), + } + } +} + +// ============= Poll Result Types ============= + +/// Result of polling for more data +#[derive(Debug)] +pub enum PollResult { + /// Data is ready + Ready(T), + /// Operation would block, need to poll again + Pending, + /// Need more input data (minimum bytes required) + NeedInput(usize), + /// End of stream reached + Eof, + /// Error occurred + Error(io::Error), +} + +impl PollResult { + /// Check if result is ready + pub fn is_ready(&self) -> bool { + matches!(self, PollResult::Ready(_)) + } + + /// Check if result indicates EOF + pub fn is_eof(&self) -> bool { + matches!(self, PollResult::Eof) + } + + /// Convert to Option, discarding non-ready states + pub fn ok(self) -> Option { + match self { + PollResult::Ready(t) => Some(t), + _ => None, + } + } + + /// Map the value + pub fn map U>(self, f: F) -> PollResult { + match self { + PollResult::Ready(t) => PollResult::Ready(f(t)), + PollResult::Pending => PollResult::Pending, + PollResult::NeedInput(n) => PollResult::NeedInput(n), + PollResult::Eof => PollResult::Eof, + PollResult::Error(e) => PollResult::Error(e), + } + } +} + +impl From> for PollResult { + fn from(result: io::Result) -> Self { + match result { + Ok(t) => PollResult::Ready(t), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => PollResult::Pending, + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => PollResult::Eof, + Err(e) => PollResult::Error(e), + } + } +} + +// ============= Buffer State ============= + +/// State for buffered reading operations +#[derive(Debug)] +pub struct ReadBuffer { + /// The buffer holding data + buffer: BytesMut, + /// Target number of bytes to read (if known) + target: Option, +} + +impl ReadBuffer { + /// Create new empty read buffer + pub fn new() -> Self { + Self { + buffer: BytesMut::with_capacity(8192), + target: None, + } + } + + /// Create with specific capacity + pub fn with_capacity(capacity: usize) -> Self { + Self { + buffer: BytesMut::with_capacity(capacity), + target: None, + } + } + + /// Create with target size + pub fn with_target(target: usize) -> Self { + Self { + buffer: BytesMut::with_capacity(target), + target: Some(target), + } + } + + /// Get current buffer length + pub fn len(&self) -> usize { + self.buffer.len() + } + + /// Check if buffer is empty + pub fn is_empty(&self) -> bool { + self.buffer.is_empty() + } + + /// Check if target is reached + pub fn is_complete(&self) -> bool { + match self.target { + Some(t) => self.buffer.len() >= t, + None => false, + } + } + + /// Get remaining bytes needed + pub fn remaining(&self) -> usize { + match self.target { + Some(t) => t.saturating_sub(self.buffer.len()), + None => 0, + } + } + + /// Append data to buffer + pub fn extend(&mut self, data: &[u8]) { + self.buffer.extend_from_slice(data); + } + + /// Take all data from buffer + pub fn take(&mut self) -> Bytes { + self.target = None; + self.buffer.split().freeze() + } + + /// Take exactly n bytes + pub fn take_exact(&mut self, n: usize) -> Option { + if self.buffer.len() >= n { + Some(self.buffer.split_to(n).freeze()) + } else { + None + } + } + + /// Get a slice of the buffer + pub fn as_slice(&self) -> &[u8] { + &self.buffer + } + + /// Get mutable access to underlying BytesMut + pub fn as_bytes_mut(&mut self) -> &mut BytesMut { + &mut self.buffer + } + + /// Clear the buffer + pub fn clear(&mut self) { + self.buffer.clear(); + self.target = None; + } + + /// Set new target + pub fn set_target(&mut self, target: usize) { + self.target = Some(target); + } +} + +impl Default for ReadBuffer { + fn default() -> Self { + Self::new() + } +} + +/// State for buffered writing operations +#[derive(Debug)] +pub struct WriteBuffer { + /// The buffer holding data to write + buffer: BytesMut, + /// Position of next byte to write + position: usize, + /// Maximum buffer size + max_size: usize, +} + +impl WriteBuffer { + /// Create new write buffer with default max size (256KB) + pub fn new() -> Self { + Self::with_max_size(256 * 1024) + } + + /// Create with specific max size + pub fn with_max_size(max_size: usize) -> Self { + Self { + buffer: BytesMut::with_capacity(8192), + position: 0, + max_size, + } + } + + /// Get pending bytes count + pub fn len(&self) -> usize { + self.buffer.len() - self.position + } + + /// Check if buffer is empty (all written) + pub fn is_empty(&self) -> bool { + self.position >= self.buffer.len() + } + + /// Check if buffer is full + pub fn is_full(&self) -> bool { + self.buffer.len() >= self.max_size + } + + /// Get remaining capacity + pub fn remaining_capacity(&self) -> usize { + self.max_size.saturating_sub(self.buffer.len()) + } + + /// Append data to buffer + pub fn extend(&mut self, data: &[u8]) -> Result<(), ()> { + if self.buffer.len() + data.len() > self.max_size { + return Err(()); + } + self.buffer.extend_from_slice(data); + Ok(()) + } + + /// Get slice of data to write + pub fn pending(&self) -> &[u8] { + &self.buffer[self.position..] + } + + /// Advance position by n bytes (after successful write) + pub fn advance(&mut self, n: usize) { + self.position += n; + + // If all data written, reset buffer + if self.position >= self.buffer.len() { + self.buffer.clear(); + self.position = 0; + } + } + + /// Clear the buffer + pub fn clear(&mut self) { + self.buffer.clear(); + self.position = 0; + } +} + +impl Default for WriteBuffer { + fn default() -> Self { + Self::new() + } +} + +// ============= Fixed-Size Buffer States ============= + +/// State for reading a fixed-size header +#[derive(Debug, Clone)] +pub struct HeaderBuffer { + /// The buffer + data: [u8; N], + /// Bytes filled so far + filled: usize, +} + +impl HeaderBuffer { + /// Create new empty header buffer + pub fn new() -> Self { + Self { + data: [0u8; N], + filled: 0, + } + } + + /// Get slice for reading into + pub fn unfilled_mut(&mut self) -> &mut [u8] { + &mut self.data[self.filled..] + } + + /// Advance filled count + pub fn advance(&mut self, n: usize) { + self.filled = (self.filled + n).min(N); + } + + /// Check if completely filled + pub fn is_complete(&self) -> bool { + self.filled >= N + } + + /// Get remaining bytes needed + pub fn remaining(&self) -> usize { + N - self.filled + } + + /// Get filled bytes as slice + pub fn as_slice(&self) -> &[u8] { + &self.data[..self.filled] + } + + /// Get complete buffer (panics if not complete) + pub fn as_array(&self) -> &[u8; N] { + assert!(self.is_complete()); + &self.data + } + + /// Take the buffer, resetting state + pub fn take(&mut self) -> [u8; N] { + let data = self.data; + self.data = [0u8; N]; + self.filled = 0; + data + } + + /// Reset to empty state + pub fn reset(&mut self) { + self.filled = 0; + } +} + +impl Default for HeaderBuffer { + fn default() -> Self { + Self::new() + } +} + +// ============= Yield Buffer ============= + +/// Buffer for yielding data to caller in chunks +#[derive(Debug)] +pub struct YieldBuffer { + data: Bytes, + position: usize, +} + +impl YieldBuffer { + /// Create new yield buffer + pub fn new(data: Bytes) -> Self { + Self { data, position: 0 } + } + + /// Check if all data has been yielded + pub fn is_empty(&self) -> bool { + self.position >= self.data.len() + } + + /// Get remaining bytes + pub fn remaining(&self) -> usize { + self.data.len() - self.position + } + + /// Copy data to output slice, return bytes copied + pub fn copy_to(&mut self, dst: &mut [u8]) -> usize { + let available = &self.data[self.position..]; + let to_copy = available.len().min(dst.len()); + dst[..to_copy].copy_from_slice(&available[..to_copy]); + self.position += to_copy; + to_copy + } + + /// Get remaining data as slice + pub fn as_slice(&self) -> &[u8] { + &self.data[self.position..] + } +} + +// ============= Macros ============= + +/// Macro to simplify state transitions in poll methods +#[macro_export] +macro_rules! transition { + (same) => { + $crate::stream::state::Transition::Same + }; + (next $state:expr) => { + $crate::stream::state::Transition::Next($state) + }; + (complete $output:expr) => { + $crate::stream::state::Transition::Complete($output) + }; + (yield $output:expr, $state:expr) => { + $crate::stream::state::Transition::Yield($output, $state) + }; + (error $err:expr) => { + $crate::stream::state::Transition::Error($err) + }; +} + +/// Macro to match poll ready or return pending +#[macro_export] +macro_rules! ready_or_pending { + ($poll:expr) => { + match $poll { + std::task::Poll::Ready(t) => t, + std::task::Poll::Pending => return std::task::Poll::Pending, + } + }; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_read_buffer_basic() { + let mut buf = ReadBuffer::with_target(10); + assert_eq!(buf.remaining(), 10); + assert!(!buf.is_complete()); + + buf.extend(b"hello"); + assert_eq!(buf.len(), 5); + assert_eq!(buf.remaining(), 5); + assert!(!buf.is_complete()); + + buf.extend(b"world"); + assert_eq!(buf.len(), 10); + assert!(buf.is_complete()); + } + + #[test] + fn test_read_buffer_take() { + let mut buf = ReadBuffer::new(); + buf.extend(b"test data"); + + let data = buf.take(); + assert_eq!(&data[..], b"test data"); + assert!(buf.is_empty()); + } + + #[test] + fn test_write_buffer_basic() { + let mut buf = WriteBuffer::with_max_size(100); + assert!(buf.is_empty()); + + buf.extend(b"hello").unwrap(); + assert_eq!(buf.len(), 5); + assert!(!buf.is_empty()); + + buf.advance(3); + assert_eq!(buf.len(), 2); + assert_eq!(buf.pending(), b"lo"); + } + + #[test] + fn test_write_buffer_overflow() { + let mut buf = WriteBuffer::with_max_size(10); + assert!(buf.extend(b"short").is_ok()); + assert!(buf.extend(b"toolong").is_err()); + } + + #[test] + fn test_header_buffer() { + let mut buf = HeaderBuffer::<5>::new(); + assert!(!buf.is_complete()); + assert_eq!(buf.remaining(), 5); + + buf.unfilled_mut()[..3].copy_from_slice(b"hel"); + buf.advance(3); + assert_eq!(buf.remaining(), 2); + + buf.unfilled_mut()[..2].copy_from_slice(b"lo"); + buf.advance(2); + assert!(buf.is_complete()); + assert_eq!(buf.as_array(), b"hello"); + } + + #[test] + fn test_yield_buffer() { + let mut buf = YieldBuffer::new(Bytes::from_static(b"hello world")); + + let mut dst = [0u8; 5]; + assert_eq!(buf.copy_to(&mut dst), 5); + assert_eq!(&dst, b"hello"); + + assert_eq!(buf.remaining(), 6); + + let mut dst = [0u8; 10]; + assert_eq!(buf.copy_to(&mut dst), 6); + assert_eq!(&dst[..6], b" world"); + + assert!(buf.is_empty()); + } + + #[test] + fn test_transition_map() { + let t: Transition = Transition::Complete("hello".to_string()); + let t = t.map_output(|s| s.len()); + + match t { + Transition::Complete(5) => {} + _ => panic!("Expected Complete(5)"), + } + } + + #[test] + fn test_poll_result() { + let r: PollResult = PollResult::Ready(42); + assert!(r.is_ready()); + assert_eq!(r.ok(), Some(42)); + + let r: PollResult = PollResult::Eof; + assert!(r.is_eof()); + assert_eq!(r.ok(), None); + } +} \ No newline at end of file