27 Commits

Author SHA1 Message Date
Alexey
5778be4f6e Update README.md 2026-01-02 19:10:12 +03:00
Alexey
f443d3dfc7 Update README.md 2026-01-02 16:54:35 +03:00
Alexey
450cf180ad Update README.md 2026-01-02 16:33:42 +03:00
Alexey
84fa7face0 Update README.md 2026-01-02 16:33:07 +03:00
Alexey
f8a2ea1972 Update README.md 2026-01-02 16:31:55 +03:00
Alexey
96d0a6bdfa Update README.md 2026-01-02 16:31:29 +03:00
Alexey
eeee55e8ea Update README.md 2026-01-02 16:21:52 +03:00
Alexey
7be179b3c0 Added accurate MTProto Frame Types + Tokio Async Intergr 2026-01-02 01:37:02 +03:00
Alexey
b2e034f8f1 Deleting of inconsiderately added rs 2026-01-02 01:20:26 +03:00
Alexey
ffe5a6cfb7 Fake TLS Fixes for Async IO
added more comments and schemas
2026-01-02 01:17:56 +03:00
Alexey
0e096ca8fb TLS Stream Tuning 2026-01-01 23:48:52 +03:00
Alexey
50658525cf Merge branch 'main' of https://github.com/telemt/telemt 2026-01-01 23:34:13 +03:00
Alexey
4fd5ff4e83 ET + SM + Crypto Fixes 2026-01-01 23:34:04 +03:00
Alexey
df4f312fec Update rust.yml 2025-12-31 06:04:56 +03:00
Alexey
7d9a8b99b4 Update rust.yml 2025-12-31 06:01:59 +03:00
Alexey
06f34e55cd Update rust.yml 2025-12-31 05:59:20 +03:00
Alexey
153cb7f3a3 Create rust.yml 2025-12-31 05:54:45 +03:00
Alexey
7f8904a989 Update README.md 2025-12-31 05:48:17 +03:00
Alexey
0ee71a59a0 Update README.md 2025-12-31 05:44:48 +03:00
Alexey
45c7347e22 Update README.md 2025-12-31 05:29:09 +03:00
Alexey
3805237d74 Update README.md 2025-12-31 05:28:32 +03:00
Alexey
5b281bf7fd Create telemt.service
based Systemd service
2025-12-31 05:10:18 +03:00
Alexey
d64cccd52c Update README.md 2025-12-31 04:45:28 +03:00
Alexey
016fdada68 Update README.md 2025-12-31 04:39:49 +03:00
Alexey
2c2ceeaf54 Update README.md 2025-12-30 22:18:22 +03:00
Alexey
dd6badd786 Update README.md 2025-12-30 21:31:54 +03:00
Alexey
50e72368c8 Update README.md 2025-12-30 21:29:04 +03:00
14 changed files with 5081 additions and 424 deletions

41
.github/workflows/rust.yml vendored Normal file
View File

@@ -0,0 +1,41 @@
name: Rust
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
env:
CARGO_TERM_COLOR: always
jobs:
build-and-test:
name: Build & Test
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install latest stable Rust toolchain
uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt, clippy
- name: Cache cargo registry & build artifacts
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
- name: Build Release
run: cargo build --release --verbose
- name: Check for unused dependencies
run: cargo udeps || true

View File

@@ -46,6 +46,7 @@ base64 = "0.21"
url = "2.5"
regex = "1.10"
once_cell = "1.19"
crossbeam-queue = "0.3"
# HTTP
reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false }

160
README.md
View File

