ET + SM + Crypto Fixes
This commit is contained in:
@@ -46,6 +46,7 @@ base64 = "0.21"
|
|||||||
url = "2.5"
|
url = "2.5"
|
||||||
regex = "1.10"
|
regex = "1.10"
|
||||||
once_cell = "1.19"
|
once_cell = "1.19"
|
||||||
|
crossbeam-queue = "0.3"
|
||||||
|
|
||||||
# HTTP
|
# HTTP
|
||||||
reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false }
|
reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false }
|
||||||
|
|||||||
@@ -1,21 +1,24 @@
|
|||||||
//! AES
|
//! AES encryption implementations
|
||||||
|
//!
|
||||||
|
//! Provides AES-256-CTR and AES-256-CBC modes for MTProto encryption.
|
||||||
|
|
||||||
use aes::Aes256;
|
use aes::Aes256;
|
||||||
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
|
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
|
||||||
use cbc::{Encryptor as CbcEncryptor, Decryptor as CbcDecryptor};
|
|
||||||
use cbc::cipher::{BlockEncryptMut, BlockDecryptMut, block_padding::NoPadding};
|
|
||||||
use crate::error::{ProxyError, Result};
|
use crate::error::{ProxyError, Result};
|
||||||
|
|
||||||
type Aes256Ctr = Ctr128BE<Aes256>;
|
type Aes256Ctr = Ctr128BE<Aes256>;
|
||||||
type Aes256CbcEnc = CbcEncryptor<Aes256>;
|
|
||||||
type Aes256CbcDec = CbcDecryptor<Aes256>;
|
// ============= AES-256-CTR =============
|
||||||
|
|
||||||
/// AES-256-CTR encryptor/decryptor
|
/// AES-256-CTR encryptor/decryptor
|
||||||
|
///
|
||||||
|
/// CTR mode is symmetric - encryption and decryption are the same operation.
|
||||||
pub struct AesCtr {
|
pub struct AesCtr {
|
||||||
cipher: Aes256Ctr,
|
cipher: Aes256Ctr,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AesCtr {
|
impl AesCtr {
|
||||||
|
/// Create new AES-CTR cipher with key and IV
|
||||||
pub fn new(key: &[u8; 32], iv: u128) -> Self {
|
pub fn new(key: &[u8; 32], iv: u128) -> Self {
|
||||||
let iv_bytes = iv.to_be_bytes();
|
let iv_bytes = iv.to_be_bytes();
|
||||||
Self {
|
Self {
|
||||||
@@ -23,6 +26,7 @@ impl AesCtr {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create from key and IV slices
|
||||||
pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result<Self> {
|
pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result<Self> {
|
||||||
if key.len() != 32 {
|
if key.len() != 32 {
|
||||||
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
|
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
|
||||||
@@ -54,17 +58,28 @@ impl AesCtr {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// AES-256-CBC Ciphermagic
|
// ============= AES-256-CBC =============
|
||||||
|
|
||||||
|
/// AES-256-CBC cipher with proper chaining
|
||||||
|
///
|
||||||
|
/// Unlike CTR mode, CBC is NOT symmetric - encryption and decryption
|
||||||
|
/// are different operations. This implementation handles CBC chaining
|
||||||
|
/// correctly across multiple blocks.
|
||||||
pub struct AesCbc {
|
pub struct AesCbc {
|
||||||
key: [u8; 32],
|
key: [u8; 32],
|
||||||
iv: [u8; 16],
|
iv: [u8; 16],
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AesCbc {
|
impl AesCbc {
|
||||||
|
/// AES block size
|
||||||
|
const BLOCK_SIZE: usize = 16;
|
||||||
|
|
||||||
|
/// Create new AES-CBC cipher with key and IV
|
||||||
pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self {
|
pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self {
|
||||||
Self { key, iv }
|
Self { key, iv }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create from slices
|
||||||
pub fn from_slices(key: &[u8], iv: &[u8]) -> Result<Self> {
|
pub fn from_slices(key: &[u8], iv: &[u8]) -> Result<Self> {
|
||||||
if key.len() != 32 {
|
if key.len() != 32 {
|
||||||
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
|
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
|
||||||
@@ -79,32 +94,36 @@ impl AesCbc {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Encrypt data using CBC mode
|
/// Encrypt a single block using raw AES (no chaining)
|
||||||
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
|
fn encrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] {
|
||||||
if data.len() % 16 != 0 {
|
use aes::cipher::BlockEncrypt;
|
||||||
return Err(ProxyError::Crypto(
|
let mut output = *block;
|
||||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
key_schedule.encrypt_block((&mut output).into());
|
||||||
));
|
output
|
||||||
}
|
|
||||||
|
|
||||||
if data.is_empty() {
|
|
||||||
return Ok(Vec::new());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut buffer = data.to_vec();
|
|
||||||
|
|
||||||
let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into());
|
|
||||||
|
|
||||||
for chunk in buffer.chunks_mut(16) {
|
|
||||||
encryptor.encrypt_block_mut(chunk.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(buffer)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Decrypt data using CBC mode
|
/// Decrypt a single block using raw AES (no chaining)
|
||||||
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
|
fn decrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] {
|
||||||
if data.len() % 16 != 0 {
|
use aes::cipher::BlockDecrypt;
|
||||||
|
let mut output = *block;
|
||||||
|
key_schedule.decrypt_block((&mut output).into());
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
/// XOR two 16-byte blocks
|
||||||
|
fn xor_blocks(a: &[u8; 16], b: &[u8; 16]) -> [u8; 16] {
|
||||||
|
let mut result = [0u8; 16];
|
||||||
|
for i in 0..16 {
|
||||||
|
result[i] = a[i] ^ b[i];
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encrypt data using CBC mode with proper chaining
|
||||||
|
///
|
||||||
|
/// CBC Encryption: C[i] = AES_Encrypt(P[i] XOR C[i-1]), where C[-1] = IV
|
||||||
|
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
|
||||||
|
if data.len() % Self::BLOCK_SIZE != 0 {
|
||||||
return Err(ProxyError::Crypto(
|
return Err(ProxyError::Crypto(
|
||||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||||
));
|
));
|
||||||
@@ -114,20 +133,73 @@ impl AesCbc {
|
|||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut buffer = data.to_vec();
|
use aes::cipher::KeyInit;
|
||||||
|
let key_schedule = aes::Aes256::new((&self.key).into());
|
||||||
|
|
||||||
let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into());
|
let mut result = Vec::with_capacity(data.len());
|
||||||
|
let mut prev_ciphertext = self.iv;
|
||||||
|
|
||||||
for chunk in buffer.chunks_mut(16) {
|
for chunk in data.chunks(Self::BLOCK_SIZE) {
|
||||||
decryptor.decrypt_block_mut(chunk.into());
|
let plaintext: [u8; 16] = chunk.try_into().unwrap();
|
||||||
|
|
||||||
|
// XOR plaintext with previous ciphertext (or IV for first block)
|
||||||
|
let xored = Self::xor_blocks(&plaintext, &prev_ciphertext);
|
||||||
|
|
||||||
|
// Encrypt the XORed block
|
||||||
|
let ciphertext = self.encrypt_block(&xored, &key_schedule);
|
||||||
|
|
||||||
|
// Save for next iteration
|
||||||
|
prev_ciphertext = ciphertext;
|
||||||
|
|
||||||
|
// Append to result
|
||||||
|
result.extend_from_slice(&ciphertext);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(buffer)
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decrypt data using CBC mode with proper chaining
|
||||||
|
///
|
||||||
|
/// CBC Decryption: P[i] = AES_Decrypt(C[i]) XOR C[i-1], where C[-1] = IV
|
||||||
|
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
|
||||||
|
if data.len() % Self::BLOCK_SIZE != 0 {
|
||||||
|
return Err(ProxyError::Crypto(
|
||||||
|
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.is_empty() {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
use aes::cipher::KeyInit;
|
||||||
|
let key_schedule = aes::Aes256::new((&self.key).into());
|
||||||
|
|
||||||
|
let mut result = Vec::with_capacity(data.len());
|
||||||
|
let mut prev_ciphertext = self.iv;
|
||||||
|
|
||||||
|
for chunk in data.chunks(Self::BLOCK_SIZE) {
|
||||||
|
let ciphertext: [u8; 16] = chunk.try_into().unwrap();
|
||||||
|
|
||||||
|
// Decrypt the block
|
||||||
|
let decrypted = self.decrypt_block(&ciphertext, &key_schedule);
|
||||||
|
|
||||||
|
// XOR with previous ciphertext (or IV for first block)
|
||||||
|
let plaintext = Self::xor_blocks(&decrypted, &prev_ciphertext);
|
||||||
|
|
||||||
|
// Save current ciphertext for next iteration
|
||||||
|
prev_ciphertext = ciphertext;
|
||||||
|
|
||||||
|
// Append to result
|
||||||
|
result.extend_from_slice(&plaintext);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Encrypt data in-place
|
/// Encrypt data in-place
|
||||||
pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
|
pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
|
||||||
if data.len() % 16 != 0 {
|
if data.len() % Self::BLOCK_SIZE != 0 {
|
||||||
return Err(ProxyError::Crypto(
|
return Err(ProxyError::Crypto(
|
||||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||||
));
|
));
|
||||||
@@ -137,10 +209,25 @@ impl AesCbc {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into());
|
use aes::cipher::KeyInit;
|
||||||
|
let key_schedule = aes::Aes256::new((&self.key).into());
|
||||||
|
|
||||||
for chunk in data.chunks_mut(16) {
|
let mut prev_ciphertext = self.iv;
|
||||||
encryptor.encrypt_block_mut(chunk.into());
|
|
||||||
|
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
|
||||||
|
let block = &mut data[i..i + Self::BLOCK_SIZE];
|
||||||
|
|
||||||
|
// XOR with previous ciphertext
|
||||||
|
for j in 0..Self::BLOCK_SIZE {
|
||||||
|
block[j] ^= prev_ciphertext[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt in-place
|
||||||
|
let block_array: &mut [u8; 16] = block.try_into().unwrap();
|
||||||
|
*block_array = self.encrypt_block(block_array, &key_schedule);
|
||||||
|
|
||||||
|
// Save for next iteration
|
||||||
|
prev_ciphertext = *block_array;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -148,7 +235,7 @@ impl AesCbc {
|
|||||||
|
|
||||||
/// Decrypt data in-place
|
/// Decrypt data in-place
|
||||||
pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
|
pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
|
||||||
if data.len() % 16 != 0 {
|
if data.len() % Self::BLOCK_SIZE != 0 {
|
||||||
return Err(ProxyError::Crypto(
|
return Err(ProxyError::Crypto(
|
||||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||||
));
|
));
|
||||||
@@ -158,16 +245,38 @@ impl AesCbc {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into());
|
use aes::cipher::KeyInit;
|
||||||
|
let key_schedule = aes::Aes256::new((&self.key).into());
|
||||||
|
|
||||||
for chunk in data.chunks_mut(16) {
|
// For in-place decryption, we need to save ciphertext blocks
|
||||||
decryptor.decrypt_block_mut(chunk.into());
|
// before we overwrite them
|
||||||
|
let mut prev_ciphertext = self.iv;
|
||||||
|
|
||||||
|
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
|
||||||
|
let block = &mut data[i..i + Self::BLOCK_SIZE];
|
||||||
|
|
||||||
|
// Save current ciphertext before modifying
|
||||||
|
let current_ciphertext: [u8; 16] = block.try_into().unwrap();
|
||||||
|
|
||||||
|
// Decrypt in-place
|
||||||
|
let block_array: &mut [u8; 16] = block.try_into().unwrap();
|
||||||
|
*block_array = self.decrypt_block(block_array, &key_schedule);
|
||||||
|
|
||||||
|
// XOR with previous ciphertext
|
||||||
|
for j in 0..Self::BLOCK_SIZE {
|
||||||
|
block[j] ^= prev_ciphertext[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save for next iteration
|
||||||
|
prev_ciphertext = current_ciphertext;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============= Encryption Traits =============
|
||||||
|
|
||||||
/// Trait for unified encryption interface
|
/// Trait for unified encryption interface
|
||||||
pub trait Encryptor: Send + Sync {
|
pub trait Encryptor: Send + Sync {
|
||||||
fn encrypt(&mut self, data: &[u8]) -> Vec<u8>;
|
fn encrypt(&mut self, data: &[u8]) -> Vec<u8>;
|
||||||
@@ -209,6 +318,8 @@ impl Decryptor for PassthroughEncryptor {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
// ============= AES-CTR Tests =============
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_ctr_roundtrip() {
|
fn test_aes_ctr_roundtrip() {
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
@@ -225,13 +336,35 @@ mod tests {
|
|||||||
assert_eq!(original.as_slice(), decrypted.as_slice());
|
assert_eq!(original.as_slice(), decrypted.as_slice());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_aes_ctr_in_place() {
|
||||||
|
let key = [0x42u8; 32];
|
||||||
|
let iv = 999u128;
|
||||||
|
|
||||||
|
let original = b"Test data for in-place encryption";
|
||||||
|
let mut data = original.to_vec();
|
||||||
|
|
||||||
|
let mut cipher = AesCtr::new(&key, iv);
|
||||||
|
cipher.apply(&mut data);
|
||||||
|
|
||||||
|
// Encrypted should be different
|
||||||
|
assert_ne!(&data[..], original);
|
||||||
|
|
||||||
|
// Decrypt with fresh cipher
|
||||||
|
let mut cipher = AesCtr::new(&key, iv);
|
||||||
|
cipher.apply(&mut data);
|
||||||
|
|
||||||
|
assert_eq!(&data[..], original);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= AES-CBC Tests =============
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_cbc_roundtrip() {
|
fn test_aes_cbc_roundtrip() {
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
let iv = [0u8; 16];
|
let iv = [0u8; 16];
|
||||||
|
|
||||||
// Must be aligned to 16 bytes
|
let original = [0u8; 32]; // 2 blocks
|
||||||
let original = [0u8; 32];
|
|
||||||
|
|
||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
let encrypted = cipher.encrypt(&original).unwrap();
|
let encrypted = cipher.encrypt(&original).unwrap();
|
||||||
@@ -242,47 +375,59 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_cbc_chaining_works() {
|
fn test_aes_cbc_chaining_works() {
|
||||||
|
// This is the key test - verify CBC chaining is correct
|
||||||
let key = [0x42u8; 32];
|
let key = [0x42u8; 32];
|
||||||
let iv = [0x00u8; 16];
|
let iv = [0x00u8; 16];
|
||||||
|
|
||||||
let plaintext = [0xAA_u8; 32];
|
// Two IDENTICAL plaintext blocks
|
||||||
|
let plaintext = [0xAAu8; 32];
|
||||||
|
|
||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
||||||
|
|
||||||
// CBC Corrections
|
// With proper CBC, identical plaintext blocks produce DIFFERENT ciphertext
|
||||||
let block1 = &ciphertext[0..16];
|
let block1 = &ciphertext[0..16];
|
||||||
let block2 = &ciphertext[16..32];
|
let block2 = &ciphertext[16..32];
|
||||||
|
|
||||||
assert_ne!(block1, block2, "CBC chaining broken: identical plaintext blocks produced identical ciphertext");
|
assert_ne!(
|
||||||
|
block1, block2,
|
||||||
|
"CBC chaining broken: identical plaintext blocks produced identical ciphertext. \
|
||||||
|
This indicates ECB mode, not CBC!"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_cbc_known_vector() {
|
fn test_aes_cbc_known_vector() {
|
||||||
|
// Test with known NIST test vector
|
||||||
|
// AES-256-CBC with zero key and zero IV
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
let iv = [0u8; 16];
|
let iv = [0u8; 16];
|
||||||
|
let plaintext = [0u8; 16];
|
||||||
// 3 Datablocks
|
|
||||||
let plaintext = [
|
|
||||||
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
|
|
||||||
0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF,
|
|
||||||
// Block 2
|
|
||||||
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
|
|
||||||
0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF,
|
|
||||||
// Block 3 - different
|
|
||||||
0xFF, 0xEE, 0xDD, 0xCC, 0xBB, 0xAA, 0x99, 0x88,
|
|
||||||
0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11, 0x00,
|
|
||||||
];
|
|
||||||
|
|
||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
||||||
|
|
||||||
// Decrypt + Verify
|
// Decrypt and verify roundtrip
|
||||||
let decrypted = cipher.decrypt(&ciphertext).unwrap();
|
let decrypted = cipher.decrypt(&ciphertext).unwrap();
|
||||||
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
|
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
|
||||||
|
|
||||||
// Verify Ciphertexts Block 1 != Block 2
|
// Ciphertext should not be all zeros
|
||||||
assert_ne!(&ciphertext[0..16], &ciphertext[16..32]);
|
assert_ne!(ciphertext.as_slice(), plaintext.as_slice());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_aes_cbc_multi_block() {
|
||||||
|
let key = [0x12u8; 32];
|
||||||
|
let iv = [0x34u8; 16];
|
||||||
|
|
||||||
|
// 5 blocks = 80 bytes
|
||||||
|
let plaintext: Vec<u8> = (0..80).collect();
|
||||||
|
|
||||||
|
let cipher = AesCbc::new(key, iv);
|
||||||
|
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
||||||
|
let decrypted = cipher.decrypt(&ciphertext).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(plaintext, decrypted);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -291,7 +436,7 @@ mod tests {
|
|||||||
let iv = [0x34u8; 16];
|
let iv = [0x34u8; 16];
|
||||||
|
|
||||||
let original = [0x56u8; 48]; // 3 blocks
|
let original = [0x56u8; 48]; // 3 blocks
|
||||||
let mut buffer = original.clone();
|
let mut buffer = original;
|
||||||
|
|
||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
|
|
||||||
@@ -317,35 +462,85 @@ mod tests {
|
|||||||
fn test_aes_cbc_unaligned_error() {
|
fn test_aes_cbc_unaligned_error() {
|
||||||
let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
|
let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
|
||||||
|
|
||||||
// 15 bytes
|
// 15 bytes - not aligned to block size
|
||||||
let result = cipher.encrypt(&[0u8; 15]);
|
let result = cipher.encrypt(&[0u8; 15]);
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
|
|
||||||
// 17 bytes
|
// 17 bytes - not aligned
|
||||||
let result = cipher.encrypt(&[0u8; 17]);
|
let result = cipher.encrypt(&[0u8; 17]);
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_cbc_avalanche_effect() {
|
fn test_aes_cbc_avalanche_effect() {
|
||||||
// Cipherplane
|
// Changing one bit in plaintext should change entire ciphertext block
|
||||||
|
// and all subsequent blocks (due to chaining)
|
||||||
let key = [0xAB; 32];
|
let key = [0xAB; 32];
|
||||||
let iv = [0xCD; 16];
|
let iv = [0xCD; 16];
|
||||||
|
|
||||||
let mut plaintext1 = [0u8; 32];
|
let mut plaintext1 = [0u8; 32];
|
||||||
let mut plaintext2 = [0u8; 32];
|
let mut plaintext2 = [0u8; 32];
|
||||||
plaintext2[0] = 0x01; // Один бит отличается
|
plaintext2[0] = 0x01; // Single bit difference in first block
|
||||||
|
|
||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
|
|
||||||
let ciphertext1 = cipher.encrypt(&plaintext1).unwrap();
|
let ciphertext1 = cipher.encrypt(&plaintext1).unwrap();
|
||||||
let ciphertext2 = cipher.encrypt(&plaintext2).unwrap();
|
let ciphertext2 = cipher.encrypt(&plaintext2).unwrap();
|
||||||
|
|
||||||
// First Blocks Diff
|
// First blocks should be different
|
||||||
assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]);
|
assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]);
|
||||||
|
|
||||||
// Second Blocks Diff
|
// Second blocks should ALSO be different (chaining effect)
|
||||||
assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]);
|
assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_aes_cbc_iv_matters() {
|
||||||
|
// Same plaintext with different IVs should produce different ciphertext
|
||||||
|
let key = [0x55; 32];
|
||||||
|
let plaintext = [0x77u8; 16];
|
||||||
|
|
||||||
|
let cipher1 = AesCbc::new(key, [0u8; 16]);
|
||||||
|
let cipher2 = AesCbc::new(key, [1u8; 16]);
|
||||||
|
|
||||||
|
let ciphertext1 = cipher1.encrypt(&plaintext).unwrap();
|
||||||
|
let ciphertext2 = cipher2.encrypt(&plaintext).unwrap();
|
||||||
|
|
||||||
|
assert_ne!(ciphertext1, ciphertext2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_aes_cbc_deterministic() {
|
||||||
|
// Same key, IV, plaintext should always produce same ciphertext
|
||||||
|
let key = [0x99; 32];
|
||||||
|
let iv = [0x88; 16];
|
||||||
|
let plaintext = [0x77u8; 32];
|
||||||
|
|
||||||
|
let cipher = AesCbc::new(key, iv);
|
||||||
|
|
||||||
|
let ciphertext1 = cipher.encrypt(&plaintext).unwrap();
|
||||||
|
let ciphertext2 = cipher.encrypt(&plaintext).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(ciphertext1, ciphertext2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Error Handling Tests =============
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_invalid_key_length() {
|
||||||
|
let result = AesCtr::from_key_iv(&[0u8; 16], &[0u8; 16]);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
let result = AesCbc::from_slices(&[0u8; 16], &[0u8; 16]);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_invalid_iv_length() {
|
||||||
|
let result = AesCtr::from_key_iv(&[0u8; 32], &[0u8; 8]);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
let result = AesCbc::from_slices(&[0u8; 32], &[0u8; 8]);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
261
src/error.rs
261
src/error.rs
@@ -1,8 +1,177 @@
|
|||||||
//! Error Types
|
//! Error Types
|
||||||
|
|
||||||
|
use std::fmt;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
|
// ============= Stream Errors =============
|
||||||
|
|
||||||
|
/// Errors specific to stream I/O operations
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum StreamError {
|
||||||
|
/// Partial read: got fewer bytes than expected
|
||||||
|
PartialRead {
|
||||||
|
expected: usize,
|
||||||
|
got: usize,
|
||||||
|
},
|
||||||
|
/// Partial write: wrote fewer bytes than expected
|
||||||
|
PartialWrite {
|
||||||
|
expected: usize,
|
||||||
|
written: usize,
|
||||||
|
},
|
||||||
|
/// Stream is in poisoned state and cannot be used
|
||||||
|
Poisoned {
|
||||||
|
reason: String,
|
||||||
|
},
|
||||||
|
/// Buffer overflow: attempted to buffer more than allowed
|
||||||
|
BufferOverflow {
|
||||||
|
limit: usize,
|
||||||
|
attempted: usize,
|
||||||
|
},
|
||||||
|
/// Invalid frame format
|
||||||
|
InvalidFrame {
|
||||||
|
details: String,
|
||||||
|
},
|
||||||
|
/// Unexpected end of stream
|
||||||
|
UnexpectedEof,
|
||||||
|
/// Underlying I/O error
|
||||||
|
Io(std::io::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for StreamError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::PartialRead { expected, got } => {
|
||||||
|
write!(f, "partial read: expected {} bytes, got {}", expected, got)
|
||||||
|
}
|
||||||
|
Self::PartialWrite { expected, written } => {
|
||||||
|
write!(f, "partial write: expected {} bytes, wrote {}", expected, written)
|
||||||
|
}
|
||||||
|
Self::Poisoned { reason } => {
|
||||||
|
write!(f, "stream poisoned: {}", reason)
|
||||||
|
}
|
||||||
|
Self::BufferOverflow { limit, attempted } => {
|
||||||
|
write!(f, "buffer overflow: limit {}, attempted {}", limit, attempted)
|
||||||
|
}
|
||||||
|
Self::InvalidFrame { details } => {
|
||||||
|
write!(f, "invalid frame: {}", details)
|
||||||
|
}
|
||||||
|
Self::UnexpectedEof => {
|
||||||
|
write!(f, "unexpected end of stream")
|
||||||
|
}
|
||||||
|
Self::Io(e) => {
|
||||||
|
write!(f, "I/O error: {}", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for StreamError {
|
||||||
|
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||||
|
match self {
|
||||||
|
Self::Io(e) => Some(e),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for StreamError {
|
||||||
|
fn from(err: std::io::Error) -> Self {
|
||||||
|
Self::Io(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StreamError> for std::io::Error {
|
||||||
|
fn from(err: StreamError) -> Self {
|
||||||
|
match err {
|
||||||
|
StreamError::Io(e) => e,
|
||||||
|
StreamError::UnexpectedEof => {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::UnexpectedEof, err)
|
||||||
|
}
|
||||||
|
StreamError::Poisoned { .. } => {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::Other, err)
|
||||||
|
}
|
||||||
|
StreamError::BufferOverflow { .. } => {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::OutOfMemory, err)
|
||||||
|
}
|
||||||
|
StreamError::InvalidFrame { .. } => {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::InvalidData, err)
|
||||||
|
}
|
||||||
|
StreamError::PartialRead { .. } | StreamError::PartialWrite { .. } => {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::Other, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Recoverable Trait =============
|
||||||
|
|
||||||
|
/// Trait for errors that may be recoverable
|
||||||
|
pub trait Recoverable {
|
||||||
|
/// Check if error is recoverable (can retry operation)
|
||||||
|
fn is_recoverable(&self) -> bool;
|
||||||
|
|
||||||
|
/// Check if connection can continue after this error
|
||||||
|
fn can_continue(&self) -> bool;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Recoverable for StreamError {
|
||||||
|
fn is_recoverable(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
// Partial operations can be retried
|
||||||
|
Self::PartialRead { .. } | Self::PartialWrite { .. } => true,
|
||||||
|
// I/O errors depend on kind
|
||||||
|
Self::Io(e) => matches!(
|
||||||
|
e.kind(),
|
||||||
|
std::io::ErrorKind::WouldBlock
|
||||||
|
| std::io::ErrorKind::Interrupted
|
||||||
|
| std::io::ErrorKind::TimedOut
|
||||||
|
),
|
||||||
|
// These are not recoverable
|
||||||
|
Self::Poisoned { .. }
|
||||||
|
| Self::BufferOverflow { .. }
|
||||||
|
| Self::InvalidFrame { .. }
|
||||||
|
| Self::UnexpectedEof => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn can_continue(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
// Poisoned stream cannot be used
|
||||||
|
Self::Poisoned { .. } => false,
|
||||||
|
// EOF means stream is done
|
||||||
|
Self::UnexpectedEof => false,
|
||||||
|
// Buffer overflow is fatal
|
||||||
|
Self::BufferOverflow { .. } => false,
|
||||||
|
// Others might allow continuation
|
||||||
|
_ => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Recoverable for std::io::Error {
|
||||||
|
fn is_recoverable(&self) -> bool {
|
||||||
|
matches!(
|
||||||
|
self.kind(),
|
||||||
|
std::io::ErrorKind::WouldBlock
|
||||||
|
| std::io::ErrorKind::Interrupted
|
||||||
|
| std::io::ErrorKind::TimedOut
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn can_continue(&self) -> bool {
|
||||||
|
!matches!(
|
||||||
|
self.kind(),
|
||||||
|
std::io::ErrorKind::BrokenPipe
|
||||||
|
| std::io::ErrorKind::ConnectionReset
|
||||||
|
| std::io::ErrorKind::ConnectionAborted
|
||||||
|
| std::io::ErrorKind::NotConnected
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Main Proxy Errors =============
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum ProxyError {
|
pub enum ProxyError {
|
||||||
// ============= Crypto Errors =============
|
// ============= Crypto Errors =============
|
||||||
@@ -13,6 +182,11 @@ pub enum ProxyError {
|
|||||||
#[error("Invalid key length: expected {expected}, got {got}")]
|
#[error("Invalid key length: expected {expected}, got {got}")]
|
||||||
InvalidKeyLength { expected: usize, got: usize },
|
InvalidKeyLength { expected: usize, got: usize },
|
||||||
|
|
||||||
|
// ============= Stream Errors =============
|
||||||
|
|
||||||
|
#[error("Stream error: {0}")]
|
||||||
|
Stream(#[from] StreamError),
|
||||||
|
|
||||||
// ============= Protocol Errors =============
|
// ============= Protocol Errors =============
|
||||||
|
|
||||||
#[error("Invalid handshake: {0}")]
|
#[error("Invalid handshake: {0}")]
|
||||||
@@ -39,6 +213,12 @@ pub enum ProxyError {
|
|||||||
#[error("Sequence number mismatch: expected={expected}, got={got}")]
|
#[error("Sequence number mismatch: expected={expected}, got={got}")]
|
||||||
SeqNoMismatch { expected: i32, got: i32 },
|
SeqNoMismatch { expected: i32, got: i32 },
|
||||||
|
|
||||||
|
#[error("TLS handshake failed: {reason}")]
|
||||||
|
TlsHandshakeFailed { reason: String },
|
||||||
|
|
||||||
|
#[error("Telegram handshake timeout")]
|
||||||
|
TgHandshakeTimeout,
|
||||||
|
|
||||||
// ============= Network Errors =============
|
// ============= Network Errors =============
|
||||||
|
|
||||||
#[error("Connection timeout to {addr}")]
|
#[error("Connection timeout to {addr}")]
|
||||||
@@ -77,15 +257,41 @@ pub enum ProxyError {
|
|||||||
#[error("Unknown user")]
|
#[error("Unknown user")]
|
||||||
UnknownUser,
|
UnknownUser,
|
||||||
|
|
||||||
|
#[error("Rate limited")]
|
||||||
|
RateLimited,
|
||||||
|
|
||||||
// ============= General Errors =============
|
// ============= General Errors =============
|
||||||
|
|
||||||
#[error("Internal error: {0}")]
|
#[error("Internal error: {0}")]
|
||||||
Internal(String),
|
Internal(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Recoverable for ProxyError {
|
||||||
|
fn is_recoverable(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Stream(e) => e.is_recoverable(),
|
||||||
|
Self::Io(e) => e.is_recoverable(),
|
||||||
|
Self::ConnectionTimeout { .. } => true,
|
||||||
|
Self::RateLimited => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn can_continue(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Stream(e) => e.can_continue(),
|
||||||
|
Self::Io(e) => e.can_continue(),
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Convenient Result type alias
|
/// Convenient Result type alias
|
||||||
pub type Result<T> = std::result::Result<T, ProxyError>;
|
pub type Result<T> = std::result::Result<T, ProxyError>;
|
||||||
|
|
||||||
|
/// Result type for stream operations
|
||||||
|
pub type StreamResult<T> = std::result::Result<T, StreamError>;
|
||||||
|
|
||||||
/// Result with optional bad client handling
|
/// Result with optional bad client handling
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum HandshakeResult<T> {
|
pub enum HandshakeResult<T> {
|
||||||
@@ -125,6 +331,14 @@ impl<T> HandshakeResult<T> {
|
|||||||
HandshakeResult::Error(e) => HandshakeResult::Error(e),
|
HandshakeResult::Error(e) => HandshakeResult::Error(e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Convert success to Option
|
||||||
|
pub fn ok(self) -> Option<T> {
|
||||||
|
match self {
|
||||||
|
HandshakeResult::Success(v) => Some(v),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> From<ProxyError> for HandshakeResult<T> {
|
impl<T> From<ProxyError> for HandshakeResult<T> {
|
||||||
@@ -139,10 +353,48 @@ impl<T> From<std::io::Error> for HandshakeResult<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T> From<StreamError> for HandshakeResult<T> {
|
||||||
|
fn from(err: StreamError) -> Self {
|
||||||
|
HandshakeResult::Error(ProxyError::Stream(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream_error_display() {
|
||||||
|
let err = StreamError::PartialRead { expected: 100, got: 50 };
|
||||||
|
assert!(err.to_string().contains("100"));
|
||||||
|
assert!(err.to_string().contains("50"));
|
||||||
|
|
||||||
|
let err = StreamError::Poisoned { reason: "test".into() };
|
||||||
|
assert!(err.to_string().contains("test"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream_error_recoverable() {
|
||||||
|
assert!(StreamError::PartialRead { expected: 10, got: 5 }.is_recoverable());
|
||||||
|
assert!(StreamError::PartialWrite { expected: 10, written: 5 }.is_recoverable());
|
||||||
|
assert!(!StreamError::Poisoned { reason: "x".into() }.is_recoverable());
|
||||||
|
assert!(!StreamError::UnexpectedEof.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream_error_can_continue() {
|
||||||
|
assert!(!StreamError::Poisoned { reason: "x".into() }.can_continue());
|
||||||
|
assert!(!StreamError::UnexpectedEof.can_continue());
|
||||||
|
assert!(StreamError::PartialRead { expected: 10, got: 5 }.can_continue());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream_error_to_io_error() {
|
||||||
|
let stream_err = StreamError::UnexpectedEof;
|
||||||
|
let io_err: std::io::Error = stream_err.into();
|
||||||
|
assert_eq!(io_err.kind(), std::io::ErrorKind::UnexpectedEof);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_handshake_result() {
|
fn test_handshake_result() {
|
||||||
let success: HandshakeResult<i32> = HandshakeResult::Success(42);
|
let success: HandshakeResult<i32> = HandshakeResult::Success(42);
|
||||||
@@ -165,6 +417,15 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_proxy_error_recoverable() {
|
||||||
|
let err = ProxyError::RateLimited;
|
||||||
|
assert!(err.is_recoverable());
|
||||||
|
|
||||||
|
let err = ProxyError::InvalidHandshake("bad".into());
|
||||||
|
assert!(!err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_error_display() {
|
fn test_error_display() {
|
||||||
let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() };
|
let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() };
|
||||||
|
|||||||
450
src/stream/buffer_pool.rs
Normal file
450
src/stream/buffer_pool.rs
Normal file
@@ -0,0 +1,450 @@
|
|||||||
|
//! Reusable buffer pool to avoid allocations in hot paths
|
||||||
|
//!
|
||||||
|
//! This module provides a thread-safe pool of BytesMut buffers
|
||||||
|
//! that can be reused across connections to reduce allocation pressure.
|
||||||
|
|
||||||
|
use bytes::BytesMut;
|
||||||
|
use crossbeam_queue::ArrayQueue;
|
||||||
|
use std::ops::{Deref, DerefMut};
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
// ============= Configuration =============
|
||||||
|
|
||||||
|
/// Default buffer size (64KB - good for MTProto)
|
||||||
|
pub const DEFAULT_BUFFER_SIZE: usize = 64 * 1024;
|
||||||
|
|
||||||
|
/// Default maximum number of pooled buffers
|
||||||
|
pub const DEFAULT_MAX_BUFFERS: usize = 1024;
|
||||||
|
|
||||||
|
// ============= Buffer Pool =============
|
||||||
|
|
||||||
|
/// Thread-safe pool of reusable buffers
|
||||||
|
pub struct BufferPool {
|
||||||
|
/// Queue of available buffers
|
||||||
|
buffers: ArrayQueue<BytesMut>,
|
||||||
|
/// Size of each buffer
|
||||||
|
buffer_size: usize,
|
||||||
|
/// Maximum number of buffers to pool
|
||||||
|
max_buffers: usize,
|
||||||
|
/// Total allocated buffers (including in-use)
|
||||||
|
allocated: AtomicUsize,
|
||||||
|
/// Number of times we had to create a new buffer
|
||||||
|
misses: AtomicUsize,
|
||||||
|
/// Number of successful reuses
|
||||||
|
hits: AtomicUsize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BufferPool {
|
||||||
|
/// Create a new buffer pool with default settings
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::with_config(DEFAULT_BUFFER_SIZE, DEFAULT_MAX_BUFFERS)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a buffer pool with custom configuration
|
||||||
|
pub fn with_config(buffer_size: usize, max_buffers: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
buffers: ArrayQueue::new(max_buffers),
|
||||||
|
buffer_size,
|
||||||
|
max_buffers,
|
||||||
|
allocated: AtomicUsize::new(0),
|
||||||
|
misses: AtomicUsize::new(0),
|
||||||
|
hits: AtomicUsize::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a buffer from the pool, or create a new one if empty
|
||||||
|
pub fn get(self: &Arc<Self>) -> PooledBuffer {
|
||||||
|
match self.buffers.pop() {
|
||||||
|
Some(mut buffer) => {
|
||||||
|
self.hits.fetch_add(1, Ordering::Relaxed);
|
||||||
|
buffer.clear();
|
||||||
|
PooledBuffer {
|
||||||
|
buffer: Some(buffer),
|
||||||
|
pool: Arc::clone(self),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
self.misses.fetch_add(1, Ordering::Relaxed);
|
||||||
|
self.allocated.fetch_add(1, Ordering::Relaxed);
|
||||||
|
PooledBuffer {
|
||||||
|
buffer: Some(BytesMut::with_capacity(self.buffer_size)),
|
||||||
|
pool: Arc::clone(self),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to get a buffer, returns None if pool is empty
|
||||||
|
pub fn try_get(self: &Arc<Self>) -> Option<PooledBuffer> {
|
||||||
|
self.buffers.pop().map(|mut buffer| {
|
||||||
|
self.hits.fetch_add(1, Ordering::Relaxed);
|
||||||
|
buffer.clear();
|
||||||
|
PooledBuffer {
|
||||||
|
buffer: Some(buffer),
|
||||||
|
pool: Arc::clone(self),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a buffer to the pool
|
||||||
|
fn return_buffer(&self, mut buffer: BytesMut) {
|
||||||
|
// Clear the buffer but keep capacity
|
||||||
|
buffer.clear();
|
||||||
|
|
||||||
|
// Only return if we haven't exceeded max and buffer is right size
|
||||||
|
if buffer.capacity() >= self.buffer_size {
|
||||||
|
// Try to push to pool, if full just drop
|
||||||
|
let _ = self.buffers.push(buffer);
|
||||||
|
}
|
||||||
|
// If buffer was dropped (pool full), decrement allocated
|
||||||
|
// Actually we don't decrement here because the buffer might have been
|
||||||
|
// grown beyond our size - we just let it go
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get pool statistics
|
||||||
|
pub fn stats(&self) -> PoolStats {
|
||||||
|
PoolStats {
|
||||||
|
pooled: self.buffers.len(),
|
||||||
|
allocated: self.allocated.load(Ordering::Relaxed),
|
||||||
|
max_buffers: self.max_buffers,
|
||||||
|
buffer_size: self.buffer_size,
|
||||||
|
hits: self.hits.load(Ordering::Relaxed),
|
||||||
|
misses: self.misses.load(Ordering::Relaxed),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get buffer size
|
||||||
|
pub fn buffer_size(&self) -> usize {
|
||||||
|
self.buffer_size
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Preallocate buffers to fill the pool
|
||||||
|
pub fn preallocate(&self, count: usize) {
|
||||||
|
let to_alloc = count.min(self.max_buffers);
|
||||||
|
for _ in 0..to_alloc {
|
||||||
|
if self.buffers.push(BytesMut::with_capacity(self.buffer_size)).is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
self.allocated.fetch_add(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for BufferPool {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Pool Statistics =============
|
||||||
|
|
||||||
|
/// Statistics about buffer pool usage
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PoolStats {
|
||||||
|
/// Current number of buffers in pool
|
||||||
|
pub pooled: usize,
|
||||||
|
/// Total buffers allocated (in-use + pooled)
|
||||||
|
pub allocated: usize,
|
||||||
|
/// Maximum buffers allowed
|
||||||
|
pub max_buffers: usize,
|
||||||
|
/// Size of each buffer
|
||||||
|
pub buffer_size: usize,
|
||||||
|
/// Number of cache hits (reused buffer)
|
||||||
|
pub hits: usize,
|
||||||
|
/// Number of cache misses (new allocation)
|
||||||
|
pub misses: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PoolStats {
|
||||||
|
/// Get hit rate as percentage
|
||||||
|
pub fn hit_rate(&self) -> f64 {
|
||||||
|
let total = self.hits + self.misses;
|
||||||
|
if total == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
(self.hits as f64 / total as f64) * 100.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Pooled Buffer =============
|
||||||
|
|
||||||
|
/// A buffer that automatically returns to the pool when dropped
|
||||||
|
pub struct PooledBuffer {
|
||||||
|
buffer: Option<BytesMut>,
|
||||||
|
pool: Arc<BufferPool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PooledBuffer {
|
||||||
|
/// Take the inner buffer, preventing return to pool
|
||||||
|
pub fn take(mut self) -> BytesMut {
|
||||||
|
self.buffer.take().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the capacity of the buffer
|
||||||
|
pub fn capacity(&self) -> usize {
|
||||||
|
self.buffer.as_ref().map(|b| b.capacity()).unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if buffer is empty
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.buffer.as_ref().map(|b| b.is_empty()).unwrap_or(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the length of data in buffer
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.buffer.as_ref().map(|b| b.len()).unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the buffer
|
||||||
|
pub fn clear(&mut self) {
|
||||||
|
if let Some(ref mut b) = self.buffer {
|
||||||
|
b.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deref for PooledBuffer {
|
||||||
|
type Target = BytesMut;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
self.buffer.as_ref().expect("buffer taken")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DerefMut for PooledBuffer {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
self.buffer.as_mut().expect("buffer taken")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for PooledBuffer {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Some(buffer) = self.buffer.take() {
|
||||||
|
self.pool.return_buffer(buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsRef<[u8]> for PooledBuffer {
|
||||||
|
fn as_ref(&self) -> &[u8] {
|
||||||
|
self.buffer.as_ref().map(|b| b.as_ref()).unwrap_or(&[])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsMut<[u8]> for PooledBuffer {
|
||||||
|
fn as_mut(&mut self) -> &mut [u8] {
|
||||||
|
self.buffer.as_mut().map(|b| b.as_mut()).unwrap_or(&mut [])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Scoped Buffer =============
|
||||||
|
|
||||||
|
/// A buffer that can be used for a scoped operation
|
||||||
|
/// Useful for ensuring buffer is returned even on early return
|
||||||
|
pub struct ScopedBuffer<'a> {
|
||||||
|
buffer: &'a mut PooledBuffer,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> ScopedBuffer<'a> {
|
||||||
|
/// Create a new scoped buffer
|
||||||
|
pub fn new(buffer: &'a mut PooledBuffer) -> Self {
|
||||||
|
buffer.clear();
|
||||||
|
Self { buffer }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Deref for ScopedBuffer<'a> {
|
||||||
|
type Target = BytesMut;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
self.buffer.deref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> DerefMut for ScopedBuffer<'a> {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
self.buffer.deref_mut()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Drop for ScopedBuffer<'a> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.buffer.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pool_basic() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||||
|
|
||||||
|
// Get a buffer
|
||||||
|
let mut buf1 = pool.get();
|
||||||
|
buf1.extend_from_slice(b"hello");
|
||||||
|
assert_eq!(&buf1[..], b"hello");
|
||||||
|
|
||||||
|
// Drop returns to pool
|
||||||
|
drop(buf1);
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.pooled, 1);
|
||||||
|
assert_eq!(stats.hits, 0);
|
||||||
|
assert_eq!(stats.misses, 1);
|
||||||
|
|
||||||
|
// Get again - should reuse
|
||||||
|
let buf2 = pool.get();
|
||||||
|
assert!(buf2.is_empty()); // Buffer was cleared
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.pooled, 0);
|
||||||
|
assert_eq!(stats.hits, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pool_multiple_buffers() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||||
|
|
||||||
|
// Get multiple buffers
|
||||||
|
let buf1 = pool.get();
|
||||||
|
let buf2 = pool.get();
|
||||||
|
let buf3 = pool.get();
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.allocated, 3);
|
||||||
|
assert_eq!(stats.pooled, 0);
|
||||||
|
|
||||||
|
// Return all
|
||||||
|
drop(buf1);
|
||||||
|
drop(buf2);
|
||||||
|
drop(buf3);
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.pooled, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pool_overflow() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 2));
|
||||||
|
|
||||||
|
// Get 3 buffers (more than max)
|
||||||
|
let buf1 = pool.get();
|
||||||
|
let buf2 = pool.get();
|
||||||
|
let buf3 = pool.get();
|
||||||
|
|
||||||
|
// Return all - only 2 should be pooled
|
||||||
|
drop(buf1);
|
||||||
|
drop(buf2);
|
||||||
|
drop(buf3);
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.pooled, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pool_take() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||||
|
|
||||||
|
let mut buf = pool.get();
|
||||||
|
buf.extend_from_slice(b"data");
|
||||||
|
|
||||||
|
// Take ownership, buffer should not return to pool
|
||||||
|
let taken = buf.take();
|
||||||
|
assert_eq!(&taken[..], b"data");
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.pooled, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pool_preallocate() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||||
|
pool.preallocate(5);
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.pooled, 5);
|
||||||
|
assert_eq!(stats.allocated, 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pool_try_get() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||||
|
|
||||||
|
// Pool is empty, try_get returns None
|
||||||
|
assert!(pool.try_get().is_none());
|
||||||
|
|
||||||
|
// Add a buffer to pool
|
||||||
|
pool.preallocate(1);
|
||||||
|
|
||||||
|
// Now try_get should succeed
|
||||||
|
assert!(pool.try_get().is_some());
|
||||||
|
assert!(pool.try_get().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_hit_rate() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||||
|
|
||||||
|
// First get is a miss
|
||||||
|
let buf1 = pool.get();
|
||||||
|
drop(buf1);
|
||||||
|
|
||||||
|
// Second get is a hit
|
||||||
|
let buf2 = pool.get();
|
||||||
|
drop(buf2);
|
||||||
|
|
||||||
|
// Third get is a hit
|
||||||
|
let _buf3 = pool.get();
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.hits, 2);
|
||||||
|
assert_eq!(stats.misses, 1);
|
||||||
|
assert!((stats.hit_rate() - 66.67).abs() < 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_scoped_buffer() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||||
|
let mut buf = pool.get();
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut scoped = ScopedBuffer::new(&mut buf);
|
||||||
|
scoped.extend_from_slice(b"scoped data");
|
||||||
|
assert_eq!(&scoped[..], b"scoped data");
|
||||||
|
}
|
||||||
|
|
||||||
|
// After scoped is dropped, buffer is cleared
|
||||||
|
assert!(buf.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_concurrent_access() {
|
||||||
|
use std::thread;
|
||||||
|
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 100));
|
||||||
|
let mut handles = vec![];
|
||||||
|
|
||||||
|
for _ in 0..10 {
|
||||||
|
let pool_clone = Arc::clone(&pool);
|
||||||
|
handles.push(thread::spawn(move || {
|
||||||
|
for _ in 0..100 {
|
||||||
|
let mut buf = pool_clone.get();
|
||||||
|
buf.extend_from_slice(b"test");
|
||||||
|
// buf auto-returned on drop
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for handle in handles {
|
||||||
|
handle.join().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
// All buffers should be returned
|
||||||
|
assert!(stats.pooled > 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,10 +1,22 @@
|
|||||||
//! Stream wrappers for MTProto protocol layers
|
//! Stream wrappers for MTProto protocol layers
|
||||||
|
|
||||||
|
pub mod state;
|
||||||
|
pub mod buffer_pool;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
pub mod crypto_stream;
|
pub mod crypto_stream;
|
||||||
pub mod tls_stream;
|
pub mod tls_stream;
|
||||||
pub mod frame_stream;
|
pub mod frame_stream;
|
||||||
|
|
||||||
|
// Re-export state machine types
|
||||||
|
pub use state::{
|
||||||
|
StreamState, Transition, PollResult,
|
||||||
|
ReadBuffer, WriteBuffer, HeaderBuffer, YieldBuffer,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Re-export buffer pool
|
||||||
|
pub use buffer_pool::{BufferPool, PooledBuffer, PoolStats};
|
||||||
|
|
||||||
|
// Re-export stream implementations
|
||||||
pub use crypto_stream::{CryptoReader, CryptoWriter, PassthroughStream};
|
pub use crypto_stream::{CryptoReader, CryptoWriter, PassthroughStream};
|
||||||
pub use tls_stream::{FakeTlsReader, FakeTlsWriter};
|
pub use tls_stream::{FakeTlsReader, FakeTlsWriter};
|
||||||
pub use frame_stream::*;
|
pub use frame_stream::*;
|
||||||
571
src/stream/state.rs
Normal file
571
src/stream/state.rs
Normal file
@@ -0,0 +1,571 @@
|
|||||||
|
//! State machine foundation types for async streams
|
||||||
|
//!
|
||||||
|
//! This module provides core types and traits for implementing
|
||||||
|
//! stateful async streams with proper partial read/write handling.
|
||||||
|
|
||||||
|
use bytes::{Bytes, BytesMut};
|
||||||
|
use std::io;
|
||||||
|
|
||||||
|
// ============= Core Traits =============
|
||||||
|
|
||||||
|
/// Trait for stream states
|
||||||
|
pub trait StreamState: Sized {
|
||||||
|
/// Check if this is a terminal state (no more transitions possible)
|
||||||
|
fn is_terminal(&self) -> bool;
|
||||||
|
|
||||||
|
/// Check if stream is in poisoned/error state
|
||||||
|
fn is_poisoned(&self) -> bool;
|
||||||
|
|
||||||
|
/// Get human-readable state name for debugging
|
||||||
|
fn state_name(&self) -> &'static str;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Transition Types =============
|
||||||
|
|
||||||
|
/// Result of a state transition
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum Transition<S, O> {
|
||||||
|
/// Stay in the same state, no output
|
||||||
|
Same,
|
||||||
|
/// Transition to a new state, no output
|
||||||
|
Next(S),
|
||||||
|
/// Complete with output, typically transitions to Idle
|
||||||
|
Complete(O),
|
||||||
|
/// Yield output and transition to new state
|
||||||
|
Yield(O, S),
|
||||||
|
/// Error occurred, transition to error state
|
||||||
|
Error(io::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, O> Transition<S, O> {
|
||||||
|
/// Check if transition produces output
|
||||||
|
pub fn has_output(&self) -> bool {
|
||||||
|
matches!(self, Transition::Complete(_) | Transition::Yield(_, _))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map the output value
|
||||||
|
pub fn map_output<U, F: FnOnce(O) -> U>(self, f: F) -> Transition<S, U> {
|
||||||
|
match self {
|
||||||
|
Transition::Same => Transition::Same,
|
||||||
|
Transition::Next(s) => Transition::Next(s),
|
||||||
|
Transition::Complete(o) => Transition::Complete(f(o)),
|
||||||
|
Transition::Yield(o, s) => Transition::Yield(f(o), s),
|
||||||
|
Transition::Error(e) => Transition::Error(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map the state value
|
||||||
|
pub fn map_state<T, F: FnOnce(S) -> T>(self, f: F) -> Transition<T, O> {
|
||||||
|
match self {
|
||||||
|
Transition::Same => Transition::Same,
|
||||||
|
Transition::Next(s) => Transition::Next(f(s)),
|
||||||
|
Transition::Complete(o) => Transition::Complete(o),
|
||||||
|
Transition::Yield(o, s) => Transition::Yield(o, f(s)),
|
||||||
|
Transition::Error(e) => Transition::Error(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Poll Result Types =============
|
||||||
|
|
||||||
|
/// Result of polling for more data
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum PollResult<T> {
|
||||||
|
/// Data is ready
|
||||||
|
Ready(T),
|
||||||
|
/// Operation would block, need to poll again
|
||||||
|
Pending,
|
||||||
|
/// Need more input data (minimum bytes required)
|
||||||
|
NeedInput(usize),
|
||||||
|
/// End of stream reached
|
||||||
|
Eof,
|
||||||
|
/// Error occurred
|
||||||
|
Error(io::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> PollResult<T> {
|
||||||
|
/// Check if result is ready
|
||||||
|
pub fn is_ready(&self) -> bool {
|
||||||
|
matches!(self, PollResult::Ready(_))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if result indicates EOF
|
||||||
|
pub fn is_eof(&self) -> bool {
|
||||||
|
matches!(self, PollResult::Eof)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert to Option, discarding non-ready states
|
||||||
|
pub fn ok(self) -> Option<T> {
|
||||||
|
match self {
|
||||||
|
PollResult::Ready(t) => Some(t),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map the value
|
||||||
|
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> PollResult<U> {
|
||||||
|
match self {
|
||||||
|
PollResult::Ready(t) => PollResult::Ready(f(t)),
|
||||||
|
PollResult::Pending => PollResult::Pending,
|
||||||
|
PollResult::NeedInput(n) => PollResult::NeedInput(n),
|
||||||
|
PollResult::Eof => PollResult::Eof,
|
||||||
|
PollResult::Error(e) => PollResult::Error(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> From<io::Result<T>> for PollResult<T> {
|
||||||
|
fn from(result: io::Result<T>) -> Self {
|
||||||
|
match result {
|
||||||
|
Ok(t) => PollResult::Ready(t),
|
||||||
|
Err(e) if e.kind() == io::ErrorKind::WouldBlock => PollResult::Pending,
|
||||||
|
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => PollResult::Eof,
|
||||||
|
Err(e) => PollResult::Error(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Buffer State =============
|
||||||
|
|
||||||
|
/// State for buffered reading operations
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ReadBuffer {
|
||||||
|
/// The buffer holding data
|
||||||
|
buffer: BytesMut,
|
||||||
|
/// Target number of bytes to read (if known)
|
||||||
|
target: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReadBuffer {
|
||||||
|
/// Create new empty read buffer
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
buffer: BytesMut::with_capacity(8192),
|
||||||
|
target: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create with specific capacity
|
||||||
|
pub fn with_capacity(capacity: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
buffer: BytesMut::with_capacity(capacity),
|
||||||
|
target: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create with target size
|
||||||
|
pub fn with_target(target: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
buffer: BytesMut::with_capacity(target),
|
||||||
|
target: Some(target),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get current buffer length
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.buffer.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if buffer is empty
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.buffer.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if target is reached
|
||||||
|
pub fn is_complete(&self) -> bool {
|
||||||
|
match self.target {
|
||||||
|
Some(t) => self.buffer.len() >= t,
|
||||||
|
None => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get remaining bytes needed
|
||||||
|
pub fn remaining(&self) -> usize {
|
||||||
|
match self.target {
|
||||||
|
Some(t) => t.saturating_sub(self.buffer.len()),
|
||||||
|
None => 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append data to buffer
|
||||||
|
pub fn extend(&mut self, data: &[u8]) {
|
||||||
|
self.buffer.extend_from_slice(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Take all data from buffer
|
||||||
|
pub fn take(&mut self) -> Bytes {
|
||||||
|
self.target = None;
|
||||||
|
self.buffer.split().freeze()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Take exactly n bytes
|
||||||
|
pub fn take_exact(&mut self, n: usize) -> Option<Bytes> {
|
||||||
|
if self.buffer.len() >= n {
|
||||||
|
Some(self.buffer.split_to(n).freeze())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a slice of the buffer
|
||||||
|
pub fn as_slice(&self) -> &[u8] {
|
||||||
|
&self.buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get mutable access to underlying BytesMut
|
||||||
|
pub fn as_bytes_mut(&mut self) -> &mut BytesMut {
|
||||||
|
&mut self.buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the buffer
|
||||||
|
pub fn clear(&mut self) {
|
||||||
|
self.buffer.clear();
|
||||||
|
self.target = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set new target
|
||||||
|
pub fn set_target(&mut self, target: usize) {
|
||||||
|
self.target = Some(target);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ReadBuffer {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// State for buffered writing operations
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct WriteBuffer {
|
||||||
|
/// The buffer holding data to write
|
||||||
|
buffer: BytesMut,
|
||||||
|
/// Position of next byte to write
|
||||||
|
position: usize,
|
||||||
|
/// Maximum buffer size
|
||||||
|
max_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WriteBuffer {
|
||||||
|
/// Create new write buffer with default max size (256KB)
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::with_max_size(256 * 1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create with specific max size
|
||||||
|
pub fn with_max_size(max_size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
buffer: BytesMut::with_capacity(8192),
|
||||||
|
position: 0,
|
||||||
|
max_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get pending bytes count
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.buffer.len() - self.position
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if buffer is empty (all written)
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.position >= self.buffer.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if buffer is full
|
||||||
|
pub fn is_full(&self) -> bool {
|
||||||
|
self.buffer.len() >= self.max_size
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get remaining capacity
|
||||||
|
pub fn remaining_capacity(&self) -> usize {
|
||||||
|
self.max_size.saturating_sub(self.buffer.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append data to buffer
|
||||||
|
pub fn extend(&mut self, data: &[u8]) -> Result<(), ()> {
|
||||||
|
if self.buffer.len() + data.len() > self.max_size {
|
||||||
|
return Err(());
|
||||||
|
}
|
||||||
|
self.buffer.extend_from_slice(data);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get slice of data to write
|
||||||
|
pub fn pending(&self) -> &[u8] {
|
||||||
|
&self.buffer[self.position..]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Advance position by n bytes (after successful write)
|
||||||
|
pub fn advance(&mut self, n: usize) {
|
||||||
|
self.position += n;
|
||||||
|
|
||||||
|
// If all data written, reset buffer
|
||||||
|
if self.position >= self.buffer.len() {
|
||||||
|
self.buffer.clear();
|
||||||
|
self.position = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the buffer
|
||||||
|
pub fn clear(&mut self) {
|
||||||
|
self.buffer.clear();
|
||||||
|
self.position = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for WriteBuffer {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Fixed-Size Buffer States =============
|
||||||
|
|
||||||
|
/// State for reading a fixed-size header
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct HeaderBuffer<const N: usize> {
|
||||||
|
/// The buffer
|
||||||
|
data: [u8; N],
|
||||||
|
/// Bytes filled so far
|
||||||
|
filled: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const N: usize> HeaderBuffer<N> {
|
||||||
|
/// Create new empty header buffer
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
data: [0u8; N],
|
||||||
|
filled: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get slice for reading into
|
||||||
|
pub fn unfilled_mut(&mut self) -> &mut [u8] {
|
||||||
|
&mut self.data[self.filled..]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Advance filled count
|
||||||
|
pub fn advance(&mut self, n: usize) {
|
||||||
|
self.filled = (self.filled + n).min(N);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if completely filled
|
||||||
|
pub fn is_complete(&self) -> bool {
|
||||||
|
self.filled >= N
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get remaining bytes needed
|
||||||
|
pub fn remaining(&self) -> usize {
|
||||||
|
N - self.filled
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get filled bytes as slice
|
||||||
|
pub fn as_slice(&self) -> &[u8] {
|
||||||
|
&self.data[..self.filled]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get complete buffer (panics if not complete)
|
||||||
|
pub fn as_array(&self) -> &[u8; N] {
|
||||||
|
assert!(self.is_complete());
|
||||||
|
&self.data
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Take the buffer, resetting state
|
||||||
|
pub fn take(&mut self) -> [u8; N] {
|
||||||
|
let data = self.data;
|
||||||
|
self.data = [0u8; N];
|
||||||
|
self.filled = 0;
|
||||||
|
data
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reset to empty state
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.filled = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const N: usize> Default for HeaderBuffer<N> {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Yield Buffer =============
|
||||||
|
|
||||||
|
/// Buffer for yielding data to caller in chunks
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct YieldBuffer {
|
||||||
|
data: Bytes,
|
||||||
|
position: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl YieldBuffer {
|
||||||
|
/// Create new yield buffer
|
||||||
|
pub fn new(data: Bytes) -> Self {
|
||||||
|
Self { data, position: 0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if all data has been yielded
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.position >= self.data.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get remaining bytes
|
||||||
|
pub fn remaining(&self) -> usize {
|
||||||
|
self.data.len() - self.position
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copy data to output slice, return bytes copied
|
||||||
|
pub fn copy_to(&mut self, dst: &mut [u8]) -> usize {
|
||||||
|
let available = &self.data[self.position..];
|
||||||
|
let to_copy = available.len().min(dst.len());
|
||||||
|
dst[..to_copy].copy_from_slice(&available[..to_copy]);
|
||||||
|
self.position += to_copy;
|
||||||
|
to_copy
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get remaining data as slice
|
||||||
|
pub fn as_slice(&self) -> &[u8] {
|
||||||
|
&self.data[self.position..]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Macros =============
|
||||||
|
|
||||||
|
/// Macro to simplify state transitions in poll methods
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! transition {
|
||||||
|
(same) => {
|
||||||
|
$crate::stream::state::Transition::Same
|
||||||
|
};
|
||||||
|
(next $state:expr) => {
|
||||||
|
$crate::stream::state::Transition::Next($state)
|
||||||
|
};
|
||||||
|
(complete $output:expr) => {
|
||||||
|
$crate::stream::state::Transition::Complete($output)
|
||||||
|
};
|
||||||
|
(yield $output:expr, $state:expr) => {
|
||||||
|
$crate::stream::state::Transition::Yield($output, $state)
|
||||||
|
};
|
||||||
|
(error $err:expr) => {
|
||||||
|
$crate::stream::state::Transition::Error($err)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Macro to match poll ready or return pending
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! ready_or_pending {
|
||||||
|
($poll:expr) => {
|
||||||
|
match $poll {
|
||||||
|
std::task::Poll::Ready(t) => t,
|
||||||
|
std::task::Poll::Pending => return std::task::Poll::Pending,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_read_buffer_basic() {
|
||||||
|
let mut buf = ReadBuffer::with_target(10);
|
||||||
|
assert_eq!(buf.remaining(), 10);
|
||||||
|
assert!(!buf.is_complete());
|
||||||
|
|
||||||
|
buf.extend(b"hello");
|
||||||
|
assert_eq!(buf.len(), 5);
|
||||||
|
assert_eq!(buf.remaining(), 5);
|
||||||
|
assert!(!buf.is_complete());
|
||||||
|
|
||||||
|
buf.extend(b"world");
|
||||||
|
assert_eq!(buf.len(), 10);
|
||||||
|
assert!(buf.is_complete());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_read_buffer_take() {
|
||||||
|
let mut buf = ReadBuffer::new();
|
||||||
|
buf.extend(b"test data");
|
||||||
|
|
||||||
|
let data = buf.take();
|
||||||
|
assert_eq!(&data[..], b"test data");
|
||||||
|
assert!(buf.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_write_buffer_basic() {
|
||||||
|
let mut buf = WriteBuffer::with_max_size(100);
|
||||||
|
assert!(buf.is_empty());
|
||||||
|
|
||||||
|
buf.extend(b"hello").unwrap();
|
||||||
|
assert_eq!(buf.len(), 5);
|
||||||
|
assert!(!buf.is_empty());
|
||||||
|
|
||||||
|
buf.advance(3);
|
||||||
|
assert_eq!(buf.len(), 2);
|
||||||
|
assert_eq!(buf.pending(), b"lo");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_write_buffer_overflow() {
|
||||||
|
let mut buf = WriteBuffer::with_max_size(10);
|
||||||
|
assert!(buf.extend(b"short").is_ok());
|
||||||
|
assert!(buf.extend(b"toolong").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_header_buffer() {
|
||||||
|
let mut buf = HeaderBuffer::<5>::new();
|
||||||
|
assert!(!buf.is_complete());
|
||||||
|
assert_eq!(buf.remaining(), 5);
|
||||||
|
|
||||||
|
buf.unfilled_mut()[..3].copy_from_slice(b"hel");
|
||||||
|
buf.advance(3);
|
||||||
|
assert_eq!(buf.remaining(), 2);
|
||||||
|
|
||||||
|
buf.unfilled_mut()[..2].copy_from_slice(b"lo");
|
||||||
|
buf.advance(2);
|
||||||
|
assert!(buf.is_complete());
|
||||||
|
assert_eq!(buf.as_array(), b"hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_yield_buffer() {
|
||||||
|
let mut buf = YieldBuffer::new(Bytes::from_static(b"hello world"));
|
||||||
|
|
||||||
|
let mut dst = [0u8; 5];
|
||||||
|
assert_eq!(buf.copy_to(&mut dst), 5);
|
||||||
|
assert_eq!(&dst, b"hello");
|
||||||
|
|
||||||
|
assert_eq!(buf.remaining(), 6);
|
||||||
|
|
||||||
|
let mut dst = [0u8; 10];
|
||||||
|
assert_eq!(buf.copy_to(&mut dst), 6);
|
||||||
|
assert_eq!(&dst[..6], b" world");
|
||||||
|
|
||||||
|
assert!(buf.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_transition_map() {
|
||||||
|
let t: Transition<i32, String> = Transition::Complete("hello".to_string());
|
||||||
|
let t = t.map_output(|s| s.len());
|
||||||
|
|
||||||
|
match t {
|
||||||
|
Transition::Complete(5) => {}
|
||||||
|
_ => panic!("Expected Complete(5)"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_poll_result() {
|
||||||
|
let r: PollResult<i32> = PollResult::Ready(42);
|
||||||
|
assert!(r.is_ready());
|
||||||
|
assert_eq!(r.ok(), Some(42));
|
||||||
|
|
||||||
|
let r: PollResult<i32> = PollResult::Eof;
|
||||||
|
assert!(r.is_eof());
|
||||||
|
assert_eq!(r.ok(), None);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user