@@ -1,2 +1,158 @@
# telemt
MTProxy for Telegram on Rust + Tokio
# Telemt - MTProxy on Rust + Tokio
**Telemt** is a fast, secure, and feature-rich server written in Rust: it fully implements the official Telegram proxy algo and adds many production-ready improvements such as connection pooling, replay protection, detailed statistics, masking from "prying" eyes
# GOTO
- [Features](#features)
- [Quick Start Guide](#quick-start-guide)
- [Build](#build)
- [How to use?](#how-to-use)
- [Systemd Method](#telemt-via-systemd)
- [FAQ](#faq)
- [Telegram Calls](#telegram-calls-via-mtproxy)
- [DPI](#how-does-dpi-see-mtproxy-tls)
- [Whitelist on Network Level](#whitelist-on-ip)
- [Why Rust?](#why-rust)
## Features
- Full support for all official MTProto proxy modes:
- Classic
- Secure - with `dd` prefix
- Fake TLS - with `ee` prefix + SNI fronting
- Replay attack protection
- Optional traffic masking: forward unrecognized connections to a real web server, e.g. GitHub 🤪
- Configurable keepalives + timeouts + IPv6 and "Fast Mode"
- Graceful shutdown on Ctrl+C
- Extensive logging via `trace` and `debug` with `RUST_LOG` method
## Quick Start Guide
### Build
```bash
# Cloning repo
git clone https://github.com/telemt/telemt
# Changing Directory to telemt
cd telemt
# Starting Release Build
cargo build --release
# Move to /bin
mv ./target/release/telemt /bin
# Make executable
chmod +x /bin/telemt
# Lets go!
telemt config.toml
```
## How to use?
### Telemt via Systemd
**0. Check port and generate secrets**
The port you have selected for use should be MISSING from the list, when:
```bash
netstat -lnp
```
Generate 16 bytes/32 characters HEX with OpenSSL or another way:
```bash
openssl rand -hex 16
```
**1. Place your config to /etc/telemt.toml**
Open nano
```bash
nano /etc/telemt.toml
```
```bash
port = 443 # Listening port
[users]
hello = "00000000000000000000000000000000" # Replace the secret with one generated before
[modes]
classic = false # Plain obfuscated mode
secure = false # dd-prefix mode
tls = true # Fake TLS - ee-prefix
tls_domain = "petrovich.ru" # Domain for ee-secret and masking
mask = true # Enable masking of bad traffic
mask_host = "petrovich.ru" # Optional override for mask destination
mask_port = 443 # Port for masking
prefer_ipv6 = false # Try IPv6 DCs first if true
fast_mode = true # Use "fast" obfuscation variant
client_keepalive = 600 # Seconds
client_ack_timeout = 300 # Seconds
```
then Ctrl+X -> Y -> Enter to save
**2. Create service on /etc/systemd/system/telemt.service**
Open nano
```bash
nano /etc/systemd/system/telemt.service
```
paste this Systemd Module
```bash
[Unit]
Description=Telemt
After=network.target
[Service]
Type=simple
WorkingDirectory=/bin
ExecStart=/bin/telemt /etc/telemt.toml
Restart=on-failure
[Install]
WantedBy=multi-user.target
```
then Ctrl+X -> Y -> Enter to save
**3.** In Shell type `systemctl start telemt` - it must start with zero exit-code
**4.** In Shell type `systemctl status telemt` - there you can reach info about current MTProxy status
**5.** In Shell type `systemctl enable telemt` - then telemt will start with system startup, after the network is up
## FAQ
### Telegram Calls via MTProxy
- Telegram architecture does **NOT allow calls via MTProxy**, but only via SOCKS5, which cannot be obfuscated
### How does DPI see MTProxy TLS?
- DPI sees MTProxy in Fake TLS (ee) mode as TLS 1.3
- the SNI you specify sends both the client and the server;
- ALPN is similar to HTTP 1.1/2;
- high entropy, which is normal for AES-encrypted traffic;
### Whitelist on IP
- MTProxy cannot work when there is:
- no IP connectivity to the target host
- OR all TCP traffic is blocked
- OR all TLS traffic is blocked,
- like most protocols on the Internet;
- this situation is observed in China behind the Great Chinese Firewall and in Russia on mobile networks
## Why Rust?
- Long-running reliability and idempotent behavior
- Rusts deterministic resource management - RAII
- No garbage collector
- Memory safety and reduced attack surface
- Tokio's asynchronous architecture
## Roadmap
- Public IP in links
- Config Reload-on-fly
- Bind to device or IP for outbound/inbound connections
- Adtag Support per SNI / Secret
- Fail-fast on start + Fail-soft on runtime (only WARN/ERROR)
- Zero-copy, minimal allocs on hotpath
- DC Healthchecks + global fallback
- No global mutable state
- Client isolation + Fair Bandwidth
- Backpressure-aware IO
- "Secret Policy" - SNI / Secret Routing :D
- Multi-upstream Balancer and Failover
- Strict FSM per handshake
- Session-based Antireplay with Sliding window, non-broking reconnects
- Web Control: statistic, state of health, latency, client experience...

View File

@@ -1,21 +1,24 @@
//! AES
//! AES encryption implementations
//!
//! Provides AES-256-CTR and AES-256-CBC modes for MTProto encryption.
use aes::Aes256;
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
use cbc::{Encryptor as CbcEncryptor, Decryptor as CbcDecryptor};
use cbc::cipher::{BlockEncryptMut, BlockDecryptMut, block_padding::NoPadding};
use crate::error::{ProxyError, Result};
type Aes256Ctr = Ctr128BE<Aes256>;
type Aes256CbcEnc = CbcEncryptor<Aes256>;
type Aes256CbcDec = CbcDecryptor<Aes256>;
// ============= AES-256-CTR =============
/// AES-256-CTR encryptor/decryptor
///
/// CTR mode is symmetric - encryption and decryption are the same operation.
pub struct AesCtr {
cipher: Aes256Ctr,
}
impl AesCtr {
/// Create new AES-CTR cipher with key and IV
pub fn new(key: &[u8; 32], iv: u128) -> Self {
let iv_bytes = iv.to_be_bytes();
Self {
@@ -23,6 +26,7 @@ impl AesCtr {
}
}
/// Create from key and IV slices
pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result<Self> {
if key.len() != 32 {
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
@@ -54,17 +58,28 @@ impl AesCtr {
}
}
/// AES-256-CBC Ciphermagic
// ============= AES-256-CBC =============
/// AES-256-CBC cipher with proper chaining
///
/// Unlike CTR mode, CBC is NOT symmetric - encryption and decryption
/// are different operations. This implementation handles CBC chaining
/// correctly across multiple blocks.
pub struct AesCbc {
key: [u8; 32],
iv: [u8; 16],
}
impl AesCbc {
/// AES block size
const BLOCK_SIZE: usize = 16;
/// Create new AES-CBC cipher with key and IV
pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self {
Self { key, iv }
}
/// Create from slices
pub fn from_slices(key: &[u8], iv: &[u8]) -> Result<Self> {
if key.len() != 32 {
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
@@ -79,32 +94,36 @@ impl AesCbc {
})
}
/// Encrypt data using CBC mode
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() % 16 != 0 {
return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
));
}
if data.is_empty() {
return Ok(Vec::new());
}
let mut buffer = data.to_vec();
let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into());
for chunk in buffer.chunks_mut(16) {
encryptor.encrypt_block_mut(chunk.into());
}
Ok(buffer)
/// Encrypt a single block using raw AES (no chaining)
fn encrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] {
use aes::cipher::BlockEncrypt;
let mut output = *block;
key_schedule.encrypt_block((&mut output).into());
output
}
/// Decrypt data using CBC mode
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() % 16 != 0 {
/// Decrypt a single block using raw AES (no chaining)
fn decrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] {
use aes::cipher::BlockDecrypt;
let mut output = *block;
key_schedule.decrypt_block((&mut output).into());
output
}
/// XOR two 16-byte blocks
fn xor_blocks(a: &[u8; 16], b: &[u8; 16]) -> [u8; 16] {
let mut result = [0u8; 16];
for i in 0..16 {
result[i] = a[i] ^ b[i];
}
result
}
/// Encrypt data using CBC mode with proper chaining
///
/// CBC Encryption: C[i] = AES_Encrypt(P[i] XOR C[i-1]), where C[-1] = IV
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() % Self::BLOCK_SIZE != 0 {
return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
));
@@ -114,20 +133,73 @@ impl AesCbc {
return Ok(Vec::new());
}
let mut buffer = data.to_vec();
use aes::cipher::KeyInit;
let key_schedule = aes::Aes256::new((&self.key).into());
let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into());
let mut result = Vec::with_capacity(data.len());
let mut prev_ciphertext = self.iv;
for chunk in buffer.chunks_mut(16) {
decryptor.decrypt_block_mut(chunk.into());
for chunk in data.chunks(Self::BLOCK_SIZE) {
let plaintext: [u8; 16] = chunk.try_into().unwrap();
// XOR plaintext with previous ciphertext (or IV for first block)
let xored = Self::xor_blocks(&plaintext, &prev_ciphertext);
// Encrypt the XORed block
let ciphertext = self.encrypt_block(&xored, &key_schedule);
// Save for next iteration
prev_ciphertext = ciphertext;
// Append to result
result.extend_from_slice(&ciphertext);
}
Ok(buffer)
Ok(result)
}
/// Decrypt data using CBC mode with proper chaining
///
/// CBC Decryption: P[i] = AES_Decrypt(C[i]) XOR C[i-1], where C[-1] = IV
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() % Self::BLOCK_SIZE != 0 {
return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
));
}
if data.is_empty() {
return Ok(Vec::new());
}
use aes::cipher::KeyInit;
let key_schedule = aes::Aes256::new((&self.key).into());
let mut result = Vec::with_capacity(data.len());
let mut prev_ciphertext = self.iv;
for chunk in data.chunks(Self::BLOCK_SIZE) {
let ciphertext: [u8; 16] = chunk.try_into().unwrap();
// Decrypt the block
let decrypted = self.decrypt_block(&ciphertext, &key_schedule);
// XOR with previous ciphertext (or IV for first block)
let plaintext = Self::xor_blocks(&decrypted, &prev_ciphertext);
// Save current ciphertext for next iteration
prev_ciphertext = ciphertext;
// Append to result
result.extend_from_slice(&plaintext);
}
Ok(result)
}
/// Encrypt data in-place
pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
if data.len() % 16 != 0 {
if data.len() % Self::BLOCK_SIZE != 0 {
return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
));
@@ -137,10 +209,25 @@ impl AesCbc {
return Ok(());
}
let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into());
use aes::cipher::KeyInit;
let key_schedule = aes::Aes256::new((&self.key).into());
for chunk in data.chunks_mut(16) {
encryptor.encrypt_block_mut(chunk.into());
let mut prev_ciphertext = self.iv;
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
let block = &mut data[i..i + Self::BLOCK_SIZE];
// XOR with previous ciphertext
for j in 0..Self::BLOCK_SIZE {
block[j] ^= prev_ciphertext[j];
}
// Encrypt in-place
let block_array: &mut [u8; 16] = block.try_into().unwrap();
*block_array = self.encrypt_block(block_array, &key_schedule);
// Save for next iteration
prev_ciphertext = *block_array;
}
Ok(())
@@ -148,7 +235,7 @@ impl AesCbc {
/// Decrypt data in-place
pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
if data.len() % 16 != 0 {
if data.len() % Self::BLOCK_SIZE != 0 {
return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
));
@@ -158,16 +245,38 @@ impl AesCbc {
return Ok(());
}
let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into());
use aes::cipher::KeyInit;
let key_schedule = aes::Aes256::new((&self.key).into());
for chunk in data.chunks_mut(16) {
decryptor.decrypt_block_mut(chunk.into());
// For in-place decryption, we need to save ciphertext blocks
// before we overwrite them
let mut prev_ciphertext = self.iv;
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
let block = &mut data[i..i + Self::BLOCK_SIZE];
// Save current ciphertext before modifying
let current_ciphertext: [u8; 16] = block.try_into().unwrap();
// Decrypt in-place
let block_array: &mut [u8; 16] = block.try_into().unwrap();
*block_array = self.decrypt_block(block_array, &key_schedule);
// XOR with previous ciphertext
for j in 0..Self::BLOCK_SIZE {
block[j] ^= prev_ciphertext[j];
}
// Save for next iteration
prev_ciphertext = current_ciphertext;
}
Ok(())
}
}
// ============= Encryption Traits =============
/// Trait for unified encryption interface
pub trait Encryptor: Send + Sync {
fn encrypt(&mut self, data: &[u8]) -> Vec<u8>;
@@ -209,6 +318,8 @@ impl Decryptor for PassthroughEncryptor {
mod tests {
use super::*;
// ============= AES-CTR Tests =============
#[test]
fn test_aes_ctr_roundtrip() {
let key = [0u8; 32];
@@ -225,13 +336,35 @@ mod tests {
assert_eq!(original.as_slice(), decrypted.as_slice());
}
#[test]
fn test_aes_ctr_in_place() {
let key = [0x42u8; 32];
let iv = 999u128;
let original = b"Test data for in-place encryption";
let mut data = original.to_vec();
let mut cipher = AesCtr::new(&key, iv);
cipher.apply(&mut data);
// Encrypted should be different
assert_ne!(&data[..], original);
// Decrypt with fresh cipher
let mut cipher = AesCtr::new(&key, iv);
cipher.apply(&mut data);
assert_eq!(&data[..], original);
}
// ============= AES-CBC Tests =============
#[test]
fn test_aes_cbc_roundtrip() {
let key = [0u8; 32];
let iv = [0u8; 16];
// Must be aligned to 16 bytes
let original = [0u8; 32];
let original = [0u8; 32]; // 2 blocks
let cipher = AesCbc::new(key, iv);
let encrypted = cipher.encrypt(&original).unwrap();
@@ -242,47 +375,59 @@ mod tests {
#[test]
fn test_aes_cbc_chaining_works() {
// This is the key test - verify CBC chaining is correct
let key = [0x42u8; 32];
let iv = [0x00u8; 16];
let plaintext = [0xAA_u8; 32];
// Two IDENTICAL plaintext blocks
let plaintext = [0xAAu8; 32];
let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap();
// CBC Corrections
// With proper CBC, identical plaintext blocks produce DIFFERENT ciphertext
let block1 = &ciphertext[0..16];
let block2 = &ciphertext[16..32];
assert_ne!(block1, block2, "CBC chaining broken: identical plaintext blocks produced identical ciphertext");
assert_ne!(
block1, block2,
"CBC chaining broken: identical plaintext blocks produced identical ciphertext. \
This indicates ECB mode, not CBC!"
);
}
#[test]
fn test_aes_cbc_known_vector() {
// Test with known NIST test vector
// AES-256-CBC with zero key and zero IV
let key = [0u8; 32];
let iv = [0u8; 16];
// 3 Datablocks
let plaintext = [
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF,
// Block 2
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF,
// Block 3 - different
0xFF, 0xEE, 0xDD, 0xCC, 0xBB, 0xAA, 0x99, 0x88,
0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11, 0x00,
];
let plaintext = [0u8; 16];
let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap();
// Decrypt + Verify
// Decrypt and verify roundtrip
let decrypted = cipher.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
// Verify Ciphertexts Block 1 != Block 2
assert_ne!(&ciphertext[0..16], &ciphertext[16..32]);
// Ciphertext should not be all zeros
assert_ne!(ciphertext.as_slice(), plaintext.as_slice());
}
#[test]
fn test_aes_cbc_multi_block() {
let key = [0x12u8; 32];
let iv = [0x34u8; 16];
// 5 blocks = 80 bytes
let plaintext: Vec<u8> = (0..80).collect();
let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap();
let decrypted = cipher.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext, decrypted);
}
#[test]
@@ -291,7 +436,7 @@ mod tests {
let iv = [0x34u8; 16];
let original = [0x56u8; 48]; // 3 blocks
let mut buffer = original.clone();
let mut buffer = original;
let cipher = AesCbc::new(key, iv);
@@ -317,35 +462,85 @@ mod tests {
fn test_aes_cbc_unaligned_error() {
let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
// 15 bytes
// 15 bytes - not aligned to block size
let result = cipher.encrypt(&[0u8; 15]);
assert!(result.is_err());
// 17 bytes
// 17 bytes - not aligned
let result = cipher.encrypt(&[0u8; 17]);
assert!(result.is_err());
}
#[test]
fn test_aes_cbc_avalanche_effect() {
// Cipherplane
// Changing one bit in plaintext should change entire ciphertext block
// and all subsequent blocks (due to chaining)
let key = [0xAB; 32];
let iv = [0xCD; 16];
let mut plaintext1 = [0u8; 32];
let mut plaintext2 = [0u8; 32];
plaintext2[0] = 0x01; // Один бит отличается
plaintext2[0] = 0x01; // Single bit difference in first block
let cipher = AesCbc::new(key, iv);
let ciphertext1 = cipher.encrypt(&plaintext1).unwrap();
let ciphertext2 = cipher.encrypt(&plaintext2).unwrap();
// First Blocks Diff
// First blocks should be different
assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]);
// Second Blocks Diff
// Second blocks should ALSO be different (chaining effect)
assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]);
}
#[test]
fn test_aes_cbc_iv_matters() {
// Same plaintext with different IVs should produce different ciphertext
let key = [0x55; 32];
let plaintext = [0x77u8; 16];
let cipher1 = AesCbc::new(key, [0u8; 16]);
let cipher2 = AesCbc::new(key, [1u8; 16]);
let ciphertext1 = cipher1.encrypt(&plaintext).unwrap();
let ciphertext2 = cipher2.encrypt(&plaintext).unwrap();
assert_ne!(ciphertext1, ciphertext2);
}
#[test]
fn test_aes_cbc_deterministic() {
// Same key, IV, plaintext should always produce same ciphertext
let key = [0x99; 32];
let iv = [0x88; 16];
let plaintext = [0x77u8; 32];
let cipher = AesCbc::new(key, iv);
let ciphertext1 = cipher.encrypt(&plaintext).unwrap();
let ciphertext2 = cipher.encrypt(&plaintext).unwrap();
assert_eq!(ciphertext1, ciphertext2);
}
// ============= Error Handling Tests =============
#[test]
fn test_invalid_key_length() {
let result = AesCtr::from_key_iv(&[0u8; 16], &[0u8; 16]);
assert!(result.is_err());
let result = AesCbc::from_slices(&[0u8; 16], &[0u8; 16]);
assert!(result.is_err());
}
#[test]
fn test_invalid_iv_length() {
let result = AesCtr::from_key_iv(&[0u8; 32], &[0u8; 8]);
assert!(result.is_err());
let result = AesCbc::from_slices(&[0u8; 32], &[0u8; 8]);
assert!(result.is_err());
}
}

View File

@@ -1,8 +1,177 @@
//! Error Types
use std::fmt;
use std::net::SocketAddr;
use thiserror::Error;
// ============= Stream Errors =============
/// Errors specific to stream I/O operations
#[derive(Debug)]
pub enum StreamError {
/// Partial read: got fewer bytes than expected
PartialRead {
expected: usize,
got: usize,
},
/// Partial write: wrote fewer bytes than expected
PartialWrite {
expected: usize,
written: usize,
},
/// Stream is in poisoned state and cannot be used
Poisoned {
reason: String,
},
/// Buffer overflow: attempted to buffer more than allowed
BufferOverflow {
limit: usize,
attempted: usize,
},
/// Invalid frame format
InvalidFrame {
details: String,
},
/// Unexpected end of stream
UnexpectedEof,
/// Underlying I/O error
Io(std::io::Error),
}
impl fmt::Display for StreamError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::PartialRead { expected, got } => {
write!(f, "partial read: expected {} bytes, got {}", expected, got)
}
Self::PartialWrite { expected, written } => {
write!(f, "partial write: expected {} bytes, wrote {}", expected, written)
}
Self::Poisoned { reason } => {
write!(f, "stream poisoned: {}", reason)
}
Self::BufferOverflow { limit, attempted } => {
write!(f, "buffer overflow: limit {}, attempted {}", limit, attempted)
}
Self::InvalidFrame { details } => {
write!(f, "invalid frame: {}", details)
}
Self::UnexpectedEof => {
write!(f, "unexpected end of stream")
}
Self::Io(e) => {
write!(f, "I/O error: {}", e)
}
}
}
}
impl std::error::Error for StreamError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
_ => None,
}
}
}
impl From<std::io::Error> for StreamError {
fn from(err: std::io::Error) -> Self {
Self::Io(err)
}
}
impl From<StreamError> for std::io::Error {
fn from(err: StreamError) -> Self {
match err {
StreamError::Io(e) => e,
StreamError::UnexpectedEof => {
std::io::Error::new(std::io::ErrorKind::UnexpectedEof, err)
}
StreamError::Poisoned { .. } => {
std::io::Error::new(std::io::ErrorKind::Other, err)
}
StreamError::BufferOverflow { .. } => {
std::io::Error::new(std::io::ErrorKind::OutOfMemory, err)
}
StreamError::InvalidFrame { .. } => {
std::io::Error::new(std::io::ErrorKind::InvalidData, err)
}
StreamError::PartialRead { .. } | StreamError::PartialWrite { .. } => {
std::io::Error::new(std::io::ErrorKind::Other, err)
}
}
}
}
// ============= Recoverable Trait =============
/// Trait for errors that may be recoverable
pub trait Recoverable {
/// Check if error is recoverable (can retry operation)
fn is_recoverable(&self) -> bool;
/// Check if connection can continue after this error
fn can_continue(&self) -> bool;
}
impl Recoverable for StreamError {
fn is_recoverable(&self) -> bool {
match self {
// Partial operations can be retried
Self::PartialRead { .. } | Self::PartialWrite { .. } => true,
// I/O errors depend on kind
Self::Io(e) => matches!(
e.kind(),
std::io::ErrorKind::WouldBlock
| std::io::ErrorKind::Interrupted
| std::io::ErrorKind::TimedOut
),
// These are not recoverable
Self::Poisoned { .. }
| Self::BufferOverflow { .. }
| Self::InvalidFrame { .. }
| Self::UnexpectedEof => false,
}
}
fn can_continue(&self) -> bool {
match self {
// Poisoned stream cannot be used
Self::Poisoned { .. } => false,
// EOF means stream is done
Self::UnexpectedEof => false,
// Buffer overflow is fatal
Self::BufferOverflow { .. } => false,
// Others might allow continuation
_ => true,
}
}
}
impl Recoverable for std::io::Error {
fn is_recoverable(&self) -> bool {
matches!(
self.kind(),
std::io::ErrorKind::WouldBlock
| std::io::ErrorKind::Interrupted
| std::io::ErrorKind::TimedOut
)
}
fn can_continue(&self) -> bool {
!matches!(
self.kind(),
std::io::ErrorKind::BrokenPipe
| std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::NotConnected
)
}
}
// ============= Main Proxy Errors =============
#[derive(Error, Debug)]
pub enum ProxyError {
// ============= Crypto Errors =============
@@ -13,6 +182,11 @@ pub enum ProxyError {
#[error("Invalid key length: expected {expected}, got {got}")]
InvalidKeyLength { expected: usize, got: usize },
// ============= Stream Errors =============
#[error("Stream error: {0}")]
Stream(#[from] StreamError),
// ============= Protocol Errors =============
#[error("Invalid handshake: {0}")]
@@ -39,6 +213,12 @@ pub enum ProxyError {
#[error("Sequence number mismatch: expected={expected}, got={got}")]
SeqNoMismatch { expected: i32, got: i32 },
#[error("TLS handshake failed: {reason}")]
TlsHandshakeFailed { reason: String },
#[error("Telegram handshake timeout")]
TgHandshakeTimeout,
// ============= Network Errors =============
#[error("Connection timeout to {addr}")]
@@ -77,15 +257,41 @@ pub enum ProxyError {
#[error("Unknown user")]
UnknownUser,
#[error("Rate limited")]
RateLimited,
// ============= General Errors =============
#[error("Internal error: {0}")]
Internal(String),
}
impl Recoverable for ProxyError {
fn is_recoverable(&self) -> bool {
match self {
Self::Stream(e) => e.is_recoverable(),
Self::Io(e) => e.is_recoverable(),
Self::ConnectionTimeout { .. } => true,
Self::RateLimited => true,
_ => false,
}
}
fn can_continue(&self) -> bool {
match self {
Self::Stream(e) => e.can_continue(),
Self::Io(e) => e.can_continue(),
_ => false,
}
}
}
/// Convenient Result type alias
pub type Result<T> = std::result::Result<T, ProxyError>;
/// Result type for stream operations
pub type StreamResult<T> = std::result::Result<T, StreamError>;
/// Result with optional bad client handling
#[derive(Debug)]
pub enum HandshakeResult<T> {
@@ -125,6 +331,14 @@ impl<T> HandshakeResult<T> {
HandshakeResult::Error(e) => HandshakeResult::Error(e),
}
}
/// Convert success to Option
pub fn ok(self) -> Option<T> {
match self {
HandshakeResult::Success(v) => Some(v),
_ => None,
}
}
}
impl<T> From<ProxyError> for HandshakeResult<T> {
@@ -139,10 +353,48 @@ impl<T> From<std::io::Error> for HandshakeResult<T> {
}
}
impl<T> From<StreamError> for HandshakeResult<T> {
fn from(err: StreamError) -> Self {
HandshakeResult::Error(ProxyError::Stream(err))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stream_error_display() {
let err = StreamError::PartialRead { expected: 100, got: 50 };
assert!(err.to_string().contains("100"));
assert!(err.to_string().contains("50"));
let err = StreamError::Poisoned { reason: "test".into() };
assert!(err.to_string().contains("test"));
}
#[test]
fn test_stream_error_recoverable() {
assert!(StreamError::PartialRead { expected: 10, got: 5 }.is_recoverable());
assert!(StreamError::PartialWrite { expected: 10, written: 5 }.is_recoverable());
assert!(!StreamError::Poisoned { reason: "x".into() }.is_recoverable());
assert!(!StreamError::UnexpectedEof.is_recoverable());
}
#[test]
fn test_stream_error_can_continue() {
assert!(!StreamError::Poisoned { reason: "x".into() }.can_continue());
assert!(!StreamError::UnexpectedEof.can_continue());
assert!(StreamError::PartialRead { expected: 10, got: 5 }.can_continue());
}
#[test]
fn test_stream_error_to_io_error() {
let stream_err = StreamError::UnexpectedEof;
let io_err: std::io::Error = stream_err.into();
assert_eq!(io_err.kind(), std::io::ErrorKind::UnexpectedEof);
}
#[test]
fn test_handshake_result() {
let success: HandshakeResult<i32> = HandshakeResult::Success(42);
@@ -165,6 +417,15 @@ mod tests {
}
}
#[test]
fn test_proxy_error_recoverable() {
let err = ProxyError::RateLimited;
assert!(err.is_recoverable());
let err = ProxyError::InvalidHandshake("bad".into());
assert!(!err.is_recoverable());
}
#[test]
fn test_error_display() {
let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() };

View File

@@ -1,14 +1,22 @@
//! Fake TLS 1.3 Handshake
//!
//! This module handles the fake TLS 1.3 handshake used by MTProto proxy
//! for domain fronting. The handshake looks like valid TLS 1.3 but
//! actually carries MTProto authentication data.
use crate::crypto::{sha256_hmac, random::SECURE_RANDOM};
use crate::error::{ProxyError, Result};
use super::constants::*;
use std::time::{SystemTime, UNIX_EPOCH};
// ============= Public Constants =============
/// TLS handshake digest length
pub const TLS_DIGEST_LEN: usize = 32;
/// Position of digest in TLS ClientHello
pub const TLS_DIGEST_POS: usize = 11;
/// Length to store for replay protection (first 16 bytes of digest)
pub const TLS_DIGEST_HALF_LEN: usize = 16;
@@ -16,6 +24,26 @@ pub const TLS_DIGEST_HALF_LEN: usize = 16;
pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before
pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after
// ============= Private Constants =============
/// TLS Extension types
mod extension_type {
pub const KEY_SHARE: u16 = 0x0033;
pub const SUPPORTED_VERSIONS: u16 = 0x002b;
}
/// TLS Cipher Suites
mod cipher_suite {
pub const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
}
/// TLS Named Curves
mod named_curve {
pub const X25519: u16 = 0x001d;
}
// ============= TLS Validation Result =============
/// Result of validating TLS handshake
#[derive(Debug)]
pub struct TlsValidation {
@@ -29,7 +57,185 @@ pub struct TlsValidation {
pub timestamp: u32,
}
// ============= TLS Extension Builder =============
/// Builder for TLS extensions with correct length calculation
struct TlsExtensionBuilder {
extensions: Vec<u8>,
}
impl TlsExtensionBuilder {
fn new() -> Self {
Self {
extensions: Vec::with_capacity(128),
}
}
/// Add Key Share extension with X25519 key
fn add_key_share(&mut self, public_key: &[u8; 32]) -> &mut Self {
// Extension type: key_share (0x0033)
self.extensions.extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes());
// Key share entry: curve (2) + key_len (2) + key (32) = 36 bytes
// Extension data length
let entry_len: u16 = 2 + 2 + 32; // curve + length + key
self.extensions.extend_from_slice(&entry_len.to_be_bytes());
// Named curve: x25519
self.extensions.extend_from_slice(&named_curve::X25519.to_be_bytes());
// Key length
self.extensions.extend_from_slice(&(32u16).to_be_bytes());
// Key data
self.extensions.extend_from_slice(public_key);
self
}
/// Add Supported Versions extension
fn add_supported_versions(&mut self, version: u16) -> &mut Self {
// Extension type: supported_versions (0x002b)
self.extensions.extend_from_slice(&extension_type::SUPPORTED_VERSIONS.to_be_bytes());
// Extension data: length (2) + version (2)
self.extensions.extend_from_slice(&(2u16).to_be_bytes());
// Selected version
self.extensions.extend_from_slice(&version.to_be_bytes());
self
}
/// Build final extensions with length prefix
fn build(self) -> Vec<u8> {
let mut result = Vec::with_capacity(2 + self.extensions.len());
// Extensions length (2 bytes)
let len = self.extensions.len() as u16;
result.extend_from_slice(&len.to_be_bytes());
// Extensions data
result.extend_from_slice(&self.extensions);
result
}
/// Get current extensions without length prefix (for calculation)
#[allow(dead_code)]
fn as_bytes(&self) -> &[u8] {
&self.extensions
}
}
// ============= ServerHello Builder =============
/// Builder for TLS ServerHello with correct structure
struct ServerHelloBuilder {
/// Random bytes (32 bytes, will contain digest)
random: [u8; 32],
/// Session ID (echoed from ClientHello)
session_id: Vec<u8>,
/// Cipher suite
cipher_suite: [u8; 2],
/// Compression method
compression: u8,
/// Extensions
extensions: TlsExtensionBuilder,
}
impl ServerHelloBuilder {
fn new(session_id: Vec<u8>) -> Self {
Self {
random: [0u8; 32],
session_id,
cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256,
compression: 0x00,
extensions: TlsExtensionBuilder::new(),
}
}
fn with_x25519_key(mut self, key: &[u8; 32]) -> Self {
self.extensions.add_key_share(key);
self
}
fn with_tls13_version(mut self) -> Self {
// TLS 1.3 = 0x0304
self.extensions.add_supported_versions(0x0304);
self
}
/// Build ServerHello message (without record header)
fn build_message(&self) -> Vec<u8> {
let extensions = self.extensions.extensions.clone();
let extensions_len = extensions.len() as u16;
// Calculate total length
let body_len = 2 + // version
32 + // random
1 + self.session_id.len() + // session_id length + data
2 + // cipher suite
1 + // compression
2 + extensions.len(); // extensions length + data
let mut message = Vec::with_capacity(4 + body_len);
// Handshake header
message.push(0x02); // ServerHello message type
// 3-byte length
let len_bytes = (body_len as u32).to_be_bytes();
message.extend_from_slice(&len_bytes[1..4]);
// Server version (TLS 1.2 in header, actual version in extension)
message.extend_from_slice(&TLS_VERSION);
// Random (32 bytes) - placeholder, will be replaced with digest
message.extend_from_slice(&self.random);
// Session ID
message.push(self.session_id.len() as u8);
message.extend_from_slice(&self.session_id);
// Cipher suite
message.extend_from_slice(&self.cipher_suite);
// Compression method
message.push(self.compression);
// Extensions length
message.extend_from_slice(&extensions_len.to_be_bytes());
// Extensions data
message.extend_from_slice(&extensions);
message
}
/// Build complete ServerHello TLS record
fn build_record(&self) -> Vec<u8> {
let message = self.build_message();
let mut record = Vec::with_capacity(5 + message.len());
// TLS record header
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(message.len() as u16).to_be_bytes());
// Message
record.extend_from_slice(&message);
record
}
}
// ============= Public Functions =============
/// Validate TLS ClientHello against user secrets
///
/// Returns validation result if a matching user is found.
pub fn validate_tls_handshake(
handshake: &[u8],
secrets: &[(String, Vec<u8>)],
@@ -86,7 +292,8 @@ pub fn validate_tls_handshake(
// Check time skew
if !ignore_time_skew {
// Allow very small timestamps (boot time instead of unix time)
let is_boot_time = timestamp < 60 * 60 * 24 * 1000;
// This is a quirk in some clients that use uptime instead of real time
let is_boot_time = timestamp < 60 * 60 * 24 * 1000; // < ~2.7 years in seconds
if !is_boot_time && (time_diff < TIME_SKEW_MIN || time_diff > TIME_SKEW_MAX) {
continue;
@@ -105,15 +312,22 @@ pub fn validate_tls_handshake(
}
/// Generate a fake X25519 public key for TLS
/// This generates a value that looks like a valid X25519 key
///
/// This generates random bytes that look like a valid X25519 public key.
/// Since we're not doing real TLS, the actual cryptographic properties don't matter.
pub fn gen_fake_x25519_key() -> [u8; 32] {
// For simplicity, just generate random 32 bytes
// In real X25519, this would be a point on the curve
let bytes = SECURE_RANDOM.bytes(32);
bytes.try_into().unwrap()
}
/// Build TLS ServerHello response
///
/// This builds a complete TLS 1.3-like response including:
/// - ServerHello record with extensions
/// - Change Cipher Spec record
/// - Fake encrypted certificate (Application Data record)
///
/// The response includes an HMAC digest that the client can verify.
pub fn build_server_hello(
secret: &[u8],
client_digest: &[u8; TLS_DIGEST_LEN],
@@ -122,62 +336,48 @@ pub fn build_server_hello(
) -> Vec<u8> {
let x25519_key = gen_fake_x25519_key();
// TLS extensions
let mut extensions = Vec::new();
extensions.extend_from_slice(&[0x00, 0x2e]); // Extension length placeholder
extensions.extend_from_slice(&[0x00, 0x33, 0x00, 0x24]); // Key share extension
extensions.extend_from_slice(&[0x00, 0x1d, 0x00, 0x20]); // X25519 curve
extensions.extend_from_slice(&x25519_key);
extensions.extend_from_slice(&[0x00, 0x2b, 0x00, 0x02, 0x03, 0x04]); // Supported versions
// Build ServerHello
let server_hello = ServerHelloBuilder::new(session_id.to_vec())
.with_x25519_key(&x25519_key)
.with_tls13_version()
.build_record();
// ServerHello body
let mut srv_hello = Vec::new();
srv_hello.extend_from_slice(&TLS_VERSION);
srv_hello.extend_from_slice(&[0u8; TLS_DIGEST_LEN]); // Placeholder for digest
srv_hello.push(session_id.len() as u8);
srv_hello.extend_from_slice(session_id);
srv_hello.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256
srv_hello.push(0x00); // No compression
srv_hello.extend_from_slice(&extensions);
// Build complete packet
let mut hello_pkt = Vec::new();
// ServerHello record
hello_pkt.push(TLS_RECORD_HANDSHAKE);
hello_pkt.extend_from_slice(&TLS_VERSION);
hello_pkt.extend_from_slice(&((srv_hello.len() + 4) as u16).to_be_bytes());
hello_pkt.push(0x02); // ServerHello message type
let len_bytes = (srv_hello.len() as u32).to_be_bytes();
hello_pkt.extend_from_slice(&len_bytes[1..4]); // 3-byte length
hello_pkt.extend_from_slice(&srv_hello);
// Change Cipher Spec record
hello_pkt.extend_from_slice(&[
// Build Change Cipher Spec record
let change_cipher_spec = [
TLS_RECORD_CHANGE_CIPHER,
TLS_VERSION[0], TLS_VERSION[1],
0x00, 0x01, 0x01
]);
0x00, 0x01, // length = 1
0x01, // CCS byte
];
// Application Data record (fake certificate)
// Build fake certificate (Application Data record)
let fake_cert = SECURE_RANDOM.bytes(fake_cert_len);
hello_pkt.push(TLS_RECORD_APPLICATION);
hello_pkt.extend_from_slice(&TLS_VERSION);
hello_pkt.extend_from_slice(&(fake_cert.len() as u16).to_be_bytes());
hello_pkt.extend_from_slice(&fake_cert);
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
app_data_record.push(TLS_RECORD_APPLICATION);
app_data_record.extend_from_slice(&TLS_VERSION);
app_data_record.extend_from_slice(&(fake_cert_len as u16).to_be_bytes());
app_data_record.extend_from_slice(&fake_cert);
// Combine all records
let mut response = Vec::with_capacity(
server_hello.len() + change_cipher_spec.len() + app_data_record.len()
);
response.extend_from_slice(&server_hello);
response.extend_from_slice(&change_cipher_spec);
response.extend_from_slice(&app_data_record);
// Compute HMAC for the response
let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + hello_pkt.len());
let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + response.len());
hmac_input.extend_from_slice(client_digest);
hmac_input.extend_from_slice(&hello_pkt);
hmac_input.extend_from_slice(&response);
let response_digest = sha256_hmac(secret, &hmac_input);
// Insert computed digest
// Position: after record header (5) + message type/length (4) + version (2) = 11
hello_pkt[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
// Insert computed digest into ServerHello
// Position: record header (5) + message type (1) + length (3) + version (2) = 11
response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.copy_from_slice(&response_digest);
hello_pkt
response
}
/// Check if bytes look like a TLS ClientHello
@@ -186,7 +386,7 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
return false;
}
// TLS record header: 0x16 0x03 0x01
// TLS record header: 0x16 (handshake) 0x03 0x01 (TLS 1.0)
first_bytes[0] == TLS_RECORD_HANDSHAKE
&& first_bytes[1] == 0x03
&& first_bytes[2] == 0x01
@@ -206,6 +406,61 @@ pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> {
Some((record_type, length))
}
/// Validate a ServerHello response structure
///
/// This is useful for testing that our ServerHello is well-formed.
#[cfg(test)]
fn validate_server_hello_structure(data: &[u8]) -> Result<()> {
if data.len() < 5 {
return Err(ProxyError::InvalidTlsRecord {
record_type: 0,
version: [0, 0],
});
}
// Check record header
if data[0] != TLS_RECORD_HANDSHAKE {
return Err(ProxyError::InvalidTlsRecord {
record_type: data[0],
version: [data[1], data[2]],
});
}
// Check version
if data[1..3] != TLS_VERSION {
return Err(ProxyError::InvalidTlsRecord {
record_type: data[0],
version: [data[1], data[2]],
});
}
// Check record length
let record_len = u16::from_be_bytes([data[3], data[4]]) as usize;
if data.len() < 5 + record_len {
return Err(ProxyError::InvalidHandshake(
format!("ServerHello record truncated: expected {}, got {}",
5 + record_len, data.len())
));
}
// Check message type
if data[5] != 0x02 {
return Err(ProxyError::InvalidHandshake(
format!("Expected ServerHello (0x02), got 0x{:02x}", data[5])
));
}
// Parse message length
let msg_len = u32::from_be_bytes([0, data[6], data[7], data[8]]) as usize;
if msg_len + 4 != record_len {
return Err(ProxyError::InvalidHandshake(
format!("Message length mismatch: {} + 4 != {}", msg_len, record_len)
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
@@ -241,4 +496,145 @@ mod tests {
assert_eq!(key2.len(), 32);
assert_ne!(key1, key2); // Should be random
}
#[test]
fn test_tls_extension_builder() {
let key = [0x42u8; 32];
let mut builder = TlsExtensionBuilder::new();
builder.add_key_share(&key);
builder.add_supported_versions(0x0304);
let result = builder.build();
// Check length prefix
let len = u16::from_be_bytes([result[0], result[1]]) as usize;
assert_eq!(len, result.len() - 2);
// Check key_share extension is present
assert!(result.len() > 40); // At least key share
}
#[test]
fn test_server_hello_builder() {
let session_id = vec![0x01, 0x02, 0x03, 0x04];
let key = [0x55u8; 32];
let builder = ServerHelloBuilder::new(session_id.clone())
.with_x25519_key(&key)
.with_tls13_version();
let record = builder.build_record();
// Validate structure
validate_server_hello_structure(&record).expect("Invalid ServerHello structure");
// Check record type
assert_eq!(record[0], TLS_RECORD_HANDSHAKE);
// Check version
assert_eq!(&record[1..3], &TLS_VERSION);
// Check message type (ServerHello = 0x02)
assert_eq!(record[5], 0x02);
}
#[test]
fn test_build_server_hello_structure() {
let secret = b"test secret";
let client_digest = [0x42u8; 32];
let session_id = vec![0xAA; 32];
let response = build_server_hello(secret, &client_digest, &session_id, 2048);
// Should have at least 3 records
assert!(response.len() > 100);
// First record should be ServerHello
assert_eq!(response[0], TLS_RECORD_HANDSHAKE);
// Validate ServerHello structure
validate_server_hello_structure(&response).expect("Invalid ServerHello");
// Find Change Cipher Spec
let server_hello_len = 5 + u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_start = server_hello_len;
assert!(response.len() > ccs_start + 6);
assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER);
// Find Application Data
let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize;
let app_start = ccs_start + ccs_len;
assert!(response.len() > app_start + 5);
assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
}
#[test]
fn test_build_server_hello_digest() {
let secret = b"test secret key here";
let client_digest = [0x42u8; 32];
let session_id = vec![0xAA; 32];
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024);
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024);
// Digest position should have non-zero data
let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
assert!(!digest1.iter().all(|&b| b == 0));
// Different calls should have different digests (due to random cert)
let digest2 = &response2[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
assert_ne!(digest1, digest2);
}
#[test]
fn test_server_hello_extensions_length() {
let session_id = vec![0x01; 32];
let key = [0x55u8; 32];
let builder = ServerHelloBuilder::new(session_id)
.with_x25519_key(&key)
.with_tls13_version();
let record = builder.build_record();
// Parse to find extensions
let msg_start = 5; // After record header
let msg_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize;
// Skip to session ID
let session_id_pos = msg_start + 4 + 2 + 32; // header(4) + version(2) + random(32)
let session_id_len = record[session_id_pos] as usize;
// Skip to extensions
let ext_len_pos = session_id_pos + 1 + session_id_len + 2 + 1; // session_id + cipher(2) + compression(1)
let ext_len = u16::from_be_bytes([record[ext_len_pos], record[ext_len_pos + 1]]) as usize;
// Verify extensions length matches actual data
let extensions_data = &record[ext_len_pos + 2..msg_start + 4 + msg_len];
assert_eq!(ext_len, extensions_data.len(),
"Extension length mismatch: declared {}, actual {}", ext_len, extensions_data.len());
}
#[test]
fn test_validate_tls_handshake_format() {
// Build a minimal ClientHello-like structure
let mut handshake = vec![0u8; 100];
// Put a valid-looking digest at position 11
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.copy_from_slice(&[0x42; 32]);
// Session ID length
handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32;
// This won't validate (wrong HMAC) but shouldn't panic
let secrets = vec![("test".to_string(), b"secret".to_vec())];
let result = validate_tls_handshake(&handshake, &secrets, true);
// Should return None (no match) but not panic
assert!(result.is_none());
}
}

450
src/stream/buffer_pool.rs Normal file
View File

@@ -0,0 +1,450 @@
//! Reusable buffer pool to avoid allocations in hot paths
//!
//! This module provides a thread-safe pool of BytesMut buffers
//! that can be reused across connections to reduce allocation pressure.
use bytes::BytesMut;
use crossbeam_queue::ArrayQueue;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
// ============= Configuration =============
/// Default buffer size (64KB - good for MTProto)
pub const DEFAULT_BUFFER_SIZE: usize = 64 * 1024;
/// Default maximum number of pooled buffers
pub const DEFAULT_MAX_BUFFERS: usize = 1024;
// ============= Buffer Pool =============
/// Thread-safe pool of reusable buffers
pub struct BufferPool {
/// Queue of available buffers
buffers: ArrayQueue<BytesMut>,
/// Size of each buffer
buffer_size: usize,
/// Maximum number of buffers to pool
max_buffers: usize,
/// Total allocated buffers (including in-use)
allocated: AtomicUsize,
/// Number of times we had to create a new buffer
misses: AtomicUsize,
/// Number of successful reuses
hits: AtomicUsize,
}
impl BufferPool {
/// Create a new buffer pool with default settings
pub fn new() -> Self {
Self::with_config(DEFAULT_BUFFER_SIZE, DEFAULT_MAX_BUFFERS)
}
/// Create a buffer pool with custom configuration
pub fn with_config(buffer_size: usize, max_buffers: usize) -> Self {
Self {
buffers: ArrayQueue::new(max_buffers),
buffer_size,
max_buffers,
allocated: AtomicUsize::new(0),
misses: AtomicUsize::new(0),
hits: AtomicUsize::new(0),
}
}
/// Get a buffer from the pool, or create a new one if empty
pub fn get(self: &Arc<Self>) -> PooledBuffer {
match self.buffers.pop() {
Some(mut buffer) => {
self.hits.fetch_add(1, Ordering::Relaxed);
buffer.clear();
PooledBuffer {
buffer: Some(buffer),
pool: Arc::clone(self),
}
}
None => {
self.misses.fetch_add(1, Ordering::Relaxed);
self.allocated.fetch_add(1, Ordering::Relaxed);
PooledBuffer {
buffer: Some(BytesMut::with_capacity(self.buffer_size)),
pool: Arc::clone(self),
}
}
}
}
/// Try to get a buffer, returns None if pool is empty
pub fn try_get(self: &Arc<Self>) -> Option<PooledBuffer> {
self.buffers.pop().map(|mut buffer| {
self.hits.fetch_add(1, Ordering::Relaxed);
buffer.clear();
PooledBuffer {
buffer: Some(buffer),
pool: Arc::clone(self),
}
})
}
/// Return a buffer to the pool
fn return_buffer(&self, mut buffer: BytesMut) {
// Clear the buffer but keep capacity
buffer.clear();
// Only return if we haven't exceeded max and buffer is right size
if buffer.capacity() >= self.buffer_size {
// Try to push to pool, if full just drop
let _ = self.buffers.push(buffer);
}
// If buffer was dropped (pool full), decrement allocated
// Actually we don't decrement here because the buffer might have been
// grown beyond our size - we just let it go
}
/// Get pool statistics
pub fn stats(&self) -> PoolStats {
PoolStats {
pooled: self.buffers.len(),
allocated: self.allocated.load(Ordering::Relaxed),
max_buffers: self.max_buffers,
buffer_size: self.buffer_size,
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
}
}
/// Get buffer size
pub fn buffer_size(&self) -> usize {
self.buffer_size
}
/// Preallocate buffers to fill the pool
pub fn preallocate(&self, count: usize) {
let to_alloc = count.min(self.max_buffers);
for _ in 0..to_alloc {
if self.buffers.push(BytesMut::with_capacity(self.buffer_size)).is_err() {
break;
}
self.allocated.fetch_add(1, Ordering::Relaxed);
}
}
}
impl Default for BufferPool {
fn default() -> Self {
Self::new()
}
}
// ============= Pool Statistics =============
/// Statistics about buffer pool usage
#[derive(Debug, Clone)]
pub struct PoolStats {
/// Current number of buffers in pool
pub pooled: usize,
/// Total buffers allocated (in-use + pooled)
pub allocated: usize,
/// Maximum buffers allowed
pub max_buffers: usize,
/// Size of each buffer
pub buffer_size: usize,
/// Number of cache hits (reused buffer)
pub hits: usize,
/// Number of cache misses (new allocation)
pub misses: usize,
}
impl PoolStats {
/// Get hit rate as percentage
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
(self.hits as f64 / total as f64) * 100.0
}
}
}
// ============= Pooled Buffer =============
/// A buffer that automatically returns to the pool when dropped
pub struct PooledBuffer {
buffer: Option<BytesMut>,
pool: Arc<BufferPool>,
}
impl PooledBuffer {
/// Take the inner buffer, preventing return to pool
pub fn take(mut self) -> BytesMut {
self.buffer.take().unwrap()
}
/// Get the capacity of the buffer
pub fn capacity(&self) -> usize {
self.buffer.as_ref().map(|b| b.capacity()).unwrap_or(0)
}
/// Check if buffer is empty
pub fn is_empty(&self) -> bool {
self.buffer.as_ref().map(|b| b.is_empty()).unwrap_or(true)
}
/// Get the length of data in buffer
pub fn len(&self) -> usize {
self.buffer.as_ref().map(|b| b.len()).unwrap_or(0)
}
/// Clear the buffer
pub fn clear(&mut self) {
if let Some(ref mut b) = self.buffer {
b.clear();
}
}
}
impl Deref for PooledBuffer {
type Target = BytesMut;
fn deref(&self) -> &Self::Target {
self.buffer.as_ref().expect("buffer taken")
}
}
impl DerefMut for PooledBuffer {
fn deref_mut(&mut self) -> &mut Self::Target {
self.buffer.as_mut().expect("buffer taken")
}
}
impl Drop for PooledBuffer {
fn drop(&mut self) {
if let Some(buffer) = self.buffer.take() {
self.pool.return_buffer(buffer);
}
}
}
impl AsRef<[u8]> for PooledBuffer {
fn as_ref(&self) -> &[u8] {
self.buffer.as_ref().map(|b| b.as_ref()).unwrap_or(&[])
}
}
impl AsMut<[u8]> for PooledBuffer {
fn as_mut(&mut self) -> &mut [u8] {
self.buffer.as_mut().map(|b| b.as_mut()).unwrap_or(&mut [])
}
}
// ============= Scoped Buffer =============
/// A buffer that can be used for a scoped operation
/// Useful for ensuring buffer is returned even on early return
pub struct ScopedBuffer<'a> {
buffer: &'a mut PooledBuffer,
}
impl<'a> ScopedBuffer<'a> {
/// Create a new scoped buffer
pub fn new(buffer: &'a mut PooledBuffer) -> Self {
buffer.clear();
Self { buffer }
}
}
impl<'a> Deref for ScopedBuffer<'a> {
type Target = BytesMut;
fn deref(&self) -> &Self::Target {
self.buffer.deref()
}
}
impl<'a> DerefMut for ScopedBuffer<'a> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.buffer.deref_mut()
}
}
impl<'a> Drop for ScopedBuffer<'a> {
fn drop(&mut self) {
self.buffer.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_basic() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
// Get a buffer
let mut buf1 = pool.get();
buf1.extend_from_slice(b"hello");
assert_eq!(&buf1[..], b"hello");
// Drop returns to pool
drop(buf1);
let stats = pool.stats();
assert_eq!(stats.pooled, 1);
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 1);
// Get again - should reuse
let buf2 = pool.get();
assert!(buf2.is_empty()); // Buffer was cleared
let stats = pool.stats();
assert_eq!(stats.pooled, 0);
assert_eq!(stats.hits, 1);
}
#[test]
fn test_pool_multiple_buffers() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
// Get multiple buffers
let buf1 = pool.get();
let buf2 = pool.get();
let buf3 = pool.get();
let stats = pool.stats();
assert_eq!(stats.allocated, 3);
assert_eq!(stats.pooled, 0);
// Return all
drop(buf1);
drop(buf2);
drop(buf3);
let stats = pool.stats();
assert_eq!(stats.pooled, 3);
}
#[test]
fn test_pool_overflow() {
let pool = Arc::new(BufferPool::with_config(1024, 2));
// Get 3 buffers (more than max)
let buf1 = pool.get();
let buf2 = pool.get();
let buf3 = pool.get();
// Return all - only 2 should be pooled
drop(buf1);
drop(buf2);
drop(buf3);
let stats = pool.stats();
assert_eq!(stats.pooled, 2);
}
#[test]
fn test_pool_take() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
let mut buf = pool.get();
buf.extend_from_slice(b"data");
// Take ownership, buffer should not return to pool
let taken = buf.take();
assert_eq!(&taken[..], b"data");
let stats = pool.stats();
assert_eq!(stats.pooled, 0);
}
#[test]
fn test_pool_preallocate() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
pool.preallocate(5);
let stats = pool.stats();
assert_eq!(stats.pooled, 5);
assert_eq!(stats.allocated, 5);
}
#[test]
fn test_pool_try_get() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
// Pool is empty, try_get returns None
assert!(pool.try_get().is_none());
// Add a buffer to pool
pool.preallocate(1);
// Now try_get should succeed
assert!(pool.try_get().is_some());
assert!(pool.try_get().is_none());
}
#[test]
fn test_hit_rate() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
// First get is a miss
let buf1 = pool.get();
drop(buf1);
// Second get is a hit
let buf2 = pool.get();
drop(buf2);
// Third get is a hit
let _buf3 = pool.get();
let stats = pool.stats();
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate() - 66.67).abs() < 1.0);
}
#[test]
fn test_scoped_buffer() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
let mut buf = pool.get();
{
let mut scoped = ScopedBuffer::new(&mut buf);
scoped.extend_from_slice(b"scoped data");
assert_eq!(&scoped[..], b"scoped data");
}
// After scoped is dropped, buffer is cleared
assert!(buf.is_empty());
}
#[test]
fn test_concurrent_access() {
use std::thread;
let pool = Arc::new(BufferPool::with_config(1024, 100));
let mut handles = vec![];
for _ in 0..10 {
let pool_clone = Arc::clone(&pool);
handles.push(thread::spawn(move || {
for _ in 0..100 {
let mut buf = pool_clone.get();
buf.extend_from_slice(b"test");
// buf auto-returned on drop
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let stats = pool.stats();
// All buffers should be returned
assert!(stats.pooled > 0);
}
}

File diff suppressed because it is too large Load Diff

187
src/stream/frame.rs Normal file
View File

@@ -0,0 +1,187 @@
//! MTProto frame types and traits
//!
//! This module defines the common types and traits used by all
//! frame encoding/decoding implementations.
use bytes::{Bytes, BytesMut};
use std::io::Result;
use crate::protocol::constants::ProtoTag;
// ============= Frame Types =============
/// A decoded MTProto frame
#[derive(Debug, Clone)]
pub struct Frame {
/// Frame payload data
pub data: Bytes,
/// Frame metadata
pub meta: FrameMeta,
}
impl Frame {
/// Create a new frame with data and default metadata
pub fn new(data: Bytes) -> Self {
Self {
data,
meta: FrameMeta::default(),
}
}
/// Create a new frame with data and metadata
pub fn with_meta(data: Bytes, meta: FrameMeta) -> Self {
Self { data, meta }
}
/// Create an empty frame
pub fn empty() -> Self {
Self::new(Bytes::new())
}
/// Check if frame is empty
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
/// Get frame length
pub fn len(&self) -> usize {
self.data.len()
}
/// Create a QuickAck request frame
pub fn quickack(data: Bytes) -> Self {
Self {
data,
meta: FrameMeta {
quickack: true,
..Default::default()
},
}
}
/// Create a simple ACK frame
pub fn simple_ack(data: Bytes) -> Self {
Self {
data,
meta: FrameMeta {
simple_ack: true,
..Default::default()
},
}
}
}
/// Frame metadata
#[derive(Debug, Clone, Default)]
pub struct FrameMeta {
/// Quick ACK requested - client wants immediate acknowledgment
pub quickack: bool,
/// This is a simple ACK message (reversed data)
pub simple_ack: bool,
/// Original padding length (for secure mode)
pub padding_len: u8,
}
impl FrameMeta {
/// Create new empty metadata
pub fn new() -> Self {
Self::default()
}
/// Create with quickack flag
pub fn with_quickack(mut self) -> Self {
self.quickack = true;
self
}
/// Create with simple_ack flag
pub fn with_simple_ack(mut self) -> Self {
self.simple_ack = true;
self
}
/// Create with padding length
pub fn with_padding(mut self, len: u8) -> Self {
self.padding_len = len;
self
}
/// Check if any special flags are set
pub fn has_flags(&self) -> bool {
self.quickack || self.simple_ack
}
}
// ============= Codec Trait =============
/// Trait for frame codecs that can encode and decode frames
pub trait FrameCodec: Send + Sync {
/// Get the protocol tag for this codec
fn proto_tag(&self) -> ProtoTag;
/// Encode a frame into the destination buffer
///
/// Returns the number of bytes written.
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> Result<usize>;
/// Try to decode a frame from the source buffer
///
/// Returns:
/// - `Ok(Some(frame))` if a complete frame was decoded
/// - `Ok(None)` if more data is needed
/// - `Err(e)` if an error occurred
///
/// On success, the consumed bytes are removed from `src`.
fn decode(&self, src: &mut BytesMut) -> Result<Option<Frame>>;
/// Get the minimum bytes needed to determine frame length
fn min_header_size(&self) -> usize;
/// Get the maximum allowed frame size
fn max_frame_size(&self) -> usize {
// Default: 16MB
16 * 1024 * 1024
}
}
// ============= Codec Factory =============
/// Create a frame codec for the given protocol tag
pub fn create_codec(proto_tag: ProtoTag) -> Box<dyn FrameCodec> {
match proto_tag {
ProtoTag::Abridged => Box::new(super::frame_codec::AbridgedCodec::new()),
ProtoTag::Intermediate => Box::new(super::frame_codec::IntermediateCodec::new()),
ProtoTag::Secure => Box::new(super::frame_codec::SecureCodec::new()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_frame_creation() {
let frame = Frame::new(Bytes::from_static(b"test"));
assert_eq!(frame.len(), 4);
assert!(!frame.is_empty());
assert!(!frame.meta.quickack);
let frame = Frame::empty();
assert!(frame.is_empty());
let frame = Frame::quickack(Bytes::from_static(b"ack"));
assert!(frame.meta.quickack);
}
#[test]
fn test_frame_meta() {
let meta = FrameMeta::new()
.with_quickack()
.with_padding(3);
assert!(meta.quickack);
assert!(!meta.simple_ack);
assert_eq!(meta.padding_len, 3);
assert!(meta.has_flags());
}
}

621
src/stream/frame_codec.rs Normal file
View File

@@ -0,0 +1,621 @@
//! tokio-util codec integration for MTProto frames
//!
//! This module provides Encoder/Decoder implementations compatible
//! with tokio-util's Framed wrapper for easy async frame I/O.
use bytes::{Bytes, BytesMut, BufMut};
use std::io::{self, Error, ErrorKind};
use tokio_util::codec::{Decoder, Encoder};
use crate::protocol::constants::ProtoTag;
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
// ============= Unified Codec =============
/// Unified frame codec that wraps all protocol variants
///
/// This codec implements tokio-util's Encoder and Decoder traits,
/// allowing it to be used with `Framed` for async frame I/O.
pub struct FrameCodec {
/// Protocol variant
proto_tag: ProtoTag,
/// Maximum allowed frame size
max_frame_size: usize,
}
impl FrameCodec {
/// Create a new codec for the given protocol
pub fn new(proto_tag: ProtoTag) -> Self {
Self {
proto_tag,
max_frame_size: 16 * 1024 * 1024, // 16MB default
}
}
/// Set maximum frame size
pub fn with_max_frame_size(mut self, size: usize) -> Self {
self.max_frame_size = size;
self
}
/// Get protocol tag
pub fn proto_tag(&self) -> ProtoTag {
self.proto_tag
}
}
impl Decoder for FrameCodec {
type Item = Frame;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match self.proto_tag {
ProtoTag::Abridged => decode_abridged(src, self.max_frame_size),
ProtoTag::Intermediate => decode_intermediate(src, self.max_frame_size),
ProtoTag::Secure => decode_secure(src, self.max_frame_size),
}
}
}
impl Encoder<Frame> for FrameCodec {
type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
match self.proto_tag {
ProtoTag::Abridged => encode_abridged(&frame, dst),
ProtoTag::Intermediate => encode_intermediate(&frame, dst),
ProtoTag::Secure => encode_secure(&frame, dst),
}
}
}
// ============= Abridged Protocol =============
fn decode_abridged(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame>> {
if src.is_empty() {
return Ok(None);
}
let mut meta = FrameMeta::new();
let first_byte = src[0];
// Extract length and quickack flag
let mut len_words = (first_byte & 0x7f) as usize;
if first_byte >= 0x80 {
meta.quickack = true;
}
let header_len;
if len_words == 0x7f {
// Extended length (3 more bytes needed)
if src.len() < 4 {
return Ok(None);
}
len_words = u32::from_le_bytes([src[1], src[2], src[3], 0]) as usize;
header_len = 4;
} else {
header_len = 1;
}
// Length is in 4-byte words
let byte_len = len_words.checked_mul(4).ok_or_else(|| {
Error::new(ErrorKind::InvalidData, "frame length overflow")
})?;
// Validate size
if byte_len > max_size {
return Err(Error::new(
ErrorKind::InvalidData,
format!("frame too large: {} bytes (max {})", byte_len, max_size)
));
}
let total_len = header_len + byte_len;
if src.len() < total_len {
// Reserve space for the rest of the frame
src.reserve(total_len - src.len());
return Ok(None);
}
// Extract data
let _ = src.split_to(header_len);
let data = src.split_to(byte_len).freeze();
Ok(Some(Frame::with_meta(data, meta)))
}
fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
let data = &frame.data;
// Validate alignment
if data.len() % 4 != 0 {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("abridged frame must be 4-byte aligned, got {} bytes", data.len())
));
}
// Simple ACK: send reversed data without header
if frame.meta.simple_ack {
dst.reserve(data.len());
for byte in data.iter().rev() {
dst.put_u8(*byte);
}
return Ok(());
}
let len_words = data.len() / 4;
if len_words < 0x7f {
// Short header
dst.reserve(1 + data.len());
let mut len_byte = len_words as u8;
if frame.meta.quickack {
len_byte |= 0x80;
}
dst.put_u8(len_byte);
} else if len_words < (1 << 24) {
// Extended header
dst.reserve(4 + data.len());
let mut first = 0x7fu8;
if frame.meta.quickack {
first |= 0x80;
}
dst.put_u8(first);
let len_bytes = (len_words as u32).to_le_bytes();
dst.extend_from_slice(&len_bytes[..3]);
} else {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("frame too large: {} bytes", data.len())
));
}
dst.extend_from_slice(data);
Ok(())
}
// ============= Intermediate Protocol =============
fn decode_intermediate(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame>> {
if src.len() < 4 {
return Ok(None);
}
let mut meta = FrameMeta::new();
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
// Check QuickACK flag
if len >= 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
// Validate size
if len > max_size {
return Err(Error::new(
ErrorKind::InvalidData,
format!("frame too large: {} bytes (max {})", len, max_size)
));
}
let total_len = 4 + len;
if src.len() < total_len {
src.reserve(total_len - src.len());
return Ok(None);
}
// Extract data
let _ = src.split_to(4);
let data = src.split_to(len).freeze();
Ok(Some(Frame::with_meta(data, meta)))
}
fn encode_intermediate(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
let data = &frame.data;
// Simple ACK: just send data
if frame.meta.simple_ack {
dst.reserve(data.len());
dst.extend_from_slice(data);
return Ok(());
}
dst.reserve(4 + data.len());
let mut len = data.len() as u32;
if frame.meta.quickack {
len |= 0x80000000;
}
dst.extend_from_slice(&len.to_le_bytes());
dst.extend_from_slice(data);
Ok(())
}
// ============= Secure Intermediate Protocol =============
fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame>> {
if src.len() < 4 {
return Ok(None);
}
let mut meta = FrameMeta::new();
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
// Check QuickACK flag
if len >= 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
// Validate size
if len > max_size {
return Err(Error::new(
ErrorKind::InvalidData,
format!("frame too large: {} bytes (max {})", len, max_size)
));
}
let total_len = 4 + len;
if src.len() < total_len {
src.reserve(total_len - src.len());
return Ok(None);
}
// Calculate padding (indicated by length not divisible by 4)
let padding_len = len % 4;
let data_len = if padding_len != 0 {
len - padding_len
} else {
len
};
meta.padding_len = padding_len as u8;
// Extract data (excluding padding)
let _ = src.split_to(4);
let all_data = src.split_to(len);
// Copy only the data portion, excluding padding
let data = Bytes::copy_from_slice(&all_data[..data_len]);
Ok(Some(Frame::with_meta(data, meta)))
}
fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
use crate::crypto::random::SECURE_RANDOM;
let data = &frame.data;
// Simple ACK: just send data
if frame.meta.simple_ack {
dst.reserve(data.len());
dst.extend_from_slice(data);
return Ok(());
}
// Generate padding to make length not divisible by 4
let padding_len = if data.len() % 4 == 0 {
// Add 1-3 bytes to make it non-aligned
(SECURE_RANDOM.range(3) + 1) as usize
} else {
// Already non-aligned, can add 0-3
SECURE_RANDOM.range(4) as usize
};
let total_len = data.len() + padding_len;
dst.reserve(4 + total_len);
let mut len = total_len as u32;
if frame.meta.quickack {
len |= 0x80000000;
}
dst.extend_from_slice(&len.to_le_bytes());
dst.extend_from_slice(data);
if padding_len > 0 {
let padding = SECURE_RANDOM.bytes(padding_len);
dst.extend_from_slice(&padding);
}
Ok(())
}
// ============= Typed Codecs =============
/// Abridged protocol codec
pub struct AbridgedCodec {
max_frame_size: usize,
}
impl AbridgedCodec {
pub fn new() -> Self {
Self {
max_frame_size: 16 * 1024 * 1024,
}
}
}
impl Default for AbridgedCodec {
fn default() -> Self {
Self::new()
}
}
impl Decoder for AbridgedCodec {
type Item = Frame;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
decode_abridged(src, self.max_frame_size)
}
}
impl Encoder<Frame> for AbridgedCodec {
type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
encode_abridged(&frame, dst)
}
}
impl FrameCodecTrait for AbridgedCodec {
fn proto_tag(&self) -> ProtoTag {
ProtoTag::Abridged
}
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
let before = dst.len();
encode_abridged(frame, dst)?;
Ok(dst.len() - before)
}
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
decode_abridged(src, self.max_frame_size)
}
fn min_header_size(&self) -> usize {
1
}
}
/// Intermediate protocol codec
pub struct IntermediateCodec {
max_frame_size: usize,
}
impl IntermediateCodec {
pub fn new() -> Self {
Self {
max_frame_size: 16 * 1024 * 1024,
}
}
}
impl Default for IntermediateCodec {
fn default() -> Self {
Self::new()
}
}
impl Decoder for IntermediateCodec {
type Item = Frame;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
decode_intermediate(src, self.max_frame_size)
}
}
impl Encoder<Frame> for IntermediateCodec {
type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
encode_intermediate(&frame, dst)
}
}
impl FrameCodecTrait for IntermediateCodec {
fn proto_tag(&self) -> ProtoTag {
ProtoTag::Intermediate
}
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
let before = dst.len();
encode_intermediate(frame, dst)?;
Ok(dst.len() - before)
}
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
decode_intermediate(src, self.max_frame_size)
}
fn min_header_size(&self) -> usize {
4
}
}
/// Secure Intermediate protocol codec
pub struct SecureCodec {
max_frame_size: usize,
}
impl SecureCodec {
pub fn new() -> Self {
Self {
max_frame_size: 16 * 1024 * 1024,
}
}
}
impl Default for SecureCodec {
fn default() -> Self {
Self::new()
}
}
impl Decoder for SecureCodec {
type Item = Frame;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
decode_secure(src, self.max_frame_size)
}
}
impl Encoder<Frame> for SecureCodec {
type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
encode_secure(&frame, dst)
}
}
impl FrameCodecTrait for SecureCodec {
fn proto_tag(&self) -> ProtoTag {
ProtoTag::Secure
}
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
let before = dst.len();
encode_secure(frame, dst)?;
Ok(dst.len() - before)
}
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
decode_secure(src, self.max_frame_size)
}
fn min_header_size(&self) -> usize {
4
}
}
// ============= Tests =============
#[cfg(test)]
mod tests {
use super::*;
use tokio_util::codec::{FramedRead, FramedWrite};
use tokio::io::duplex;
use futures::{SinkExt, StreamExt};
#[tokio::test]
async fn test_framed_abridged() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, AbridgedCodec::new());
let mut reader = FramedRead::new(server, AbridgedCodec::new());
// Write a frame
let frame = Frame::new(Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]));
writer.send(frame).await.unwrap();
// Read it back
let received = reader.next().await.unwrap().unwrap();
assert_eq!(&received.data[..], &[1, 2, 3, 4, 5, 6, 7, 8]);
}
#[tokio::test]
async fn test_framed_intermediate() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
let mut reader = FramedRead::new(server, IntermediateCodec::new());
let frame = Frame::new(Bytes::from_static(b"hello world"));
writer.send(frame).await.unwrap();
let received = reader.next().await.unwrap().unwrap();
assert_eq!(&received.data[..], b"hello world");
}
#[tokio::test]
async fn test_framed_secure() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, SecureCodec::new());
let mut reader = FramedRead::new(server, SecureCodec::new());
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
let frame = Frame::new(original.clone());
writer.send(frame).await.unwrap();
let received = reader.next().await.unwrap().unwrap();
assert_eq!(&received.data[..], &original[..]);
}
#[tokio::test]
async fn test_unified_codec() {
for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag));
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag));
// Use 4-byte aligned data for abridged compatibility
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
let frame = Frame::new(original.clone());
writer.send(frame).await.unwrap();
let received = reader.next().await.unwrap().unwrap();
assert_eq!(received.data.len(), 8);
}
}
#[tokio::test]
async fn test_multiple_frames() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
let mut reader = FramedRead::new(server, IntermediateCodec::new());
// Send multiple frames
for i in 0..10 {
let data: Vec<u8> = (0..((i + 1) * 10)).map(|j| (j % 256) as u8).collect();
let frame = Frame::new(Bytes::from(data));
writer.send(frame).await.unwrap();
}
// Receive them
for i in 0..10 {
let received = reader.next().await.unwrap().unwrap();
assert_eq!(received.data.len(), (i + 1) * 10);
}
}
#[tokio::test]
async fn test_quickack_flag() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
let mut reader = FramedRead::new(server, IntermediateCodec::new());
let frame = Frame::quickack(Bytes::from_static(b"urgent"));
writer.send(frame).await.unwrap();
let received = reader.next().await.unwrap().unwrap();
assert!(received.meta.quickack);
}
#[test]
fn test_frame_too_large() {
let mut codec = FrameCodec::new(ProtoTag::Intermediate)
.with_max_frame_size(100);
// Create a "frame" that claims to be very large
let mut buf = BytesMut::new();
buf.extend_from_slice(&1000u32.to_le_bytes()); // length = 1000
buf.extend_from_slice(&[0u8; 10]); // partial data
let result = codec.decode(&mut buf);
assert!(result.is_err());
}
}

View File

@@ -1,10 +1,43 @@
//! Stream wrappers for MTProto protocol layers
pub mod state;
pub mod buffer_pool;
pub mod traits;
pub mod crypto_stream;
pub mod tls_stream;
pub mod frame;
pub mod frame_codec;
// Legacy compatibility - will be removed later
pub mod frame_stream;
// Re-export state machine types
pub use state::{
StreamState, Transition, PollResult,
ReadBuffer, WriteBuffer, HeaderBuffer, YieldBuffer,
};
// Re-export buffer pool
pub use buffer_pool::{BufferPool, PooledBuffer, PoolStats};
// Re-export stream implementations
pub use crypto_stream::{CryptoReader, CryptoWriter, PassthroughStream};
pub use tls_stream::{FakeTlsReader, FakeTlsWriter};
pub use frame_stream::*;
// Re-export frame types
pub use frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait, create_codec};
// Re-export tokio-util compatible codecs
pub use frame_codec::{
FrameCodec,
AbridgedCodec, IntermediateCodec, SecureCodec,
};
// Legacy re-exports for compatibility
pub use frame_stream::{
AbridgedFrameReader, AbridgedFrameWriter,
IntermediateFrameReader, IntermediateFrameWriter,
SecureIntermediateFrameReader, SecureIntermediateFrameWriter,
MtprotoFrameReader, MtprotoFrameWriter,
FrameReaderKind, FrameWriterKind,
};

571
src/stream/state.rs Normal file
View File

@@ -0,0 +1,571 @@
//! State machine foundation types for async streams
//!
//! This module provides core types and traits for implementing
//! stateful async streams with proper partial read/write handling.
use bytes::{Bytes, BytesMut};
use std::io;
// ============= Core Traits =============
/// Trait for stream states
pub trait StreamState: Sized {
/// Check if this is a terminal state (no more transitions possible)
fn is_terminal(&self) -> bool;
/// Check if stream is in poisoned/error state
fn is_poisoned(&self) -> bool;
/// Get human-readable state name for debugging
fn state_name(&self) -> &'static str;
}
// ============= Transition Types =============
/// Result of a state transition
#[derive(Debug)]
pub enum Transition<S, O> {
/// Stay in the same state, no output
Same,
/// Transition to a new state, no output
Next(S),
/// Complete with output, typically transitions to Idle
Complete(O),
/// Yield output and transition to new state
Yield(O, S),
/// Error occurred, transition to error state
Error(io::Error),
}
impl<S, O> Transition<S, O> {
/// Check if transition produces output
pub fn has_output(&self) -> bool {
matches!(self, Transition::Complete(_) | Transition::Yield(_, _))
}
/// Map the output value
pub fn map_output<U, F: FnOnce(O) -> U>(self, f: F) -> Transition<S, U> {
match self {
Transition::Same => Transition::Same,
Transition::Next(s) => Transition::Next(s),
Transition::Complete(o) => Transition::Complete(f(o)),
Transition::Yield(o, s) => Transition::Yield(f(o), s),
Transition::Error(e) => Transition::Error(e),
}
}
/// Map the state value
pub fn map_state<T, F: FnOnce(S) -> T>(self, f: F) -> Transition<T, O> {
match self {
Transition::Same => Transition::Same,
Transition::Next(s) => Transition::Next(f(s)),
Transition::Complete(o) => Transition::Complete(o),
Transition::Yield(o, s) => Transition::Yield(o, f(s)),
Transition::Error(e) => Transition::Error(e),
}
}
}
// ============= Poll Result Types =============
/// Result of polling for more data
#[derive(Debug)]
pub enum PollResult<T> {
/// Data is ready
Ready(T),
/// Operation would block, need to poll again
Pending,
/// Need more input data (minimum bytes required)
NeedInput(usize),
/// End of stream reached
Eof,
/// Error occurred
Error(io::Error),
}
impl<T> PollResult<T> {
/// Check if result is ready
pub fn is_ready(&self) -> bool {
matches!(self, PollResult::Ready(_))
}
/// Check if result indicates EOF
pub fn is_eof(&self) -> bool {
matches!(self, PollResult::Eof)
}
/// Convert to Option, discarding non-ready states
pub fn ok(self) -> Option<T> {
match self {
PollResult::Ready(t) => Some(t),
_ => None,
}
}
/// Map the value
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> PollResult<U> {
match self {
PollResult::Ready(t) => PollResult::Ready(f(t)),
PollResult::Pending => PollResult::Pending,
PollResult::NeedInput(n) => PollResult::NeedInput(n),
PollResult::Eof => PollResult::Eof,
PollResult::Error(e) => PollResult::Error(e),
}
}
}
impl<T> From<io::Result<T>> for PollResult<T> {
fn from(result: io::Result<T>) -> Self {
match result {
Ok(t) => PollResult::Ready(t),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => PollResult::Pending,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => PollResult::Eof,
Err(e) => PollResult::Error(e),
}
}
}
// ============= Buffer State =============
/// State for buffered reading operations
#[derive(Debug)]
pub struct ReadBuffer {
/// The buffer holding data
buffer: BytesMut,
/// Target number of bytes to read (if known)
target: Option<usize>,
}
impl ReadBuffer {
/// Create new empty read buffer
pub fn new() -> Self {
Self {
buffer: BytesMut::with_capacity(8192),
target: None,
}
}
/// Create with specific capacity
pub fn with_capacity(capacity: usize) -> Self {
Self {
buffer: BytesMut::with_capacity(capacity),
target: None,
}
}
/// Create with target size
pub fn with_target(target: usize) -> Self {
Self {
buffer: BytesMut::with_capacity(target),
target: Some(target),
}
}
/// Get current buffer length
pub fn len(&self) -> usize {
self.buffer.len()
}
/// Check if buffer is empty
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
/// Check if target is reached
pub fn is_complete(&self) -> bool {
match self.target {
Some(t) => self.buffer.len() >= t,
None => false,
}
}
/// Get remaining bytes needed
pub fn remaining(&self) -> usize {
match self.target {
Some(t) => t.saturating_sub(self.buffer.len()),
None => 0,
}
}
/// Append data to buffer
pub fn extend(&mut self, data: &[u8]) {
self.buffer.extend_from_slice(data);
}
/// Take all data from buffer
pub fn take(&mut self) -> Bytes {
self.target = None;
self.buffer.split().freeze()
}
/// Take exactly n bytes
pub fn take_exact(&mut self, n: usize) -> Option<Bytes> {
if self.buffer.len() >= n {
Some(self.buffer.split_to(n).freeze())
} else {
None
}
}
/// Get a slice of the buffer
pub fn as_slice(&self) -> &[u8] {
&self.buffer
}
/// Get mutable access to underlying BytesMut
pub fn as_bytes_mut(&mut self) -> &mut BytesMut {
&mut self.buffer
}
/// Clear the buffer
pub fn clear(&mut self) {
self.buffer.clear();
self.target = None;
}
/// Set new target
pub fn set_target(&mut self, target: usize) {
self.target = Some(target);
}
}
impl Default for ReadBuffer {
fn default() -> Self {
Self::new()
}
}
/// State for buffered writing operations
#[derive(Debug)]
pub struct WriteBuffer {
/// The buffer holding data to write
buffer: BytesMut,
/// Position of next byte to write
position: usize,
/// Maximum buffer size
max_size: usize,
}
impl WriteBuffer {
/// Create new write buffer with default max size (256KB)
pub fn new() -> Self {
Self::with_max_size(256 * 1024)
}
/// Create with specific max size
pub fn with_max_size(max_size: usize) -> Self {
Self {
buffer: BytesMut::with_capacity(8192),
position: 0,
max_size,
}
}
/// Get pending bytes count
pub fn len(&self) -> usize {
self.buffer.len() - self.position
}
/// Check if buffer is empty (all written)
pub fn is_empty(&self) -> bool {
self.position >= self.buffer.len()
}
/// Check if buffer is full
pub fn is_full(&self) -> bool {
self.buffer.len() >= self.max_size
}
/// Get remaining capacity
pub fn remaining_capacity(&self) -> usize {
self.max_size.saturating_sub(self.buffer.len())
}
/// Append data to buffer
pub fn extend(&mut self, data: &[u8]) -> Result<(), ()> {
if self.buffer.len() + data.len() > self.max_size {
return Err(());
}
self.buffer.extend_from_slice(data);
Ok(())
}
/// Get slice of data to write
pub fn pending(&self) -> &[u8] {
&self.buffer[self.position..]
}
/// Advance position by n bytes (after successful write)
pub fn advance(&mut self, n: usize) {
self.position += n;
// If all data written, reset buffer
if self.position >= self.buffer.len() {
self.buffer.clear();
self.position = 0;
}
}
/// Clear the buffer
pub fn clear(&mut self) {
self.buffer.clear();
self.position = 0;
}
}
impl Default for WriteBuffer {
fn default() -> Self {
Self::new()
}
}
// ============= Fixed-Size Buffer States =============
/// State for reading a fixed-size header
#[derive(Debug, Clone)]
pub struct HeaderBuffer<const N: usize> {
/// The buffer
data: [u8; N],
/// Bytes filled so far
filled: usize,
}
impl<const N: usize> HeaderBuffer<N> {
/// Create new empty header buffer
pub fn new() -> Self {
Self {
data: [0u8; N],
filled: 0,
}
}
/// Get slice for reading into
pub fn unfilled_mut(&mut self) -> &mut [u8] {
&mut self.data[self.filled..]
}
/// Advance filled count
pub fn advance(&mut self, n: usize) {
self.filled = (self.filled + n).min(N);
}
/// Check if completely filled
pub fn is_complete(&self) -> bool {
self.filled >= N
}
/// Get remaining bytes needed
pub fn remaining(&self) -> usize {
N - self.filled
}
/// Get filled bytes as slice
pub fn as_slice(&self) -> &[u8] {
&self.data[..self.filled]
}
/// Get complete buffer (panics if not complete)
pub fn as_array(&self) -> &[u8; N] {
assert!(self.is_complete());
&self.data
}
/// Take the buffer, resetting state
pub fn take(&mut self) -> [u8; N] {
let data = self.data;
self.data = [0u8; N];
self.filled = 0;
data
}
/// Reset to empty state
pub fn reset(&mut self) {
self.filled = 0;
}
}
impl<const N: usize> Default for HeaderBuffer<N> {
fn default() -> Self {
Self::new()
}
}
// ============= Yield Buffer =============
/// Buffer for yielding data to caller in chunks
#[derive(Debug)]
pub struct YieldBuffer {
data: Bytes,
position: usize,
}
impl YieldBuffer {
/// Create new yield buffer
pub fn new(data: Bytes) -> Self {
Self { data, position: 0 }
}
/// Check if all data has been yielded
pub fn is_empty(&self) -> bool {
self.position >= self.data.len()
}
/// Get remaining bytes
pub fn remaining(&self) -> usize {
self.data.len() - self.position
}
/// Copy data to output slice, return bytes copied
pub fn copy_to(&mut self, dst: &mut [u8]) -> usize {
let available = &self.data[self.position..];
let to_copy = available.len().min(dst.len());
dst[..to_copy].copy_from_slice(&available[..to_copy]);
self.position += to_copy;
to_copy
}
/// Get remaining data as slice
pub fn as_slice(&self) -> &[u8] {
&self.data[self.position..]
}
}
// ============= Macros =============
/// Macro to simplify state transitions in poll methods
#[macro_export]
macro_rules! transition {
(same) => {
$crate::stream::state::Transition::Same
};
(next $state:expr) => {
$crate::stream::state::Transition::Next($state)
};
(complete $output:expr) => {
$crate::stream::state::Transition::Complete($output)
};
(yield $output:expr, $state:expr) => {
$crate::stream::state::Transition::Yield($output, $state)
};
(error $err:expr) => {
$crate::stream::state::Transition::Error($err)
};
}
/// Macro to match poll ready or return pending
#[macro_export]
macro_rules! ready_or_pending {
($poll:expr) => {
match $poll {
std::task::Poll::Ready(t) => t,
std::task::Poll::Pending => return std::task::Poll::Pending,
}
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read_buffer_basic() {
let mut buf = ReadBuffer::with_target(10);
assert_eq!(buf.remaining(), 10);
assert!(!buf.is_complete());
buf.extend(b"hello");
assert_eq!(buf.len(), 5);
assert_eq!(buf.remaining(), 5);
assert!(!buf.is_complete());
buf.extend(b"world");
assert_eq!(buf.len(), 10);
assert!(buf.is_complete());
}
#[test]
fn test_read_buffer_take() {
let mut buf = ReadBuffer::new();
buf.extend(b"test data");
let data = buf.take();
assert_eq!(&data[..], b"test data");
assert!(buf.is_empty());
}
#[test]
fn test_write_buffer_basic() {
let mut buf = WriteBuffer::with_max_size(100);
assert!(buf.is_empty());
buf.extend(b"hello").unwrap();
assert_eq!(buf.len(), 5);
assert!(!buf.is_empty());
buf.advance(3);
assert_eq!(buf.len(), 2);
assert_eq!(buf.pending(), b"lo");
}
#[test]
fn test_write_buffer_overflow() {
let mut buf = WriteBuffer::with_max_size(10);
assert!(buf.extend(b"short").is_ok());
assert!(buf.extend(b"toolong").is_err());
}
#[test]
fn test_header_buffer() {
let mut buf = HeaderBuffer::<5>::new();
assert!(!buf.is_complete());
assert_eq!(buf.remaining(), 5);
buf.unfilled_mut()[..3].copy_from_slice(b"hel");
buf.advance(3);
assert_eq!(buf.remaining(), 2);
buf.unfilled_mut()[..2].copy_from_slice(b"lo");
buf.advance(2);
assert!(buf.is_complete());
assert_eq!(buf.as_array(), b"hello");
}
#[test]
fn test_yield_buffer() {
let mut buf = YieldBuffer::new(Bytes::from_static(b"hello world"));
let mut dst = [0u8; 5];
assert_eq!(buf.copy_to(&mut dst), 5);
assert_eq!(&dst, b"hello");
assert_eq!(buf.remaining(), 6);
let mut dst = [0u8; 10];
assert_eq!(buf.copy_to(&mut dst), 6);
assert_eq!(&dst[..6], b" world");
assert!(buf.is_empty());
}
#[test]
fn test_transition_map() {
let t: Transition<i32, String> = Transition::Complete("hello".to_string());
let t = t.map_output(|s| s.len());
match t {
Transition::Complete(5) => {}
_ => panic!("Expected Complete(5)"),
}
}
#[test]
fn test_poll_result() {
let r: PollResult<i32> = PollResult::Ready(42);
assert!(r.is_ready());
assert_eq!(r.ok(), Some(42));
let r: PollResult<i32> = PollResult::Eof;
assert!(r.is_eof());
assert_eq!(r.ok(), None);
}
}

File diff suppressed because it is too large Load Diff

12
telemt.service Normal file
View File

@@ -0,0 +1,12 @@
[Unit]
Description=Telemt
After=network.target
[Service]
Type=simple
WorkingDirectory=/bin
ExecStart=/bin/telemt /etc/telemt.toml
Restart=on-failure
[Install]
WantedBy=multi-user.target