Tschuss Status Quo - Hallo, Zukunft!
This commit is contained in:
Alexey
2025-12-30 05:08:05 +03:00
parent 44169441b4
commit 3d9150a074
33 changed files with 6079 additions and 0 deletions

60
Cargo.toml Normal file
View File

@@ -0,0 +1,60 @@
[package]
name = "telemt"
version = "1.0.0"
edition = "2021"
rust-version = "1.75"
[dependencies]
# C
libc = "0.2"
# Async runtime
tokio = { version = "1.35", features = ["full", "tracing"] }
tokio-util = { version = "0.7", features = ["codec"] }
# Crypto
aes = "0.8"
ctr = "0.9"
cbc = "0.1"
sha2 = "0.10"
sha1 = "0.10"
md-5 = "0.10"
hmac = "0.12"
crc32fast = "1.3"
# Network
socket2 = { version = "0.5", features = ["all"] }
rustls = "0.22"
# Serial
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
toml = "0.8"
# Utils
bytes = "1.5"
thiserror = "1.0"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
parking_lot = "0.12"
dashmap = "5.5"
lru = "0.12"
rand = "0.8"
chrono = { version = "0.4", features = ["serde"] }
hex = "0.4"
base64 = "0.21"
url = "2.5"
regex = "1.10"
once_cell = "1.19"
# HTTP
reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false }
[dev-dependencies]
tokio-test = "0.4"
criterion = "0.5"
proptest = "1.4"
[[bench]]
name = "crypto_bench"
harness = false

12
benches/crypto_bench.rs Normal file
View File

@@ -0,0 +1,12 @@
// Cryptobench
use criterion::{black_box, criterion_group, Criterion};
fn bench_aes_ctr(c: &mut Criterion) {
c.bench_function("aes_ctr_encrypt_64kb", |b| {
let data = vec![0u8; 65536];
b.iter(|| {
let mut enc = AesCtr::new(&[0u8; 32], 0);
black_box(enc.encrypt(&data))
})
});
}

13
config.toml Normal file
View File

@@ -0,0 +1,13 @@
port = 443
[users]
user1 = "00000000000000000000000000000000"
[modes]
classic = true
secure = true
tls = true
tls_domain = "www.github.com"
fast_mode = true
prefer_ipv6 = false

227
src/config/mod.rs Normal file
View File

@@ -0,0 +1,227 @@
//! Configuration
use std::collections::HashMap;
use std::net::IpAddr;
use std::path::Path;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::error::{ProxyError, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyModes {
#[serde(default)]
pub classic: bool,
#[serde(default)]
pub secure: bool,
#[serde(default = "default_true")]
pub tls: bool,
}
fn default_true() -> bool { true }
impl Default for ProxyModes {
fn default() -> Self {
Self { classic: true, secure: true, tls: true }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyConfig {
#[serde(default = "default_port")]
pub port: u16,
#[serde(default)]
pub users: HashMap<String, String>,
#[serde(default)]
pub ad_tag: Option<String>,
#[serde(default)]
pub modes: ProxyModes,
#[serde(default = "default_tls_domain")]
pub tls_domain: String,
#[serde(default = "default_true")]
pub mask: bool,
#[serde(default)]
pub mask_host: Option<String>,
#[serde(default = "default_mask_port")]
pub mask_port: u16,
#[serde(default)]
pub prefer_ipv6: bool,
#[serde(default = "default_true")]
pub fast_mode: bool,
#[serde(default)]
pub use_middle_proxy: bool,
#[serde(default)]
pub user_max_tcp_conns: HashMap<String, usize>,
#[serde(default)]
pub user_expirations: HashMap<String, DateTime<Utc>>,
#[serde(default)]
pub user_data_quota: HashMap<String, u64>,
#[serde(default = "default_replay_check_len")]
pub replay_check_len: usize,
#[serde(default)]
pub ignore_time_skew: bool,
#[serde(default = "default_handshake_timeout")]
pub client_handshake_timeout: u64,
#[serde(default = "default_connect_timeout")]
pub tg_connect_timeout: u64,
#[serde(default = "default_keepalive")]
pub client_keepalive: u64,
#[serde(default = "default_ack_timeout")]
pub client_ack_timeout: u64,
#[serde(default = "default_listen_addr")]
pub listen_addr_ipv4: String,
#[serde(default)]
pub listen_addr_ipv6: Option<String>,
#[serde(default)]
pub listen_unix_sock: Option<String>,
#[serde(default)]
pub metrics_port: Option<u16>,
#[serde(default = "default_metrics_whitelist")]
pub metrics_whitelist: Vec<IpAddr>,
#[serde(default = "default_fake_cert_len")]
pub fake_cert_len: usize,
}
fn default_port() -> u16 { 443 }
fn default_tls_domain() -> String { "www.google.com".to_string() }
fn default_mask_port() -> u16 { 443 }
fn default_replay_check_len() -> usize { 65536 }
fn default_handshake_timeout() -> u64 { 10 }
fn default_connect_timeout() -> u64 { 10 }
fn default_keepalive() -> u64 { 600 }
fn default_ack_timeout() -> u64 { 300 }
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
fn default_fake_cert_len() -> usize { 2048 }
fn default_metrics_whitelist() -> Vec<IpAddr> {
vec![
"127.0.0.1".parse().unwrap(),
"::1".parse().unwrap(),
]
}
impl Default for ProxyConfig {
fn default() -> Self {
let mut users = HashMap::new();
users.insert("default".to_string(), "00000000000000000000000000000000".to_string());
Self {
port: default_port(),
users,
ad_tag: None,
modes: ProxyModes::default(),
tls_domain: default_tls_domain(),
mask: true,
mask_host: None,
mask_port: default_mask_port(),
prefer_ipv6: false,
fast_mode: true,
use_middle_proxy: false,
user_max_tcp_conns: HashMap::new(),
user_expirations: HashMap::new(),
user_data_quota: HashMap::new(),
replay_check_len: default_replay_check_len(),
ignore_time_skew: false,
client_handshake_timeout: default_handshake_timeout(),
tg_connect_timeout: default_connect_timeout(),
client_keepalive: default_keepalive(),
client_ack_timeout: default_ack_timeout(),
listen_addr_ipv4: default_listen_addr(),
listen_addr_ipv6: Some("::".to_string()),
listen_unix_sock: None,
metrics_port: None,
metrics_whitelist: default_metrics_whitelist(),
fake_cert_len: default_fake_cert_len(),
}
}
}
impl ProxyConfig {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = std::fs::read_to_string(path)
.map_err(|e| ProxyError::Config(e.to_string()))?;
let mut config: ProxyConfig = toml::from_str(&content)
.map_err(|e| ProxyError::Config(e.to_string()))?;
// Validate secrets
for (user, secret) in &config.users {
if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 {
return Err(ProxyError::InvalidSecret {
user: user.clone(),
reason: "Must be 32 hex characters".to_string(),
});
}
}
// Default mask_host
if config.mask_host.is_none() {
config.mask_host = Some(config.tls_domain.clone());
}
// Random fake_cert_len
use rand::Rng;
config.fake_cert_len = rand::thread_rng().gen_range(1024..4096);
Ok(config)
}
pub fn validate(&self) -> Result<()> {
if self.users.is_empty() {
return Err(ProxyError::Config("No users configured".to_string()));
}
if !self.modes.classic && !self.modes.secure && !self.modes.tls {
return Err(ProxyError::Config("No modes enabled".to_string()));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = ProxyConfig::default();
assert_eq!(config.port, 443);
assert!(config.modes.tls);
assert_eq!(config.client_keepalive, 600);
assert_eq!(config.client_ack_timeout, 300);
}
#[test]
fn test_config_validate() {
let mut config = ProxyConfig::default();
assert!(config.validate().is_ok());
config.users.clear();
assert!(config.validate().is_err());
}
}

351
src/crypto/aes.rs Normal file
View File

@@ -0,0 +1,351 @@
//! AES
use aes::Aes256;
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
use cbc::{Encryptor as CbcEncryptor, Decryptor as CbcDecryptor};
use cbc::cipher::{BlockEncryptMut, BlockDecryptMut, block_padding::NoPadding};
use crate::error::{ProxyError, Result};
type Aes256Ctr = Ctr128BE<Aes256>;
type Aes256CbcEnc = CbcEncryptor<Aes256>;
type Aes256CbcDec = CbcDecryptor<Aes256>;
/// AES-256-CTR encryptor/decryptor
pub struct AesCtr {
cipher: Aes256Ctr,
}
impl AesCtr {
pub fn new(key: &[u8; 32], iv: u128) -> Self {
let iv_bytes = iv.to_be_bytes();
Self {
cipher: Aes256Ctr::new(key.into(), (&iv_bytes).into()),
}
}
pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result<Self> {
if key.len() != 32 {
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
}
if iv.len() != 16 {
return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() });
}
let key: [u8; 32] = key.try_into().unwrap();
let iv = u128::from_be_bytes(iv.try_into().unwrap());
Ok(Self::new(&key, iv))
}
/// Encrypt/decrypt data in-place (CTR mode is symmetric)
pub fn apply(&mut self, data: &mut [u8]) {
self.cipher.apply_keystream(data);
}
/// Encrypt data, returning new buffer
pub fn encrypt(&mut self, data: &[u8]) -> Vec<u8> {
let mut output = data.to_vec();
self.apply(&mut output);
output
}
/// Decrypt data (for CTR, identical to encrypt)
pub fn decrypt(&mut self, data: &[u8]) -> Vec<u8> {
self.encrypt(data)
}
}
/// AES-256-CBC Ciphermagic
pub struct AesCbc {
key: [u8; 32],
iv: [u8; 16],
}
impl AesCbc {
pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self {
Self { key, iv }
}
pub fn from_slices(key: &[u8], iv: &[u8]) -> Result<Self> {
if key.len() != 32 {
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
}
if iv.len() != 16 {
return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() });
}
Ok(Self {
key: key.try_into().unwrap(),
iv: iv.try_into().unwrap(),
})
}
/// Encrypt data using CBC mode
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() % 16 != 0 {
return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
));
}
if data.is_empty() {
return Ok(Vec::new());
}
let mut buffer = data.to_vec();
let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into());
for chunk in buffer.chunks_mut(16) {
encryptor.encrypt_block_mut(chunk.into());
}
Ok(buffer)
}
/// Decrypt data using CBC mode
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() % 16 != 0 {
return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
));
}
if data.is_empty() {
return Ok(Vec::new());
}
let mut buffer = data.to_vec();
let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into());
for chunk in buffer.chunks_mut(16) {
decryptor.decrypt_block_mut(chunk.into());
}
Ok(buffer)
}
/// Encrypt data in-place
pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
if data.len() % 16 != 0 {
return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
));
}
if data.is_empty() {
return Ok(());
}
let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into());
for chunk in data.chunks_mut(16) {
encryptor.encrypt_block_mut(chunk.into());
}
Ok(())
}
/// Decrypt data in-place
pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
if data.len() % 16 != 0 {
return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
));
}
if data.is_empty() {
return Ok(());
}
let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into());
for chunk in data.chunks_mut(16) {
decryptor.decrypt_block_mut(chunk.into());
}
Ok(())
}
}
/// Trait for unified encryption interface
pub trait Encryptor: Send + Sync {
fn encrypt(&mut self, data: &[u8]) -> Vec<u8>;
}
/// Trait for unified decryption interface
pub trait Decryptor: Send + Sync {
fn decrypt(&mut self, data: &[u8]) -> Vec<u8>;
}
impl Encryptor for AesCtr {
fn encrypt(&mut self, data: &[u8]) -> Vec<u8> {
AesCtr::encrypt(self, data)
}
}
impl Decryptor for AesCtr {
fn decrypt(&mut self, data: &[u8]) -> Vec<u8> {
AesCtr::decrypt(self, data)
}
}
/// No-op encryptor for fast mode
pub struct PassthroughEncryptor;
impl Encryptor for PassthroughEncryptor {
fn encrypt(&mut self, data: &[u8]) -> Vec<u8> {
data.to_vec()
}
}
impl Decryptor for PassthroughEncryptor {
fn decrypt(&mut self, data: &[u8]) -> Vec<u8> {
data.to_vec()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aes_ctr_roundtrip() {
let key = [0u8; 32];
let iv = 12345u128;
let original = b"Hello, MTProto!";
let mut enc = AesCtr::new(&key, iv);
let encrypted = enc.encrypt(original);
let mut dec = AesCtr::new(&key, iv);
let decrypted = dec.decrypt(&encrypted);
assert_eq!(original.as_slice(), decrypted.as_slice());
}
#[test]
fn test_aes_cbc_roundtrip() {
let key = [0u8; 32];
let iv = [0u8; 16];
// Must be aligned to 16 bytes
let original = [0u8; 32];
let cipher = AesCbc::new(key, iv);
let encrypted = cipher.encrypt(&original).unwrap();
let decrypted = cipher.decrypt(&encrypted).unwrap();
assert_eq!(original.as_slice(), decrypted.as_slice());
}
#[test]
fn test_aes_cbc_chaining_works() {
let key = [0x42u8; 32];
let iv = [0x00u8; 16];
let plaintext = [0xAA_u8; 32];
let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap();
// CBC Corrections
let block1 = &ciphertext[0..16];
let block2 = &ciphertext[16..32];
assert_ne!(block1, block2, "CBC chaining broken: identical plaintext blocks produced identical ciphertext");
}
#[test]
fn test_aes_cbc_known_vector() {
let key = [0u8; 32];
let iv = [0u8; 16];
// 3 Datablocks
let plaintext = [
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF,
// Block 2
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF,
// Block 3 - different
0xFF, 0xEE, 0xDD, 0xCC, 0xBB, 0xAA, 0x99, 0x88,
0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11, 0x00,
];
let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap();
// Decrypt + Verify
let decrypted = cipher.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
// Verify Ciphertexts Block 1 != Block 2
assert_ne!(&ciphertext[0..16], &ciphertext[16..32]);
}
#[test]
fn test_aes_cbc_in_place() {
let key = [0x12u8; 32];
let iv = [0x34u8; 16];
let original = [0x56u8; 48]; // 3 blocks
let mut buffer = original.clone();
let cipher = AesCbc::new(key, iv);
cipher.encrypt_in_place(&mut buffer).unwrap();
assert_ne!(&buffer[..], &original[..]);
cipher.decrypt_in_place(&mut buffer).unwrap();
assert_eq!(&buffer[..], &original[..]);
}
#[test]
fn test_aes_cbc_empty_data() {
let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
let encrypted = cipher.encrypt(&[]).unwrap();
assert!(encrypted.is_empty());
let decrypted = cipher.decrypt(&[]).unwrap();
assert!(decrypted.is_empty());
}
#[test]
fn test_aes_cbc_unaligned_error() {
let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
// 15 bytes
let result = cipher.encrypt(&[0u8; 15]);
assert!(result.is_err());
// 17 bytes
let result = cipher.encrypt(&[0u8; 17]);
assert!(result.is_err());
}
#[test]
fn test_aes_cbc_avalanche_effect() {
// Cipherplane
let key = [0xAB; 32];
let iv = [0xCD; 16];
let mut plaintext1 = [0u8; 32];
let mut plaintext2 = [0u8; 32];
plaintext2[0] = 0x01; // Один бит отличается
let cipher = AesCbc::new(key, iv);
let ciphertext1 = cipher.encrypt(&plaintext1).unwrap();
let ciphertext2 = cipher.encrypt(&plaintext2).unwrap();
// First Blocks Diff
assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]);
// Second Blocks Diff
assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]);
}
}

90
src/crypto/hash.rs Normal file
View File

@@ -0,0 +1,90 @@
use hmac::{Hmac, Mac};
use sha2::Sha256;
use md5::Md5;
use sha1::Sha1;
use sha2::Digest;
type HmacSha256 = Hmac<Sha256>;
/// SHA-256
pub fn sha256(data: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(data);
hasher.finalize().into()
}
/// SHA-256 HMAC
pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] {
let mut mac = HmacSha256::new_from_slice(key)
.expect("HMAC accepts any key length");
mac.update(data);
mac.finalize().into_bytes().into()
}
/// SHA-1
pub fn sha1(data: &[u8]) -> [u8; 20] {
let mut hasher = Sha1::new();
hasher.update(data);
hasher.finalize().into()
}
/// MD5
pub fn md5(data: &[u8]) -> [u8; 16] {
let mut hasher = Md5::new();
hasher.update(data);
hasher.finalize().into()
}
/// CRC32
pub fn crc32(data: &[u8]) -> u32 {
crc32fast::hash(data)
}
/// Middle Proxy Keygen
pub fn derive_middleproxy_keys(
nonce_srv: &[u8; 16],
nonce_clt: &[u8; 16],
clt_ts: &[u8; 4],
srv_ip: Option<&[u8]>,
clt_port: &[u8; 2],
purpose: &[u8],
clt_ip: Option<&[u8]>,
srv_port: &[u8; 2],
secret: &[u8],
clt_ipv6: Option<&[u8; 16]>,
srv_ipv6: Option<&[u8; 16]>,
) -> ([u8; 32], [u8; 16]) {
const EMPTY_IP: [u8; 4] = [0, 0, 0, 0];
let srv_ip = srv_ip.unwrap_or(&EMPTY_IP);
let clt_ip = clt_ip.unwrap_or(&EMPTY_IP);
let mut s = Vec::with_capacity(256);
s.extend_from_slice(nonce_srv);
s.extend_from_slice(nonce_clt);
s.extend_from_slice(clt_ts);
s.extend_from_slice(srv_ip);
s.extend_from_slice(clt_port);
s.extend_from_slice(purpose);
s.extend_from_slice(clt_ip);
s.extend_from_slice(srv_port);
s.extend_from_slice(secret);
s.extend_from_slice(nonce_srv);
if let (Some(clt_v6), Some(srv_v6)) = (clt_ipv6, srv_ipv6) {
s.extend_from_slice(clt_v6);
s.extend_from_slice(srv_v6);
}
s.extend_from_slice(nonce_clt);
let md5_1 = md5(&s[1..]);
let sha1_sum = sha1(&s);
let md5_2 = md5(&s[2..]);
let mut key = [0u8; 32];
key[..12].copy_from_slice(&md5_1[..12]);
key[12..].copy_from_slice(&sha1_sum);
(key, md5_2)
}

9
src/crypto/mod.rs Normal file
View File

@@ -0,0 +1,9 @@
//! Crypto
pub mod aes;
pub mod hash;
pub mod random;
pub use aes::{AesCtr, AesCbc};
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32};
pub use random::{SecureRandom, SECURE_RANDOM};

212
src/crypto/random.rs Normal file
View File

@@ -0,0 +1,212 @@
//! Pseudorandom
use rand::{Rng, RngCore, SeedableRng};
use rand::rngs::StdRng;
use parking_lot::Mutex;
use crate::crypto::AesCtr;
use once_cell::sync::Lazy;
/// Global secure random instance
pub static SECURE_RANDOM: Lazy<SecureRandom> = Lazy::new(SecureRandom::new);
/// Cryptographically secure PRNG with AES-CTR
pub struct SecureRandom {
inner: Mutex<SecureRandomInner>,
}
struct SecureRandomInner {
rng: StdRng,
cipher: AesCtr,
buffer: Vec<u8>,
}
impl SecureRandom {
pub fn new() -> Self {
let mut rng = StdRng::from_entropy();
let mut key = [0u8; 32];
rng.fill_bytes(&mut key);
let iv: u128 = rng.gen();
Self {
inner: Mutex::new(SecureRandomInner {
rng,
cipher: AesCtr::new(&key, iv),
buffer: Vec::with_capacity(1024),
}),
}
}
/// Generate random bytes
pub fn bytes(&self, len: usize) -> Vec<u8> {
let mut inner = self.inner.lock();
const CHUNK_SIZE: usize = 512;
while inner.buffer.len() < len {
let mut chunk = vec![0u8; CHUNK_SIZE];
inner.rng.fill_bytes(&mut chunk);
inner.cipher.apply(&mut chunk);
inner.buffer.extend_from_slice(&chunk);
}
inner.buffer.drain(..len).collect()
}
/// Generate random number in range [0, max)
pub fn range(&self, max: usize) -> usize {
if max == 0 {
return 0;
}
let mut inner = self.inner.lock();
inner.rng.gen_range(0..max)
}
/// Generate random bits
pub fn bits(&self, k: usize) -> u64 {
if k == 0 {
return 0;
}
let bytes_needed = (k + 7) / 8;
let bytes = self.bytes(bytes_needed.min(8));
let mut result = 0u64;
for (i, &b) in bytes.iter().enumerate() {
if i >= 8 {
break;
}
result |= (b as u64) << (i * 8);
}
// Mask extra bits
if k < 64 {
result &= (1u64 << k) - 1;
}
result
}
/// Choose random element from slice
pub fn choose<'a, T>(&self, slice: &'a [T]) -> Option<&'a T> {
if slice.is_empty() {
None
} else {
Some(&slice[self.range(slice.len())])
}
}
/// Shuffle slice in place
pub fn shuffle<T>(&self, slice: &mut [T]) {
let mut inner = self.inner.lock();
for i in (1..slice.len()).rev() {
let j = inner.rng.gen_range(0..=i);
slice.swap(i, j);
}
}
/// Generate random u32
pub fn u32(&self) -> u32 {
let mut inner = self.inner.lock();
inner.rng.gen()
}
/// Generate random u64
pub fn u64(&self) -> u64 {
let mut inner = self.inner.lock();
inner.rng.gen()
}
}
impl Default for SecureRandom {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn test_bytes_uniqueness() {
let rng = SecureRandom::new();
let a = rng.bytes(32);
let b = rng.bytes(32);
assert_ne!(a, b);
}
#[test]
fn test_bytes_length() {
let rng = SecureRandom::new();
assert_eq!(rng.bytes(0).len(), 0);
assert_eq!(rng.bytes(1).len(), 1);
assert_eq!(rng.bytes(100).len(), 100);
assert_eq!(rng.bytes(1000).len(), 1000);
}
#[test]
fn test_range() {
let rng = SecureRandom::new();
for _ in 0..1000 {
let n = rng.range(10);
assert!(n < 10);
}
assert_eq!(rng.range(1), 0);
assert_eq!(rng.range(0), 0);
}
#[test]
fn test_bits() {
let rng = SecureRandom::new();
// Single bit should be 0 or 1
for _ in 0..100 {
assert!(rng.bits(1) <= 1);
}
// 8 bits should be 0-255
for _ in 0..100 {
assert!(rng.bits(8) <= 255);
}
}
#[test]
fn test_choose() {
let rng = SecureRandom::new();
let items = vec![1, 2, 3, 4, 5];
let mut seen = HashSet::new();
for _ in 0..1000 {
if let Some(&item) = rng.choose(&items) {
seen.insert(item);
}
}
// Should have seen all items
assert_eq!(seen.len(), 5);
// Empty slice should return None
let empty: Vec<i32> = vec![];
assert!(rng.choose(&empty).is_none());
}
#[test]
fn test_shuffle() {
let rng = SecureRandom::new();
let original = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let mut shuffled = original.clone();
rng.shuffle(&mut shuffled);
// Should contain same elements
let mut sorted = shuffled.clone();
sorted.sort();
assert_eq!(sorted, original);
// Should be different order (with very high probability)
assert_ne!(shuffled, original);
}
}

176
src/error.rs Normal file
View File

@@ -0,0 +1,176 @@
//! Error Types
use std::net::SocketAddr;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum ProxyError {
// ============= Crypto Errors =============
#[error("Crypto error: {0}")]
Crypto(String),
#[error("Invalid key length: expected {expected}, got {got}")]
InvalidKeyLength { expected: usize, got: usize },
// ============= Protocol Errors =============
#[error("Invalid handshake: {0}")]
InvalidHandshake(String),
#[error("Invalid protocol tag: {0:02x?}")]
InvalidProtoTag([u8; 4]),
#[error("Invalid TLS record: type={record_type}, version={version:02x?}")]
InvalidTlsRecord { record_type: u8, version: [u8; 2] },
#[error("Replay attack detected from {addr}")]
ReplayAttack { addr: SocketAddr },
#[error("Time skew detected: client={client_time}, server={server_time}")]
TimeSkew { client_time: u32, server_time: u32 },
#[error("Invalid message length: {len} (min={min}, max={max})")]
InvalidMessageLength { len: usize, min: usize, max: usize },
#[error("Checksum mismatch: expected={expected:08x}, got={got:08x}")]
ChecksumMismatch { expected: u32, got: u32 },
#[error("Sequence number mismatch: expected={expected}, got={got}")]
SeqNoMismatch { expected: i32, got: i32 },
// ============= Network Errors =============
#[error("Connection timeout to {addr}")]
ConnectionTimeout { addr: String },
#[error("Connection refused by {addr}")]
ConnectionRefused { addr: String },
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
// ============= Proxy Protocol Errors =============
#[error("Invalid proxy protocol header")]
InvalidProxyProtocol,
// ============= Config Errors =============
#[error("Config error: {0}")]
Config(String),
#[error("Invalid secret for user {user}: {reason}")]
InvalidSecret { user: String, reason: String },
// ============= User Errors =============
#[error("User {user} expired")]
UserExpired { user: String },
#[error("User {user} exceeded connection limit")]
ConnectionLimitExceeded { user: String },
#[error("User {user} exceeded data quota")]
DataQuotaExceeded { user: String },
#[error("Unknown user")]
UnknownUser,
// ============= General Errors =============
#[error("Internal error: {0}")]
Internal(String),
}
/// Convenient Result type alias
pub type Result<T> = std::result::Result<T, ProxyError>;
/// Result with optional bad client handling
#[derive(Debug)]
pub enum HandshakeResult<T> {
/// Handshake succeeded
Success(T),
/// Client failed validation, needs masking
BadClient,
/// Error occurred
Error(ProxyError),
}
impl<T> HandshakeResult<T> {
/// Check if successful
pub fn is_success(&self) -> bool {
matches!(self, HandshakeResult::Success(_))
}
/// Check if bad client
pub fn is_bad_client(&self) -> bool {
matches!(self, HandshakeResult::BadClient)
}
/// Convert to Result, treating BadClient as error
pub fn into_result(self) -> Result<T> {
match self {
HandshakeResult::Success(v) => Ok(v),
HandshakeResult::BadClient => Err(ProxyError::InvalidHandshake("Bad client".into())),
HandshakeResult::Error(e) => Err(e),
}
}
/// Map the success value
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U> {
match self {
HandshakeResult::Success(v) => HandshakeResult::Success(f(v)),
HandshakeResult::BadClient => HandshakeResult::BadClient,
HandshakeResult::Error(e) => HandshakeResult::Error(e),
}
}
}
impl<T> From<ProxyError> for HandshakeResult<T> {
fn from(err: ProxyError) -> Self {
HandshakeResult::Error(err)
}
}
impl<T> From<std::io::Error> for HandshakeResult<T> {
fn from(err: std::io::Error) -> Self {
HandshakeResult::Error(ProxyError::Io(err))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handshake_result() {
let success: HandshakeResult<i32> = HandshakeResult::Success(42);
assert!(success.is_success());
assert!(!success.is_bad_client());
let bad: HandshakeResult<i32> = HandshakeResult::BadClient;
assert!(!bad.is_success());
assert!(bad.is_bad_client());
}
#[test]
fn test_handshake_result_map() {
let success: HandshakeResult<i32> = HandshakeResult::Success(42);
let mapped = success.map(|x| x * 2);
match mapped {
HandshakeResult::Success(v) => assert_eq!(v, 84),
_ => panic!("Expected success"),
}
}
#[test]
fn test_error_display() {
let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() };
assert!(err.to_string().contains("1.2.3.4:443"));
let err = ProxyError::InvalidProxyProtocol;
assert!(err.to_string().contains("proxy protocol"));
}
}

158
src/main.rs Normal file
View File

@@ -0,0 +1,158 @@
//! Telemt - MTProxy on Rust
use std::sync::Arc;
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tokio::signal;
use tracing::{info, error, Level};
use tracing_subscriber::{FmtSubscriber, EnvFilter};
mod error;
mod crypto;
mod protocol;
mod stream;
mod transport;
mod proxy;
mod config;
mod stats;
mod util;
use config::ProxyConfig;
use stats::{Stats, ReplayChecker};
use transport::ConnectionPool;
use proxy::ClientHandler;
#[tokio::main]
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
// Initialize logging with env filter
// Use RUST_LOG=debug or RUST_LOG=trace for more details
let filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info"));
let subscriber = FmtSubscriber::builder()
.with_env_filter(filter)
.with_target(true)
.with_thread_ids(false)
.with_file(false)
.with_line_number(false)
.finish();
tracing::subscriber::set_global_default(subscriber)?;
// Load configuration
let config_path = std::env::args()
.nth(1)
.unwrap_or_else(|| "config.toml".to_string());
info!("Loading configuration from {}", config_path);
let config = ProxyConfig::load(&config_path).unwrap_or_else(|e| {
error!("Failed to load config: {}", e);
info!("Using default configuration");
ProxyConfig::default()
});
if let Err(e) = config.validate() {
error!("Invalid configuration: {}", e);
std::process::exit(1);
}
let config = Arc::new(config);
info!("Starting MTProto Proxy on port {}", config.port);
info!("Fast mode: {}", config.fast_mode);
info!("Modes: classic={}, secure={}, tls={}",
config.modes.classic, config.modes.secure, config.modes.tls);
// Initialize components
let stats = Arc::new(Stats::new());
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len));
let pool = Arc::new(ConnectionPool::new());
// Create handler
let handler = Arc::new(ClientHandler::new(
Arc::clone(&config),
Arc::clone(&stats),
Arc::clone(&replay_checker),
Arc::clone(&pool),
));
// Start listener
let addr: SocketAddr = format!("{}:{}", config.listen_addr_ipv4, config.port)
.parse()?;
let listener = TcpListener::bind(addr).await?;
info!("Listening on {}", addr);
// Print proxy links
print_proxy_links(&config);
info!("Use RUST_LOG=debug or RUST_LOG=trace for more detailed logging");
// Main accept loop
let accept_loop = async {
loop {
match listener.accept().await {
Ok((stream, peer)) => {
let handler = Arc::clone(&handler);
tokio::spawn(async move {
handler.handle(stream, peer).await;
});
}
Err(e) => {
error!("Accept error: {}", e);
}
}
}
};
// Graceful shutdown
tokio::select! {
_ = accept_loop => {}
_ = signal::ctrl_c() => {
info!("Shutting down...");
}
}
// Cleanup
pool.close_all().await;
info!("Goodbye!");
Ok(())
}
fn print_proxy_links(config: &ProxyConfig) {
println!("\n=== Proxy Links ===\n");
for (user, secret) in &config.users {
if config.modes.tls {
let tls_secret = format!(
"ee{}{}",
secret,
hex::encode(config.tls_domain.as_bytes())
);
println!(
"{} (TLS): tg://proxy?server=IP&port={}&secret={}",
user, config.port, tls_secret
);
}
if config.modes.secure {
println!(
"{} (Secure): tg://proxy?server=IP&port={}&secret=dd{}",
user, config.port, secret
);
}
if config.modes.classic {
println!(
"{} (Classic): tg://proxy?server=IP&port={}&secret={}",
user, config.port, secret
);
}
println!();
}
println!("===================\n");
}

261
src/protocol/constants.rs Normal file
View File

@@ -0,0 +1,261 @@
//! Protocol constants and datacenter addresses
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use once_cell::sync::Lazy;
// ============= Telegram Datacenters =============
pub const TG_DATACENTER_PORT: u16 = 443;
pub static TG_DATACENTERS_V4: Lazy<Vec<IpAddr>> = Lazy::new(|| {
vec![
IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)),
IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)),
IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)),
IpAddr::V4(Ipv4Addr::new(149, 154, 167, 91)),
IpAddr::V4(Ipv4Addr::new(149, 154, 171, 5)),
]
});
pub static TG_DATACENTERS_V6: Lazy<Vec<IpAddr>> = Lazy::new(|| {
vec![
IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()),
IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()),
IpAddr::V6("2001:b28:f23d:f003::a".parse().unwrap()),
IpAddr::V6("2001:67c:04e8:f004::a".parse().unwrap()),
IpAddr::V6("2001:b28:f23f:f005::a".parse().unwrap()),
]
});
// ============= Middle Proxies (for advertising) =============
pub static TG_MIDDLE_PROXIES_V4: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
Lazy::new(|| {
let mut m = std::collections::HashMap::new();
m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
m.insert(2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]);
m.insert(-2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]);
m.insert(3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]);
m.insert(-3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]);
m.insert(4, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888)]);
m.insert(-4, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 165, 109)), 8888)]);
m.insert(5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]);
m.insert(-5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]);
m
});
pub static TG_MIDDLE_PROXIES_V6: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
Lazy::new(|| {
let mut m = std::collections::HashMap::new();
m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
m.insert(2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]);
m.insert(-2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]);
m.insert(3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]);
m.insert(-3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]);
m.insert(4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]);
m.insert(-4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]);
m.insert(5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]);
m.insert(-5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]);
m
});
// ============= Protocol Tags =============
/// MTProto transport protocol variants
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum ProtoTag {
/// Abridged protocol - compact framing
Abridged = 0xefefefef,
/// Intermediate protocol - simple 4-byte length prefix
Intermediate = 0xeeeeeeee,
/// Secure intermediate - with random padding
Secure = 0xdddddddd,
}
impl ProtoTag {
/// Parse protocol tag from 4 bytes
pub fn from_bytes(bytes: [u8; 4]) -> Option<Self> {
match u32::from_le_bytes(bytes) {
0xefefefef => Some(ProtoTag::Abridged),
0xeeeeeeee => Some(ProtoTag::Intermediate),
0xdddddddd => Some(ProtoTag::Secure),
_ => None,
}
}
/// Convert to 4 bytes (little-endian)
pub fn to_bytes(self) -> [u8; 4] {
(self as u32).to_le_bytes()
}
/// Get protocol tag as bytes slice
pub fn as_bytes(&self) -> &'static [u8; 4] {
match self {
ProtoTag::Abridged => &PROTO_TAG_ABRIDGED,
ProtoTag::Intermediate => &PROTO_TAG_INTERMEDIATE,
ProtoTag::Secure => &PROTO_TAG_SECURE,
}
}
}
/// Protocol tag bytes
pub const PROTO_TAG_ABRIDGED: [u8; 4] = [0xef, 0xef, 0xef, 0xef];
pub const PROTO_TAG_INTERMEDIATE: [u8; 4] = [0xee, 0xee, 0xee, 0xee];
pub const PROTO_TAG_SECURE: [u8; 4] = [0xdd, 0xdd, 0xdd, 0xdd];
// ============= Handshake Layout =============
/// Bytes to skip at the start of handshake
pub const SKIP_LEN: usize = 8;
/// Pre-key length (before hashing with secret)
pub const PREKEY_LEN: usize = 32;
/// AES key length
pub const KEY_LEN: usize = 32;
/// AES IV length
pub const IV_LEN: usize = 16;
/// Total handshake length
pub const HANDSHAKE_LEN: usize = 64;
/// Position of protocol tag in decrypted handshake
pub const PROTO_TAG_POS: usize = 56;
/// Position of datacenter index
pub const DC_IDX_POS: usize = 60;
// ============= Message Limits =============
/// Minimum message length
pub const MIN_MSG_LEN: usize = 12;
/// Maximum message length (16 MB)
pub const MAX_MSG_LEN: usize = 1 << 24;
/// CBC block padding size
pub const CBC_PADDING: usize = 16;
/// Padding filler bytes
pub const PADDING_FILLER: [u8; 4] = [0x04, 0x00, 0x00, 0x00];
// ============= TLS Constants =============
/// Minimum certificate length for detection
pub const MIN_CERT_LEN: usize = 1024;
/// TLS 1.3 version bytes
pub const TLS_VERSION: [u8; 2] = [0x03, 0x03];
/// TLS record type: Handshake
pub const TLS_RECORD_HANDSHAKE: u8 = 0x16;
/// TLS record type: Change Cipher Spec
pub const TLS_RECORD_CHANGE_CIPHER: u8 = 0x14;
/// TLS record type: Application Data
pub const TLS_RECORD_APPLICATION: u8 = 0x17;
/// TLS record type: Alert
pub const TLS_RECORD_ALERT: u8 = 0x15;
/// Maximum TLS record size
pub const MAX_TLS_RECORD_SIZE: usize = 16384;
/// Maximum TLS chunk size (with overhead)
pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 24;
// ============= Timeouts =============
/// Default handshake timeout in seconds
pub const DEFAULT_HANDSHAKE_TIMEOUT_SECS: u64 = 10;
/// Default connect timeout in seconds
pub const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10;
/// Default keepalive interval in seconds
pub const DEFAULT_KEEPALIVE_SECS: u64 = 600;
/// Default ACK timeout in seconds
pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300;
// ============= Buffer Sizes =============
/// Default buffer size
pub const DEFAULT_BUFFER_SIZE: usize = 65536;
/// Small buffer size for bad client handling
pub const SMALL_BUFFER_SIZE: usize = 8192;
// ============= Statistics =============
/// Duration buckets for histogram metrics
pub static DURATION_BUCKETS: &[f64] = &[
0.1, 0.5, 1.0, 2.0, 5.0, 15.0, 60.0, 300.0, 600.0, 1800.0,
];
// ============= Reserved Nonce Patterns =============
/// Reserved first bytes of nonce (must avoid)
pub static RESERVED_NONCE_FIRST_BYTES: &[u8] = &[0xef];
/// Reserved 4-byte beginnings of nonce
pub static RESERVED_NONCE_BEGINNINGS: &[[u8; 4]] = &[
[0x48, 0x45, 0x41, 0x44], // HEAD
[0x50, 0x4F, 0x53, 0x54], // POST
[0x47, 0x45, 0x54, 0x20], // GET
[0xee, 0xee, 0xee, 0xee], // Intermediate
[0xdd, 0xdd, 0xdd, 0xdd], // Secure
[0x16, 0x03, 0x01, 0x02], // TLS
];
/// Reserved continuation bytes (bytes 4-7)
pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[
[0x00, 0x00, 0x00, 0x00],
];
// ============= RPC Constants (for Middle Proxy) =============
/// RPC Proxy Request
pub const RPC_PROXY_REQ: [u8; 4] = [0xee, 0xf1, 0xce, 0x36];
/// RPC Proxy Answer
pub const RPC_PROXY_ANS: [u8; 4] = [0x0d, 0xda, 0x03, 0x44];
/// RPC Close Extended
pub const RPC_CLOSE_EXT: [u8; 4] = [0xa2, 0x34, 0xb6, 0x5e];
/// RPC Simple ACK
pub const RPC_SIMPLE_ACK: [u8; 4] = [0x9b, 0x40, 0xac, 0x3b];
/// RPC Unknown
pub const RPC_UNKNOWN: [u8; 4] = [0xdf, 0xa2, 0x30, 0x57];
/// RPC Handshake
pub const RPC_HANDSHAKE: [u8; 4] = [0xf5, 0xee, 0x82, 0x76];
/// RPC Nonce
pub const RPC_NONCE: [u8; 4] = [0xaa, 0x87, 0xcb, 0x7a];
/// RPC Flags
pub mod rpc_flags {
pub const FLAG_NOT_ENCRYPTED: u32 = 0x2;
pub const FLAG_HAS_AD_TAG: u32 = 0x8;
pub const FLAG_MAGIC: u32 = 0x1000;
pub const FLAG_EXTMODE2: u32 = 0x20000;
pub const FLAG_PAD: u32 = 0x8000000;
pub const FLAG_INTERMEDIATE: u32 = 0x20000000;
pub const FLAG_ABRIDGED: u32 = 0x40000000;
pub const FLAG_QUICKACK: u32 = 0x80000000;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_proto_tag_roundtrip() {
for tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
let bytes = tag.to_bytes();
let parsed = ProtoTag::from_bytes(bytes).unwrap();
assert_eq!(tag, parsed);
}
}
#[test]
fn test_proto_tag_values() {
assert_eq!(ProtoTag::Abridged.to_bytes(), PROTO_TAG_ABRIDGED);
assert_eq!(ProtoTag::Intermediate.to_bytes(), PROTO_TAG_INTERMEDIATE);
assert_eq!(ProtoTag::Secure.to_bytes(), PROTO_TAG_SECURE);
}
#[test]
fn test_invalid_proto_tag() {
assert!(ProtoTag::from_bytes([0, 0, 0, 0]).is_none());
assert!(ProtoTag::from_bytes([0xff, 0xff, 0xff, 0xff]).is_none());
}
#[test]
fn test_datacenters_count() {
assert_eq!(TG_DATACENTERS_V4.len(), 5);
assert_eq!(TG_DATACENTERS_V6.len(), 5);
}
}

120
src/protocol/frame.rs Normal file
View File

@@ -0,0 +1,120 @@
//! MTProto frame types and metadata
use std::collections::HashMap;
/// Extra metadata associated with a frame
#[derive(Debug, Clone, Default)]
pub struct FrameExtra {
/// Quick ACK flag - request immediate acknowledgment
pub quickack: bool,
/// Simple ACK - this is an acknowledgment message
pub simple_ack: bool,
/// Skip sending - internal flag to skip forwarding
pub skip_send: bool,
/// Custom key-value metadata
pub custom: HashMap<String, String>,
}
impl FrameExtra {
/// Create new empty frame extra
pub fn new() -> Self {
Self::default()
}
/// Create with quickack flag set
pub fn with_quickack() -> Self {
Self {
quickack: true,
..Default::default()
}
}
/// Create with simple_ack flag set
pub fn with_simple_ack() -> Self {
Self {
simple_ack: true,
..Default::default()
}
}
/// Check if any flags are set
pub fn has_flags(&self) -> bool {
self.quickack || self.simple_ack || self.skip_send
}
}
/// Result of reading a frame
#[derive(Debug)]
pub enum FrameReadResult {
/// Successfully read a frame with data and metadata
Data(Vec<u8>, FrameExtra),
/// Connection closed normally
Closed,
/// Need more data (for non-blocking reads)
WouldBlock,
}
/// Frame encoding/decoding mode
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FrameMode {
/// Abridged - 1 or 4 byte length prefix
Abridged,
/// Intermediate - 4 byte length prefix
Intermediate,
/// Secure Intermediate - 4 byte length with padding
SecureIntermediate,
/// Full MTProto - with seq_no and CRC32
Full,
}
impl FrameMode {
/// Get maximum overhead for this frame mode
pub fn max_overhead(&self) -> usize {
match self {
FrameMode::Abridged => 4,
FrameMode::Intermediate => 4,
FrameMode::SecureIntermediate => 4 + 3, // length + padding
FrameMode::Full => 12 + 16, // header + max CBC padding
}
}
}
/// Validate message length for MTProto
pub fn validate_message_length(len: usize) -> bool {
use super::constants::{MIN_MSG_LEN, MAX_MSG_LEN, PADDING_FILLER};
len >= MIN_MSG_LEN && len <= MAX_MSG_LEN && len % PADDING_FILLER.len() == 0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_frame_extra_default() {
let extra = FrameExtra::default();
assert!(!extra.quickack);
assert!(!extra.simple_ack);
assert!(!extra.skip_send);
assert!(!extra.has_flags());
}
#[test]
fn test_frame_extra_flags() {
let extra = FrameExtra::with_quickack();
assert!(extra.quickack);
assert!(extra.has_flags());
let extra = FrameExtra::with_simple_ack();
assert!(extra.simple_ack);
assert!(extra.has_flags());
}
#[test]
fn test_validate_message_length() {
assert!(validate_message_length(12)); // MIN_MSG_LEN
assert!(validate_message_length(16));
assert!(!validate_message_length(8)); // Too small
assert!(!validate_message_length(13)); // Not aligned to 4
}
}

11
src/protocol/mod.rs Normal file
View File

@@ -0,0 +1,11 @@
//! MTProto Defs + Cons
pub mod constants;
pub mod frame;
pub mod obfuscation;
pub mod tls;
pub use constants::*;
pub use frame::*;
pub use obfuscation::*;
pub use tls::*;

217
src/protocol/obfuscation.rs Normal file
View File

@@ -0,0 +1,217 @@
//! MTProto Obfuscation
use crate::crypto::{sha256, AesCtr};
use crate::error::Result;
use super::constants::*;
/// Obfuscation parameters from handshake
#[derive(Debug, Clone)]
pub struct ObfuscationParams {
/// Key for decrypting client -> proxy traffic
pub decrypt_key: [u8; 32],
/// IV for decrypting client -> proxy traffic
pub decrypt_iv: u128,
/// Key for encrypting proxy -> client traffic
pub encrypt_key: [u8; 32],
/// IV for encrypting proxy -> client traffic
pub encrypt_iv: u128,
/// Protocol tag (abridged/intermediate/secure)
pub proto_tag: ProtoTag,
/// Datacenter index
pub dc_idx: i16,
}
impl ObfuscationParams {
/// Parse obfuscation parameters from handshake bytes
/// Returns None if handshake doesn't match any user secret
pub fn from_handshake(
handshake: &[u8; HANDSHAKE_LEN],
secrets: &[(String, Vec<u8>)], // (username, secret_bytes)
) -> Option<(Self, String)> {
// Extract prekey and IV for decryption
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
// Reversed for encryption direction
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
for (username, secret) in secrets {
// Derive decryption key
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(secret);
let decrypt_key = sha256(&dec_key_input);
let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
// Create decryptor and decrypt handshake
let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv);
let decrypted = decryptor.decrypt(handshake);
// Check protocol tag
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
.try_into()
.unwrap();
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
Some(tag) => tag,
None => continue, // Try next secret
};
// Extract DC index
let dc_idx = i16::from_le_bytes(
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
);
// Derive encryption key
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(secret);
let encrypt_key = sha256(&enc_key_input);
let encrypt_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
return Some((
ObfuscationParams {
decrypt_key,
decrypt_iv,
encrypt_key,
encrypt_iv,
proto_tag,
dc_idx,
},
username.clone(),
));
}
None
}
/// Create AES-CTR decryptor for client -> proxy direction
pub fn create_decryptor(&self) -> AesCtr {
AesCtr::new(&self.decrypt_key, self.decrypt_iv)
}
/// Create AES-CTR encryptor for proxy -> client direction
pub fn create_encryptor(&self) -> AesCtr {
AesCtr::new(&self.encrypt_key, self.encrypt_iv)
}
/// Get the combined encrypt key and IV for fast mode
pub fn enc_key_iv(&self) -> Vec<u8> {
let mut result = Vec::with_capacity(KEY_LEN + IV_LEN);
result.extend_from_slice(&self.encrypt_key);
result.extend_from_slice(&self.encrypt_iv.to_be_bytes());
result
}
}
/// Generate a valid random nonce for Telegram handshake
pub fn generate_nonce<R: FnMut(usize) -> Vec<u8>>(mut random_bytes: R) -> [u8; HANDSHAKE_LEN] {
loop {
let nonce_vec = random_bytes(HANDSHAKE_LEN);
let mut nonce = [0u8; HANDSHAKE_LEN];
nonce.copy_from_slice(&nonce_vec);
if is_valid_nonce(&nonce) {
return nonce;
}
}
}
/// Check if nonce is valid (not matching reserved patterns)
pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
// Check first byte
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
return false;
}
// Check first 4 bytes
let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
return false;
}
// Check bytes 4-7
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
return false;
}
true
}
/// Prepare nonce for sending to Telegram
pub fn prepare_tg_nonce(
nonce: &mut [u8; HANDSHAKE_LEN],
proto_tag: ProtoTag,
enc_key_iv: Option<&[u8]>, // For fast mode
) {
// Set protocol tag
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
// For fast mode, copy the reversed enc_key_iv
if let Some(key_iv) = enc_key_iv {
let reversed: Vec<u8> = key_iv.iter().rev().copied().collect();
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed);
}
}
/// Encrypt the outgoing nonce for Telegram
pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
// Derive encryption key from the nonce itself
let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let enc_key = sha256(key_iv);
let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap());
let mut encryptor = AesCtr::new(&enc_key, enc_iv);
// Only encrypt from PROTO_TAG_POS onwards
let mut result = nonce.to_vec();
let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]);
result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part);
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_valid_nonce() {
// Valid nonce
let mut valid = [0x42u8; HANDSHAKE_LEN];
valid[4..8].copy_from_slice(&[1, 2, 3, 4]);
assert!(is_valid_nonce(&valid));
// Invalid: starts with 0xef
let mut invalid = [0x00u8; HANDSHAKE_LEN];
invalid[0] = 0xef;
assert!(!is_valid_nonce(&invalid));
// Invalid: starts with HEAD
let mut invalid = [0x00u8; HANDSHAKE_LEN];
invalid[..4].copy_from_slice(b"HEAD");
assert!(!is_valid_nonce(&invalid));
// Invalid: bytes 4-7 are zeros
let mut invalid = [0x42u8; HANDSHAKE_LEN];
invalid[4..8].copy_from_slice(&[0, 0, 0, 0]);
assert!(!is_valid_nonce(&invalid));
}
#[test]
fn test_generate_nonce() {
let mut counter = 0u8;
let nonce = generate_nonce(|n| {
counter = counter.wrapping_add(1);
vec![counter; n]
});
assert!(is_valid_nonce(&nonce));
assert_eq!(nonce.len(), HANDSHAKE_LEN);
}
}

244
src/protocol/tls.rs Normal file
View File

@@ -0,0 +1,244 @@
//! Fake TLS 1.3 Handshake
use crate::crypto::{sha256_hmac, random::SECURE_RANDOM};
use crate::error::{ProxyError, Result};
use super::constants::*;
use std::time::{SystemTime, UNIX_EPOCH};
/// TLS handshake digest length
pub const TLS_DIGEST_LEN: usize = 32;
/// Position of digest in TLS ClientHello
pub const TLS_DIGEST_POS: usize = 11;
/// Length to store for replay protection (first 16 bytes of digest)
pub const TLS_DIGEST_HALF_LEN: usize = 16;
/// Time skew limits for anti-replay (in seconds)
pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before
pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after
/// Result of validating TLS handshake
#[derive(Debug)]
pub struct TlsValidation {
/// Username that validated
pub user: String,
/// Session ID from ClientHello
pub session_id: Vec<u8>,
/// Client digest for response generation
pub digest: [u8; TLS_DIGEST_LEN],
/// Timestamp extracted from digest
pub timestamp: u32,
}
/// Validate TLS ClientHello against user secrets
pub fn validate_tls_handshake(
handshake: &[u8],
secrets: &[(String, Vec<u8>)],
ignore_time_skew: bool,
) -> Option<TlsValidation> {
if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 {
return None;
}
// Extract digest
let digest: [u8; TLS_DIGEST_LEN] = handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.try_into()
.ok()?;
// Extract session ID
let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN;
let session_id_len = handshake.get(session_id_len_pos).copied()? as usize;
let session_id_start = session_id_len_pos + 1;
if handshake.len() < session_id_start + session_id_len {
return None;
}
let session_id = handshake[session_id_start..session_id_start + session_id_len].to_vec();
// Build message for HMAC (with zeroed digest)
let mut msg = handshake.to_vec();
msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
// Get current time
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
for (user, secret) in secrets {
let computed = sha256_hmac(secret, &msg);
// XOR digests
let xored: Vec<u8> = digest.iter()
.zip(computed.iter())
.map(|(a, b)| a ^ b)
.collect();
// Check that first 28 bytes are zeros (timestamp in last 4)
if !xored[..28].iter().all(|&b| b == 0) {
continue;
}
// Extract timestamp
let timestamp = u32::from_le_bytes(xored[28..32].try_into().unwrap());
let time_diff = now - timestamp as i64;
// Check time skew
if !ignore_time_skew {
// Allow very small timestamps (boot time instead of unix time)
let is_boot_time = timestamp < 60 * 60 * 24 * 1000;
if !is_boot_time && (time_diff < TIME_SKEW_MIN || time_diff > TIME_SKEW_MAX) {
continue;
}
}
return Some(TlsValidation {
user: user.clone(),
session_id,
digest,
timestamp,
});
}
None
}
/// Generate a fake X25519 public key for TLS
/// This generates a value that looks like a valid X25519 key
pub fn gen_fake_x25519_key() -> [u8; 32] {
// For simplicity, just generate random 32 bytes
// In real X25519, this would be a point on the curve
let bytes = SECURE_RANDOM.bytes(32);
bytes.try_into().unwrap()
}
/// Build TLS ServerHello response
pub fn build_server_hello(
secret: &[u8],
client_digest: &[u8; TLS_DIGEST_LEN],
session_id: &[u8],
fake_cert_len: usize,
) -> Vec<u8> {
let x25519_key = gen_fake_x25519_key();
// TLS extensions
let mut extensions = Vec::new();
extensions.extend_from_slice(&[0x00, 0x2e]); // Extension length placeholder
extensions.extend_from_slice(&[0x00, 0x33, 0x00, 0x24]); // Key share extension
extensions.extend_from_slice(&[0x00, 0x1d, 0x00, 0x20]); // X25519 curve
extensions.extend_from_slice(&x25519_key);
extensions.extend_from_slice(&[0x00, 0x2b, 0x00, 0x02, 0x03, 0x04]); // Supported versions
// ServerHello body
let mut srv_hello = Vec::new();
srv_hello.extend_from_slice(&TLS_VERSION);
srv_hello.extend_from_slice(&[0u8; TLS_DIGEST_LEN]); // Placeholder for digest
srv_hello.push(session_id.len() as u8);
srv_hello.extend_from_slice(session_id);
srv_hello.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256
srv_hello.push(0x00); // No compression
srv_hello.extend_from_slice(&extensions);
// Build complete packet
let mut hello_pkt = Vec::new();
// ServerHello record
hello_pkt.push(TLS_RECORD_HANDSHAKE);
hello_pkt.extend_from_slice(&TLS_VERSION);
hello_pkt.extend_from_slice(&((srv_hello.len() + 4) as u16).to_be_bytes());
hello_pkt.push(0x02); // ServerHello message type
let len_bytes = (srv_hello.len() as u32).to_be_bytes();
hello_pkt.extend_from_slice(&len_bytes[1..4]); // 3-byte length
hello_pkt.extend_from_slice(&srv_hello);
// Change Cipher Spec record
hello_pkt.extend_from_slice(&[
TLS_RECORD_CHANGE_CIPHER,
TLS_VERSION[0], TLS_VERSION[1],
0x00, 0x01, 0x01
]);
// Application Data record (fake certificate)
let fake_cert = SECURE_RANDOM.bytes(fake_cert_len);
hello_pkt.push(TLS_RECORD_APPLICATION);
hello_pkt.extend_from_slice(&TLS_VERSION);
hello_pkt.extend_from_slice(&(fake_cert.len() as u16).to_be_bytes());
hello_pkt.extend_from_slice(&fake_cert);
// Compute HMAC for the response
let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + hello_pkt.len());
hmac_input.extend_from_slice(client_digest);
hmac_input.extend_from_slice(&hello_pkt);
let response_digest = sha256_hmac(secret, &hmac_input);
// Insert computed digest
// Position: after record header (5) + message type/length (4) + version (2) = 11
hello_pkt[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.copy_from_slice(&response_digest);
hello_pkt
}
/// Check if bytes look like a TLS ClientHello
pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
if first_bytes.len() < 3 {
return false;
}
// TLS record header: 0x16 0x03 0x01
first_bytes[0] == TLS_RECORD_HANDSHAKE
&& first_bytes[1] == 0x03
&& first_bytes[2] == 0x01
}
/// Parse TLS record header, returns (record_type, length)
pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> {
let record_type = header[0];
let version = [header[1], header[2]];
// We accept both TLS 1.0 header (for ClientHello) and TLS 1.2/1.3
if version != [0x03, 0x01] && version != TLS_VERSION {
return None;
}
let length = u16::from_be_bytes([header[3], header[4]]);
Some((record_type, length))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_tls_handshake() {
assert!(is_tls_handshake(&[0x16, 0x03, 0x01]));
assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00]));
assert!(!is_tls_handshake(&[0x17, 0x03, 0x01])); // Application data
assert!(!is_tls_handshake(&[0x16, 0x03, 0x02])); // Wrong version
assert!(!is_tls_handshake(&[0x16, 0x03])); // Too short
}
#[test]
fn test_parse_tls_record_header() {
let header = [0x16, 0x03, 0x01, 0x02, 0x00];
let result = parse_tls_record_header(&header).unwrap();
assert_eq!(result.0, TLS_RECORD_HANDSHAKE);
assert_eq!(result.1, 512);
let header = [0x17, 0x03, 0x03, 0x40, 0x00];
let result = parse_tls_record_header(&header).unwrap();
assert_eq!(result.0, TLS_RECORD_APPLICATION);
assert_eq!(result.1, 16384);
}
#[test]
fn test_gen_fake_x25519_key() {
let key1 = gen_fake_x25519_key();
let key2 = gen_fake_x25519_key();
assert_eq!(key1.len(), 32);
assert_eq!(key2.len(), 32);
assert_ne!(key1, key2); // Should be random
}
}

378
src/proxy/client.rs Normal file
View File

@@ -0,0 +1,378 @@
//! Client Handler
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout;
use tracing::{debug, info, warn, error, trace};
use crate::config::ProxyConfig;
use crate::error::{ProxyError, Result, HandshakeResult};
use crate::protocol::constants::*;
use crate::protocol::tls;
use crate::stats::{Stats, ReplayChecker};
use crate::transport::{ConnectionPool, configure_client_socket};
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter};
use crate::crypto::AesCtr;
use super::handshake::{
handle_tls_handshake, handle_mtproto_handshake,
HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce,
};
use super::relay::relay_bidirectional;
use super::masking::handle_bad_client;
/// Client connection handler
pub struct ClientHandler {
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
replay_checker: Arc<ReplayChecker>,
pool: Arc<ConnectionPool>,
}
impl ClientHandler {
/// Create new client handler
pub fn new(
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
replay_checker: Arc<ReplayChecker>,
pool: Arc<ConnectionPool>,
) -> Self {
Self {
config,
stats,
replay_checker,
pool,
}
}
/// Handle a client connection
pub async fn handle(&self, stream: TcpStream, peer: SocketAddr) {
self.stats.increment_connects_all();
debug!(peer = %peer, "New connection");
// Configure socket
if let Err(e) = configure_client_socket(
&stream,
self.config.client_keepalive,
self.config.client_ack_timeout,
) {
debug!(peer = %peer, error = %e, "Failed to configure client socket");
}
// Perform handshake with timeout
let handshake_timeout = Duration::from_secs(self.config.client_handshake_timeout);
let result = timeout(
handshake_timeout,
self.do_handshake(stream, peer)
).await;
match result {
Ok(Ok(())) => {
debug!(peer = %peer, "Connection handled successfully");
}
Ok(Err(e)) => {
debug!(peer = %peer, error = %e, "Handshake failed");
}
Err(_) => {
self.stats.increment_handshake_timeouts();
debug!(peer = %peer, "Handshake timeout");
}
}
}
/// Perform handshake and relay
async fn do_handshake(&self, mut stream: TcpStream, peer: SocketAddr) -> Result<()> {
// Read first bytes to determine handshake type
let mut first_bytes = [0u8; 5];
stream.read_exact(&mut first_bytes).await?;
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
debug!(peer = %peer, is_tls = is_tls, first_bytes = %hex::encode(&first_bytes), "Handshake type detected");
if is_tls {
self.handle_tls_client(stream, peer, first_bytes).await
} else {
self.handle_direct_client(stream, peer, first_bytes).await
}
}
/// Handle TLS-wrapped client
async fn handle_tls_client(
&self,
mut stream: TcpStream,
peer: SocketAddr,
first_bytes: [u8; 5],
) -> Result<()> {
// Read TLS handshake length
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake");
if tls_len < 512 {
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
self.stats.increment_connects_bad();
handle_bad_client(stream, &first_bytes, &self.config).await;
return Ok(());
}
// Read full TLS handshake
let mut handshake = vec![0u8; 5 + tls_len];
handshake[..5].copy_from_slice(&first_bytes);
stream.read_exact(&mut handshake[5..]).await?;
// Split stream for reading/writing
let (read_half, write_half) = stream.into_split();
// Handle TLS handshake
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
&handshake,
read_half,
write_half,
peer,
&self.config,
&self.replay_checker,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => {
self.stats.increment_connects_bad();
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
};
// Read MTProto handshake through TLS
debug!(peer = %peer, "Reading MTProto handshake through TLS");
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
// Handle MTProto handshake
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&mtproto_handshake,
tls_reader,
tls_writer,
peer,
&self.config,
&self.replay_checker,
true,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => {
self.stats.increment_connects_bad();
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
};
// Handle authenticated client
self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await
}
/// Handle direct (non-TLS) client
async fn handle_direct_client(
&self,
mut stream: TcpStream,
peer: SocketAddr,
first_bytes: [u8; 5],
) -> Result<()> {
// Check if non-TLS modes are enabled
if !self.config.modes.classic && !self.config.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled");
self.stats.increment_connects_bad();
handle_bad_client(stream, &first_bytes, &self.config).await;
return Ok(());
}
// Read rest of handshake
let mut handshake = [0u8; HANDSHAKE_LEN];
handshake[..5].copy_from_slice(&first_bytes);
stream.read_exact(&mut handshake[5..]).await?;
// Split stream
let (read_half, write_half) = stream.into_split();
// Handle MTProto handshake
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&handshake,
read_half,
write_half,
peer,
&self.config,
&self.replay_checker,
false,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => {
self.stats.increment_connects_bad();
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
};
self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await
}
/// Handle authenticated client - connect to Telegram and relay
async fn handle_authenticated_inner<R, W>(
&self,
client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>,
success: HandshakeSuccess,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let user = &success.user;
// Check user limits
if let Err(e) = self.check_user_limits(user) {
warn!(user = %user, error = %e, "User limit exceeded");
return Err(e);
}
// Get datacenter address
let dc_addr = self.get_dc_addr(success.dc_idx)?;
info!(
user = %user,
peer = %success.peer,
dc = success.dc_idx,
dc_addr = %dc_addr,
proto = ?success.proto_tag,
fast_mode = self.config.fast_mode,
"Connecting to Telegram"
);
// Connect to Telegram
let tg_stream = self.pool.get(dc_addr).await?;
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected to Telegram, performing handshake");
// Perform Telegram handshake and get crypto streams
let (tg_reader, tg_writer) = self.do_tg_handshake(
tg_stream,
&success,
).await?;
debug!(peer = %success.peer, "Telegram handshake complete, starting relay");
// Update stats
self.stats.increment_user_connects(user);
self.stats.increment_user_curr_connects(user);
// Relay traffic - передаём Arc::clone(&self.stats)
let relay_result = relay_bidirectional(
client_reader,
client_writer,
tg_reader,
tg_writer,
user,
Arc::clone(&self.stats),
).await;
// Update stats
self.stats.decrement_user_curr_connects(user);
match &relay_result {
Ok(()) => debug!(user = %user, peer = %success.peer, "Relay completed normally"),
Err(e) => debug!(user = %user, peer = %success.peer, error = %e, "Relay ended with error"),
}
relay_result
}
/// Check user limits (expiration, connection count, data quota)
fn check_user_limits(&self, user: &str) -> Result<()> {
// Check expiration
if let Some(expiration) = self.config.user_expirations.get(user) {
if chrono::Utc::now() > *expiration {
return Err(ProxyError::UserExpired { user: user.to_string() });
}
}
// Check connection limit
if let Some(limit) = self.config.user_max_tcp_conns.get(user) {
let current = self.stats.get_user_curr_connects(user);
if current >= *limit as u64 {
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
}
}
// Check data quota
if let Some(quota) = self.config.user_data_quota.get(user) {
let used = self.stats.get_user_total_octets(user);
if used >= *quota {
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
}
}
Ok(())
}
/// Get datacenter address by index
fn get_dc_addr(&self, dc_idx: i16) -> Result<SocketAddr> {
let idx = (dc_idx.abs() - 1) as usize;
let datacenters = if self.config.prefer_ipv6 {
&*TG_DATACENTERS_V6
} else {
&*TG_DATACENTERS_V4
};
datacenters.get(idx)
.map(|ip| SocketAddr::new(*ip, TG_DATACENTER_PORT))
.ok_or_else(|| ProxyError::InvalidHandshake(
format!("Invalid DC index: {}", dc_idx)
))
}
/// Perform handshake with Telegram server
/// Returns crypto reader and writer for TG connection
async fn do_tg_handshake(
&self,
mut stream: TcpStream,
success: &HandshakeSuccess,
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
// Generate nonce with keys for TG
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
success.proto_tag,
&success.dec_key, // Client's dec key
success.dec_iv,
self.config.fast_mode,
);
// Encrypt nonce
let encrypted_nonce = encrypt_tg_nonce(&nonce);
debug!(
peer = %success.peer,
nonce_head = %hex::encode(&nonce[..16]),
encrypted_head = %hex::encode(&encrypted_nonce[..16]),
"Sending nonce to Telegram"
);
// Send to Telegram
stream.write_all(&encrypted_nonce).await?;
stream.flush().await?;
debug!(peer = %success.peer, "Nonce sent to Telegram");
// Split stream and wrap with crypto
let (read_half, write_half) = stream.into_split();
let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv);
let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv);
let tg_reader = CryptoReader::new(read_half, decryptor);
let tg_writer = CryptoWriter::new(write_half, encryptor);
Ok((tg_reader, tg_writer))
}
}

411
src/proxy/handshake.rs Normal file
View File

@@ -0,0 +1,411 @@
//! MTProto Handshake Magics
use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, warn, trace, info};
use crate::crypto::{sha256, AesCtr};
use crate::crypto::random::SECURE_RANDOM;
use crate::protocol::constants::*;
use crate::protocol::tls;
use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
use crate::error::{ProxyError, HandshakeResult};
use crate::stats::ReplayChecker;
use crate::config::ProxyConfig;
/// Result of successful handshake
#[derive(Debug, Clone)]
pub struct HandshakeSuccess {
/// Authenticated user name
pub user: String,
/// Target datacenter index
pub dc_idx: i16,
/// Protocol variant (abridged/intermediate/secure)
pub proto_tag: ProtoTag,
/// Decryption key and IV (for reading from client)
pub dec_key: [u8; 32],
pub dec_iv: u128,
/// Encryption key and IV (for writing to client)
pub enc_key: [u8; 32],
pub enc_iv: u128,
/// Client address
pub peer: SocketAddr,
/// Whether TLS was used
pub is_tls: bool,
}
/// Handle fake TLS handshake
pub async fn handle_tls_handshake<R, W>(
handshake: &[u8],
reader: R,
mut writer: W,
peer: SocketAddr,
config: &ProxyConfig,
replay_checker: &ReplayChecker,
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String)>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
// Check minimum length
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient;
}
// Extract digest for replay check
let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN];
let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN];
// Check for replay
if replay_checker.check_tls_digest(digest_half) {
warn!(peer = %peer, "TLS replay attack detected");
return HandshakeResult::BadClient;
}
// Build secrets list
let secrets: Vec<(String, Vec<u8>)> = config.users.iter()
.filter_map(|(name, hex)| {
hex::decode(hex).ok().map(|bytes| (name.clone(), bytes))
})
.collect();
debug!(peer = %peer, num_users = secrets.len(), "Validating TLS handshake against users");
// Validate handshake
let validation = match tls::validate_tls_handshake(
handshake,
&secrets,
config.ignore_time_skew,
) {
Some(v) => v,
None => {
debug!(peer = %peer, "TLS handshake validation failed - no matching user");
return HandshakeResult::BadClient;
}
};
// Get secret for response
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s,
None => return HandshakeResult::BadClient,
};
// Build and send response
let response = tls::build_server_hello(
secret,
&validation.digest,
&validation.session_id,
config.fake_cert_len,
);
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
if let Err(e) = writer.write_all(&response).await {
return HandshakeResult::Error(ProxyError::Io(e));
}
if let Err(e) = writer.flush().await {
return HandshakeResult::Error(ProxyError::Io(e));
}
// Record for replay protection
replay_checker.add_tls_digest(digest_half);
info!(
peer = %peer,
user = %validation.user,
"TLS handshake successful"
);
HandshakeResult::Success((
FakeTlsReader::new(reader),
FakeTlsWriter::new(writer),
validation.user,
))
}
/// Handle MTProto obfuscation handshake
pub async fn handle_mtproto_handshake<R, W>(
handshake: &[u8; HANDSHAKE_LEN],
reader: R,
writer: W,
peer: SocketAddr,
config: &ProxyConfig,
replay_checker: &ReplayChecker,
is_tls: bool,
) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess)>
where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
{
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
// Extract prekey and IV
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
debug!(
peer = %peer,
dec_prekey_iv = %hex::encode(dec_prekey_iv),
"Extracted prekey+IV from handshake"
);
// Check for replay
if replay_checker.check_handshake(dec_prekey_iv) {
warn!(peer = %peer, "MTProto replay attack detected");
return HandshakeResult::BadClient;
}
// Reversed for encryption direction
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
// Try each user's secret
for (user, secret_hex) in &config.users {
let secret = match hex::decode(secret_hex) {
Ok(s) => s,
Err(_) => continue,
};
// Derive decryption key
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
// Decrypt handshake to check protocol tag
let mut decryptor = AesCtr::new(&dec_key, dec_iv);
let decrypted = decryptor.decrypt(handshake);
trace!(
peer = %peer,
user = %user,
decrypted_tail = %hex::encode(&decrypted[PROTO_TAG_POS..]),
"Decrypted handshake tail"
);
// Check protocol tag
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
.try_into()
.unwrap();
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
Some(tag) => tag,
None => {
trace!(peer = %peer, user = %user, tag = %hex::encode(tag_bytes), "Invalid proto tag");
continue;
}
};
debug!(peer = %peer, user = %user, proto = ?proto_tag, "Found valid proto tag");
// Check if mode is enabled
let mode_ok = match proto_tag {
ProtoTag::Secure => {
if is_tls { config.modes.tls } else { config.modes.secure }
}
ProtoTag::Intermediate | ProtoTag::Abridged => config.modes.classic,
};
if !mode_ok {
debug!(peer = %peer, user = %user, proto = ?proto_tag, "Mode not enabled");
continue;
}
// Extract DC index
let dc_idx = i16::from_le_bytes(
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
);
// Derive encryption key
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(&secret);
let enc_key = sha256(&enc_key_input);
let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
// Record for replay protection
replay_checker.add_handshake(dec_prekey_iv);
// Create new cipher instances
let decryptor = AesCtr::new(&dec_key, dec_iv);
let encryptor = AesCtr::new(&enc_key, enc_iv);
let success = HandshakeSuccess {
user: user.clone(),
dc_idx,
proto_tag,
dec_key,
dec_iv,
enc_key,
enc_iv,
peer,
is_tls,
};
info!(
peer = %peer,
user = %user,
dc = dc_idx,
proto = ?proto_tag,
tls = is_tls,
"MTProto handshake successful"
);
return HandshakeResult::Success((
CryptoReader::new(reader, decryptor),
CryptoWriter::new(writer, encryptor),
success,
));
}
debug!(peer = %peer, "MTProto handshake: no matching user found");
HandshakeResult::BadClient
}
/// Generate nonce for Telegram connection
///
/// In FAST MODE: we use the same keys for TG as for client, but reversed.
/// This means: client's enc_key becomes TG's dec_key and vice versa.
pub fn generate_tg_nonce(
proto_tag: ProtoTag,
client_dec_key: &[u8; 32],
client_dec_iv: u128,
fast_mode: bool,
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
loop {
let bytes = SECURE_RANDOM.bytes(HANDSHAKE_LEN);
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
// Check reserved patterns
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
continue;
}
let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
continue;
}
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
continue;
}
// Set protocol tag
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
// Fast mode: copy client's dec_key+iv (this becomes TG's enc direction)
// In fast mode, we make TG use the same keys as client but swapped:
// - What we decrypt FROM TG = what we encrypt TO client (so no re-encryption needed)
// - What we encrypt TO TG = what we decrypt FROM client
if fast_mode {
// Put client's dec_key + dec_iv into nonce[8:56]
// This will be used by TG for encryption TO us
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key);
nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN]
.copy_from_slice(&client_dec_iv.to_be_bytes());
}
// Now compute what keys WE will use for TG connection
// enc_key_iv = nonce[8:56] (for encrypting TO TG)
// dec_key_iv = nonce[8:56] reversed (for decrypting FROM TG)
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
let tg_enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
let tg_enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
debug!(
fast_mode = fast_mode,
tg_enc_key = %hex::encode(&tg_enc_key[..8]),
tg_dec_key = %hex::encode(&tg_dec_key[..8]),
"Generated TG nonce"
);
return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv);
}
}
/// Encrypt nonce for sending to Telegram
///
/// Only the part from PROTO_TAG_POS onwards is encrypted.
/// The encryption key is derived from enc_key_iv in the nonce itself.
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
// enc_key_iv is at nonce[8:56]
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
// Key for encrypting is just the first 32 bytes of enc_key_iv
let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let mut encryptor = AesCtr::new(&key, iv);
// Encrypt the entire nonce first, then take only the encrypted tail
let encrypted_full = encryptor.encrypt(nonce);
// Result: unencrypted head + encrypted tail
let mut result = nonce[..PROTO_TAG_POS].to_vec();
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
trace!(
original = %hex::encode(&nonce[PROTO_TAG_POS..]),
encrypted = %hex::encode(&result[PROTO_TAG_POS..]),
"Encrypted nonce tail"
);
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_tg_nonce() {
let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128;
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) =
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false);
// Check length
assert_eq!(nonce.len(), HANDSHAKE_LEN);
// Check proto tag is set
let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap();
assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure));
}
#[test]
fn test_encrypt_tg_nonce() {
let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128;
let (nonce, _, _, _, _) =
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false);
let encrypted = encrypt_tg_nonce(&nonce);
assert_eq!(encrypted.len(), HANDSHAKE_LEN);
// First PROTO_TAG_POS bytes should be unchanged
assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]);
// Rest should be different (encrypted)
assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]);
}
}

115
src/proxy/masking.rs Normal file
View File

@@ -0,0 +1,115 @@
//! Masking - forward unrecognized traffic to mask host
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout;
use tracing::debug;
use crate::config::ProxyConfig;
use crate::transport::set_linger_zero;
const MASK_TIMEOUT: Duration = Duration::from_secs(5);
const MASK_BUFFER_SIZE: usize = 8192;
/// Handle a bad client by forwarding to mask host
pub async fn handle_bad_client(
mut client: TcpStream,
initial_data: &[u8],
config: &ProxyConfig,
) {
if !config.mask {
// Masking disabled, just consume data
consume_client_data(client).await;
return;
}
let mask_host = config.mask_host.as_deref()
.unwrap_or(&config.tls_domain);
let mask_port = config.mask_port;
debug!(
host = %mask_host,
port = mask_port,
"Forwarding bad client to mask host"
);
// Connect to mask host
let mask_addr = format!("{}:{}", mask_host, mask_port);
let connect_result = timeout(
MASK_TIMEOUT,
TcpStream::connect(&mask_addr)
).await;
let mut mask_stream = match connect_result {
Ok(Ok(s)) => s,
Ok(Err(e)) => {
debug!(error = %e, "Failed to connect to mask host");
consume_client_data(client).await;
return;
}
Err(_) => {
debug!("Timeout connecting to mask host");
consume_client_data(client).await;
return;
}
};
// Send initial data to mask host
if mask_stream.write_all(initial_data).await.is_err() {
return;
}
// Relay traffic
let (mut client_read, mut client_write) = client.into_split();
let (mut mask_read, mut mask_write) = mask_stream.into_split();
let c2m = tokio::spawn(async move {
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
loop {
match client_read.read(&mut buf).await {
Ok(0) | Err(_) => {
let _ = mask_write.shutdown().await;
break;
}
Ok(n) => {
if mask_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
}
}
});
let m2c = tokio::spawn(async move {
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
loop {
match mask_read.read(&mut buf).await {
Ok(0) | Err(_) => {
let _ = client_write.shutdown().await;
break;
}
Ok(n) => {
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
}
}
});
// Wait for either to complete
tokio::select! {
_ = c2m => {}
_ = m2c => {}
}
}
/// Just consume all data from client without responding
async fn consume_client_data(mut client: TcpStream) {
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
while let Ok(n) = client.read(&mut buf).await {
if n == 0 {
break;
}
}
}

11
src/proxy/mod.rs Normal file
View File

@@ -0,0 +1,11 @@
//! Proxy Defs
pub mod handshake;
pub mod client;
pub mod relay;
pub mod masking;
pub use handshake::*;
pub use client::ClientHandler;
pub use relay::*;
pub use masking::*;

162
src/proxy/relay.rs Normal file
View File

@@ -0,0 +1,162 @@
//! Bidirectional Relay
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tracing::{debug, trace, warn};
use crate::error::Result;
use crate::stats::Stats;
use std::sync::atomic::{AtomicU64, Ordering};
const BUFFER_SIZE: usize = 65536;
/// Relay data bidirectionally between client and server
pub async fn relay_bidirectional<CR, CW, SR, SW>(
mut client_reader: CR,
mut client_writer: CW,
mut server_reader: SR,
mut server_writer: SW,
user: &str,
stats: Arc<Stats>,
) -> Result<()>
where
CR: AsyncRead + Unpin + Send + 'static,
CW: AsyncWrite + Unpin + Send + 'static,
SR: AsyncRead + Unpin + Send + 'static,
SW: AsyncWrite + Unpin + Send + 'static,
{
let user_c2s = user.to_string();
let user_s2c = user.to_string();
// Используем Arc::clone вместо stats.clone()
let stats_c2s = Arc::clone(&stats);
let stats_s2c = Arc::clone(&stats);
let c2s_bytes = Arc::new(AtomicU64::new(0));
let s2c_bytes = Arc::new(AtomicU64::new(0));
let c2s_bytes_clone = Arc::clone(&c2s_bytes);
let s2c_bytes_clone = Arc::clone(&s2c_bytes);
// Client -> Server task
let c2s = tokio::spawn(async move {
let mut buf = vec![0u8; BUFFER_SIZE];
let mut total_bytes = 0u64;
let mut msg_count = 0u64;
loop {
match client_reader.read(&mut buf).await {
Ok(0) => {
debug!(
user = %user_c2s,
total_bytes = total_bytes,
msgs = msg_count,
"Client closed connection (C->S)"
);
let _ = server_writer.shutdown().await;
break;
}
Ok(n) => {
total_bytes += n as u64;
msg_count += 1;
c2s_bytes_clone.store(total_bytes, Ordering::Relaxed);
stats_c2s.add_user_octets_from(&user_c2s, n as u64);
stats_c2s.increment_user_msgs_from(&user_c2s);
trace!(
user = %user_c2s,
bytes = n,
total = total_bytes,
data_preview = %hex::encode(&buf[..n.min(32)]),
"C->S data"
);
if let Err(e) = server_writer.write_all(&buf[..n]).await {
debug!(user = %user_c2s, error = %e, "Failed to write to server");
break;
}
if let Err(e) = server_writer.flush().await {
debug!(user = %user_c2s, error = %e, "Failed to flush to server");
break;
}
}
Err(e) => {
debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error");
break;
}
}
}
});
// Server -> Client task
let s2c = tokio::spawn(async move {
let mut buf = vec![0u8; BUFFER_SIZE];
let mut total_bytes = 0u64;
let mut msg_count = 0u64;
loop {
match server_reader.read(&mut buf).await {
Ok(0) => {
debug!(
user = %user_s2c,
total_bytes = total_bytes,
msgs = msg_count,
"Server closed connection (S->C)"
);
let _ = client_writer.shutdown().await;
break;
}
Ok(n) => {
total_bytes += n as u64;
msg_count += 1;
s2c_bytes_clone.store(total_bytes, Ordering::Relaxed);
stats_s2c.add_user_octets_to(&user_s2c, n as u64);
stats_s2c.increment_user_msgs_to(&user_s2c);
trace!(
user = %user_s2c,
bytes = n,
total = total_bytes,
data_preview = %hex::encode(&buf[..n.min(32)]),
"S->C data"
);
if let Err(e) = client_writer.write_all(&buf[..n]).await {
debug!(user = %user_s2c, error = %e, "Failed to write to client");
break;
}
if let Err(e) = client_writer.flush().await {
debug!(user = %user_s2c, error = %e, "Failed to flush to client");
break;
}
}
Err(e) => {
debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error");
break;
}
}
}
});
// Wait for either direction to complete
tokio::select! {
result = c2s => {
if let Err(e) = result {
warn!(error = %e, "C->S task panicked");
}
}
result = s2c => {
if let Err(e) = result {
warn!(error = %e, "S->C task panicked");
}
}
}
debug!(
c2s_bytes = c2s_bytes.load(Ordering::Relaxed),
s2c_bytes = s2c_bytes.load(Ordering::Relaxed),
"Relay finished"
);
Ok(())
}

223
src/stats/mod.rs Normal file
View File

@@ -0,0 +1,223 @@
//! Statistics
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use dashmap::DashMap;
use parking_lot::RwLock;
use lru::LruCache;
use std::num::NonZeroUsize;
/// Thread-safe statistics
#[derive(Default)]
pub struct Stats {
// Global counters
connects_all: AtomicU64,
connects_bad: AtomicU64,
handshake_timeouts: AtomicU64,
// Per-user stats
user_stats: DashMap<String, UserStats>,
// Start time
start_time: RwLock<Option<Instant>>,
}
/// Per-user statistics
#[derive(Default)]
pub struct UserStats {
pub connects: AtomicU64,
pub curr_connects: AtomicU64,
pub octets_from_client: AtomicU64,
pub octets_to_client: AtomicU64,
pub msgs_from_client: AtomicU64,
pub msgs_to_client: AtomicU64,
}
impl Stats {
pub fn new() -> Self {
let stats = Self::default();
*stats.start_time.write() = Some(Instant::now());
stats
}
// Global stats
pub fn increment_connects_all(&self) {
self.connects_all.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_connects_bad(&self) {
self.connects_bad.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_handshake_timeouts(&self) {
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
}
pub fn get_connects_all(&self) -> u64 {
self.connects_all.load(Ordering::Relaxed)
}
pub fn get_connects_bad(&self) -> u64 {
self.connects_bad.load(Ordering::Relaxed)
}
// User stats
pub fn increment_user_connects(&self, user: &str) {
self.user_stats
.entry(user.to_string())
.or_default()
.connects
.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_user_curr_connects(&self, user: &str) {
self.user_stats
.entry(user.to_string())
.or_default()
.curr_connects
.fetch_add(1, Ordering::Relaxed);
}
pub fn decrement_user_curr_connects(&self, user: &str) {
if let Some(stats) = self.user_stats.get(user) {
stats.curr_connects.fetch_sub(1, Ordering::Relaxed);
}
}
pub fn get_user_curr_connects(&self, user: &str) -> u64 {
self.user_stats
.get(user)
.map(|s| s.curr_connects.load(Ordering::Relaxed))
.unwrap_or(0)
}
pub fn add_user_octets_from(&self, user: &str, bytes: u64) {
self.user_stats
.entry(user.to_string())
.or_default()
.octets_from_client
.fetch_add(bytes, Ordering::Relaxed);
}
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
self.user_stats
.entry(user.to_string())
.or_default()
.octets_to_client
.fetch_add(bytes, Ordering::Relaxed);
}
pub fn increment_user_msgs_from(&self, user: &str) {
self.user_stats
.entry(user.to_string())
.or_default()
.msgs_from_client
.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_user_msgs_to(&self, user: &str) {
self.user_stats
.entry(user.to_string())
.or_default()
.msgs_to_client
.fetch_add(1, Ordering::Relaxed);
}
pub fn get_user_total_octets(&self, user: &str) -> u64 {
self.user_stats
.get(user)
.map(|s| {
s.octets_from_client.load(Ordering::Relaxed) +
s.octets_to_client.load(Ordering::Relaxed)
})
.unwrap_or(0)
}
pub fn uptime_secs(&self) -> f64 {
self.start_time.read()
.map(|t| t.elapsed().as_secs_f64())
.unwrap_or(0.0)
}
}
// Arc<Stats> Hightech Stats :D
/// Replay attack checker using LRU cache
pub struct ReplayChecker {
handshakes: RwLock<LruCache<Vec<u8>, ()>>,
tls_digests: RwLock<LruCache<Vec<u8>, ()>>,
}
impl ReplayChecker {
pub fn new(capacity: usize) -> Self {
let cap = NonZeroUsize::new(capacity.max(1)).unwrap();
Self {
handshakes: RwLock::new(LruCache::new(cap)),
tls_digests: RwLock::new(LruCache::new(cap)),
}
}
pub fn check_handshake(&self, data: &[u8]) -> bool {
self.handshakes.read().contains(&data.to_vec())
}
pub fn add_handshake(&self, data: &[u8]) {
self.handshakes.write().put(data.to_vec(), ());
}
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
self.tls_digests.read().contains(&data.to_vec())
}
pub fn add_tls_digest(&self, data: &[u8]) {
self.tls_digests.write().put(data.to_vec(), ());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stats_shared_counters() {
let stats = Arc::new(Stats::new());
// Симулируем использование из разных "задач"
let stats1 = Arc::clone(&stats);
let stats2 = Arc::clone(&stats);
stats1.increment_connects_all();
stats2.increment_connects_all();
stats1.increment_connects_all();
// Все инкременты должны быть видны
assert_eq!(stats.get_connects_all(), 3);
}
#[test]
fn test_user_stats_shared() {
let stats = Arc::new(Stats::new());
let stats1 = Arc::clone(&stats);
let stats2 = Arc::clone(&stats);
stats1.add_user_octets_from("user1", 100);
stats2.add_user_octets_from("user1", 200);
stats1.add_user_octets_to("user1", 50);
assert_eq!(stats.get_user_total_octets("user1"), 350);
}
#[test]
fn test_concurrent_user_connects() {
let stats = Arc::new(Stats::new());
stats.increment_user_curr_connects("user1");
stats.increment_user_curr_connects("user1");
assert_eq!(stats.get_user_curr_connects("user1"), 2);
stats.decrement_user_curr_connects("user1");
assert_eq!(stats.get_user_curr_connects("user1"), 1);
}
}

474
src/stream/crypto_stream.rs Normal file
View File

@@ -0,0 +1,474 @@
//! Encrypted stream wrappers using AES-CTR
use bytes::{Bytes, BytesMut, BufMut};
use std::io::{Error, ErrorKind, Result};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
use crate::crypto::AesCtr;
use parking_lot::Mutex;
/// Reader that decrypts data using AES-CTR
pub struct CryptoReader<R> {
upstream: R,
decryptor: AesCtr,
buffer: BytesMut,
}
impl<R> CryptoReader<R> {
/// Create new crypto reader
pub fn new(upstream: R, decryptor: AesCtr) -> Self {
Self {
upstream,
decryptor,
buffer: BytesMut::with_capacity(8192),
}
}
/// Get reference to upstream
pub fn get_ref(&self) -> &R {
&self.upstream
}
/// Get mutable reference to upstream
pub fn get_mut(&mut self) -> &mut R {
&mut self.upstream
}
/// Consume and return upstream
pub fn into_inner(self) -> R {
self.upstream
}
}
impl<R: AsyncRead + Unpin> AsyncRead for CryptoReader<R> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
let this = self.get_mut();
if !this.buffer.is_empty() {
let to_copy = this.buffer.len().min(buf.remaining());
buf.put_slice(&this.buffer.split_to(to_copy));
return Poll::Ready(Ok(()));
}
// Zero-copy Reader
let before = buf.filled().len();
match Pin::new(&mut this.upstream).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let after = buf.filled().len();
let bytes_read = after - before;
if bytes_read > 0 {
// Decrypt in-place
let filled = buf.filled_mut();
this.decryptor.apply(&mut filled[before..after]);
}
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
impl<R: AsyncRead + Unpin> CryptoReader<R> {
/// Read and decrypt exactly n bytes with Async
pub async fn read_exact_decrypt(&mut self, n: usize) -> Result<Bytes> {
let mut result = BytesMut::with_capacity(n);
if !self.buffer.is_empty() {
let to_take = self.buffer.len().min(n);
result.extend_from_slice(&self.buffer.split_to(to_take));
}
// Reread
while result.len() < n {
let mut temp = vec![0u8; n - result.len()];
let read = self.upstream.read(&mut temp).await?;
if read == 0 {
return Err(Error::new(ErrorKind::UnexpectedEof, "Connection closed"));
}
// Decrypt
self.decryptor.apply(&mut temp[..read]);
result.extend_from_slice(&temp[..read]);
}
Ok(result.freeze())
}
}
/// Writer that encrypts data using AES-CTR
pub struct CryptoWriter<W> {
upstream: W,
encryptor: AesCtr,
pending: BytesMut,
}
impl<W> CryptoWriter<W> {
pub fn new(upstream: W, encryptor: AesCtr) -> Self {
Self {
upstream,
encryptor,
pending: BytesMut::with_capacity(8192),
}
}
pub fn get_ref(&self) -> &W {
&self.upstream
}
pub fn get_mut(&mut self) -> &mut W {
&mut self.upstream
}
pub fn into_inner(self) -> W {
self.upstream
}
}
impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
let this = self.get_mut();
if !this.pending.is_empty() {
match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) {
Poll::Ready(Ok(written)) => {
let _ = this.pending.split_to(written);
if !this.pending.is_empty() {
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
// Pending Null
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
// Encrypt
let mut encrypted = buf.to_vec();
this.encryptor.apply(&mut encrypted);
// Write Try
match Pin::new(&mut this.upstream).poll_write(cx, &encrypted) {
Poll::Ready(Ok(written)) => {
if written < encrypted.len() {
// Partial write — сохраняем остаток в pending
this.pending.extend_from_slice(&encrypted[written..]);
}
Poll::Ready(Ok(buf.len()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => {
this.pending.extend_from_slice(&encrypted);
Poll::Ready(Ok(buf.len()))
}
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let this = self.get_mut();
while !this.pending.is_empty() {
match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) {
Poll::Ready(Ok(0)) => {
return Poll::Ready(Err(Error::new(
ErrorKind::WriteZero,
"Failed to write pending data during flush",
)));
}
Poll::Ready(Ok(written)) => {
let _ = this.pending.split_to(written);
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
Pin::new(&mut this.upstream).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let this = self.get_mut();
while !this.pending.is_empty() {
match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) {
Poll::Ready(Ok(0)) => {
break;
}
Poll::Ready(Ok(written)) => {
let _ = this.pending.split_to(written);
}
Poll::Ready(Err(_)) => {
break;
}
Poll::Pending => return Poll::Pending,
}
}
Pin::new(&mut this.upstream).poll_shutdown(cx)
}
}
/// Passthrough stream for fast mode - no encryption/decryption
pub struct PassthroughStream<S> {
inner: S,
}
impl<S> PassthroughStream<S> {
pub fn new(inner: S) -> Self {
Self { inner }
}
pub fn into_inner(self) -> S {
self.inner
}
}
impl<S: AsyncRead + Unpin> AsyncRead for PassthroughStream<S> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl<S: AsyncWrite + Unpin> AsyncWrite for PassthroughStream<S> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::{Context, Poll, Waker, RawWaker, RawWakerVTable};
use tokio::io::duplex;
/// Mock writer
struct PartialWriter {
chunk_size: usize,
data: Vec<u8>,
write_count: usize,
}
impl PartialWriter {
fn new(chunk_size: usize) -> Self {
Self {
chunk_size,
data: Vec::new(),
write_count: 0,
}
}
}
impl AsyncWrite for PartialWriter {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
self.write_count += 1;
let to_write = buf.len().min(self.chunk_size);
self.data.extend_from_slice(&buf[..to_write]);
Poll::Ready(Ok(to_write))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
}
fn noop_waker() -> Waker {
const VTABLE: RawWakerVTable = RawWakerVTable::new(
|_| RawWaker::new(std::ptr::null(), &VTABLE),
|_| {},
|_| {},
|_| {},
);
unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
}
#[test]
fn test_crypto_writer_partial_write_correctness() {
let key = [0x42u8; 32];
let iv = 12345u128;
// 10-byte Writer
let mock_writer = PartialWriter::new(10);
let encryptor = AesCtr::new(&key, iv);
let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
// 25 byte
let original = b"Hello, this is test data!";
// First Write
let result = Pin::new(&mut crypto_writer).poll_write(&mut cx, original);
assert!(matches!(result, Poll::Ready(Ok(25))));
// Flush before continue Pending
loop {
match Pin::new(&mut crypto_writer).poll_flush(&mut cx) {
Poll::Ready(Ok(())) => break,
Poll::Ready(Err(e)) => panic!("Flush error: {}", e),
Poll::Pending => continue,
}
}
// Write Check
let encrypted = &crypto_writer.upstream.data;
assert_eq!(encrypted.len(), 25);
// Decrypt + Verify
let mut decryptor = AesCtr::new(&key, iv);
let mut decrypted = encrypted.clone();
decryptor.apply(&mut decrypted);
assert_eq!(&decrypted, original);
}
#[test]
fn test_crypto_writer_multiple_partial_writes() {
let key = [0xAB; 32];
let iv = 9999u128;
let mock_writer = PartialWriter::new(3);
let encryptor = AesCtr::new(&key, iv);
let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let data1 = b"First";
let data2 = b"Second";
let data3 = b"Third";
Pin::new(&mut crypto_writer).poll_write(&mut cx, data1).unwrap();
// Flush
while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {}
Pin::new(&mut crypto_writer).poll_write(&mut cx, data2).unwrap();
while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {}
Pin::new(&mut crypto_writer).poll_write(&mut cx, data3).unwrap();
while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {}
// Assemble
let mut expected = Vec::new();
expected.extend_from_slice(data1);
expected.extend_from_slice(data2);
expected.extend_from_slice(data3);
// Decrypt
let mut decryptor = AesCtr::new(&key, iv);
let mut decrypted = crypto_writer.upstream.data.clone();
decryptor.apply(&mut decrypted);
assert_eq!(decrypted, expected);
}
#[tokio::test]
async fn test_crypto_stream_roundtrip() {
let key = [0u8; 32];
let iv = 12345u128;
let (client, server) = duplex(4096);
let encryptor = AesCtr::new(&key, iv);
let decryptor = AesCtr::new(&key, iv);
let mut writer = CryptoWriter::new(client, encryptor);
let mut reader = CryptoReader::new(server, decryptor);
// Write
let original = b"Hello, encrypted world!";
writer.write_all(original).await.unwrap();
writer.flush().await.unwrap();
// Read
let mut buf = vec![0u8; original.len()];
reader.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, original);
}
#[tokio::test]
async fn test_crypto_stream_large_data() {
let key = [0x55u8; 32];
let iv = 777u128;
let (client, server) = duplex(1024);
let encryptor = AesCtr::new(&key, iv);
let decryptor = AesCtr::new(&key, iv);
let mut writer = CryptoWriter::new(client, encryptor);
let mut reader = CryptoReader::new(server, decryptor);
// Hugeload
let original: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
// Write
let write_data = original.clone();
let write_handle = tokio::spawn(async move {
writer.write_all(&write_data).await.unwrap();
writer.flush().await.unwrap();
writer.shutdown().await.unwrap();
});
// Read
let mut received = Vec::new();
let mut buf = vec![0u8; 1024];
loop {
match reader.read(&mut buf).await {
Ok(0) => break,
Ok(n) => received.extend_from_slice(&buf[..n]),
Err(e) => panic!("Read error: {}", e),
}
}
write_handle.await.unwrap();
assert_eq!(received, original);
}
}

585
src/stream/frame_stream.rs Normal file
View File

@@ -0,0 +1,585 @@
//! MTProto frame stream wrappers
use bytes::{Bytes, BytesMut};
use std::io::{Error, ErrorKind, Result};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use crate::protocol::constants::*;
use crate::crypto::crc32;
use crate::crypto::random::SECURE_RANDOM;
use super::traits::{FrameMeta, LayeredStream};
// ============= Abridged (Compact) Frame =============
/// Reader for abridged MTProto framing
pub struct AbridgedFrameReader<R> {
upstream: R,
}
impl<R> AbridgedFrameReader<R> {
pub fn new(upstream: R) -> Self {
Self { upstream }
}
}
impl<R: AsyncRead + Unpin> AbridgedFrameReader<R> {
/// Read a frame and return (data, metadata)
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
let mut meta = FrameMeta::new();
// Read length byte
let mut len_byte = [0u8];
self.upstream.read_exact(&mut len_byte).await?;
let mut len = len_byte[0] as usize;
// Check QuickACK flag (high bit)
if len >= 0x80 {
meta.quickack = true;
len -= 0x80;
}
// Extended length (3 bytes)
if len == 0x7f {
let mut len_bytes = [0u8; 3];
self.upstream.read_exact(&mut len_bytes).await?;
len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize;
}
// Length is in 4-byte words
let byte_len = len * 4;
// Read data
let mut data = vec![0u8; byte_len];
self.upstream.read_exact(&mut data).await?;
Ok((Bytes::from(data), meta))
}
}
impl<R> LayeredStream<R> for AbridgedFrameReader<R> {
fn upstream(&self) -> &R { &self.upstream }
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
fn into_upstream(self) -> R { self.upstream }
}
/// Writer for abridged MTProto framing
pub struct AbridgedFrameWriter<W> {
upstream: W,
}
impl<W> AbridgedFrameWriter<W> {
pub fn new(upstream: W) -> Self {
Self { upstream }
}
}
impl<W: AsyncWrite + Unpin> AbridgedFrameWriter<W> {
/// Write a frame
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
if data.len() % 4 != 0 {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("Abridged frame must be aligned to 4 bytes, got {}", data.len()),
));
}
// Simple ACK: send reversed data
if meta.simple_ack {
let reversed: Vec<u8> = data.iter().rev().copied().collect();
self.upstream.write_all(&reversed).await?;
return Ok(());
}
let len_div_4 = data.len() / 4;
if len_div_4 < 0x7f {
// Short length (1 byte)
self.upstream.write_all(&[len_div_4 as u8]).await?;
} else if len_div_4 < (1 << 24) {
// Long length (4 bytes: 0x7f + 3 bytes)
let mut header = [0x7f, 0, 0, 0];
header[1..4].copy_from_slice(&(len_div_4 as u32).to_le_bytes()[..3]);
self.upstream.write_all(&header).await?;
} else {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("Frame too large: {} bytes", data.len()),
));
}
self.upstream.write_all(data).await?;
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
self.upstream.flush().await
}
}
impl<W> LayeredStream<W> for AbridgedFrameWriter<W> {
fn upstream(&self) -> &W { &self.upstream }
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
fn into_upstream(self) -> W { self.upstream }
}
// ============= Intermediate Frame =============
/// Reader for intermediate MTProto framing
pub struct IntermediateFrameReader<R> {
upstream: R,
}
impl<R> IntermediateFrameReader<R> {
pub fn new(upstream: R) -> Self {
Self { upstream }
}
}
impl<R: AsyncRead + Unpin> IntermediateFrameReader<R> {
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
let mut meta = FrameMeta::new();
// Read 4-byte length
let mut len_bytes = [0u8; 4];
self.upstream.read_exact(&mut len_bytes).await?;
let mut len = u32::from_le_bytes(len_bytes) as usize;
// Check QuickACK flag (high bit)
if len > 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
// Read data
let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?;
Ok((Bytes::from(data), meta))
}
}
impl<R> LayeredStream<R> for IntermediateFrameReader<R> {
fn upstream(&self) -> &R { &self.upstream }
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
fn into_upstream(self) -> R { self.upstream }
}
/// Writer for intermediate MTProto framing
pub struct IntermediateFrameWriter<W> {
upstream: W,
}
impl<W> IntermediateFrameWriter<W> {
pub fn new(upstream: W) -> Self {
Self { upstream }
}
}
impl<W: AsyncWrite + Unpin> IntermediateFrameWriter<W> {
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
if meta.simple_ack {
self.upstream.write_all(data).await?;
} else {
let len_bytes = (data.len() as u32).to_le_bytes();
self.upstream.write_all(&len_bytes).await?;
self.upstream.write_all(data).await?;
}
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
self.upstream.flush().await
}
}
impl<W> LayeredStream<W> for IntermediateFrameWriter<W> {
fn upstream(&self) -> &W { &self.upstream }
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
fn into_upstream(self) -> W { self.upstream }
}
// ============= Secure Intermediate Frame =============
/// Reader for secure intermediate MTProto framing (with padding)
pub struct SecureIntermediateFrameReader<R> {
upstream: R,
}
impl<R> SecureIntermediateFrameReader<R> {
pub fn new(upstream: R) -> Self {
Self { upstream }
}
}
impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
let mut meta = FrameMeta::new();
// Read 4-byte length
let mut len_bytes = [0u8; 4];
self.upstream.read_exact(&mut len_bytes).await?;
let mut len = u32::from_le_bytes(len_bytes) as usize;
// Check QuickACK flag
if len > 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
// Read data (including padding)
let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?;
// Strip padding (not aligned to 4)
if len % 4 != 0 {
let actual_len = len - (len % 4);
data.truncate(actual_len);
}
Ok((Bytes::from(data), meta))
}
}
impl<R> LayeredStream<R> for SecureIntermediateFrameReader<R> {
fn upstream(&self) -> &R { &self.upstream }
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
fn into_upstream(self) -> R { self.upstream }
}
/// Writer for secure intermediate MTProto framing
pub struct SecureIntermediateFrameWriter<W> {
upstream: W,
}
impl<W> SecureIntermediateFrameWriter<W> {
pub fn new(upstream: W) -> Self {
Self { upstream }
}
}
impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
if meta.simple_ack {
self.upstream.write_all(data).await?;
return Ok(());
}
// Add random padding (0-3 bytes)
let padding_len = SECURE_RANDOM.range(4);
let padding = SECURE_RANDOM.bytes(padding_len);
let total_len = data.len() + padding_len;
let len_bytes = (total_len as u32).to_le_bytes();
self.upstream.write_all(&len_bytes).await?;
self.upstream.write_all(data).await?;
self.upstream.write_all(&padding).await?;
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
self.upstream.flush().await
}
}
impl<W> LayeredStream<W> for SecureIntermediateFrameWriter<W> {
fn upstream(&self) -> &W { &self.upstream }
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
fn into_upstream(self) -> W { self.upstream }
}
// ============= Full MTProto Frame (with CRC) =============
/// Reader for full MTProto framing with sequence numbers and CRC32
pub struct MtprotoFrameReader<R> {
upstream: R,
seq_no: i32,
}
impl<R> MtprotoFrameReader<R> {
pub fn new(upstream: R, start_seq: i32) -> Self {
Self { upstream, seq_no: start_seq }
}
}
impl<R: AsyncRead + Unpin> MtprotoFrameReader<R> {
pub async fn read_frame(&mut self) -> Result<Bytes> {
loop {
// Read length (4 bytes)
let mut len_bytes = [0u8; 4];
self.upstream.read_exact(&mut len_bytes).await?;
let len = u32::from_le_bytes(len_bytes) as usize;
// Skip padding-only messages
if len == 4 {
continue;
}
// Validate length
if len < MIN_MSG_LEN || len > MAX_MSG_LEN || len % PADDING_FILLER.len() != 0 {
return Err(Error::new(
ErrorKind::InvalidData,
format!("Invalid message length: {}", len),
));
}
// Read sequence number
let mut seq_bytes = [0u8; 4];
self.upstream.read_exact(&mut seq_bytes).await?;
let msg_seq = i32::from_le_bytes(seq_bytes);
if msg_seq != self.seq_no {
return Err(Error::new(
ErrorKind::InvalidData,
format!("Sequence mismatch: expected {}, got {}", self.seq_no, msg_seq),
));
}
self.seq_no += 1;
// Read data (length - 4 len - 4 seq - 4 crc = len - 12)
let data_len = len - 12;
let mut data = vec![0u8; data_len];
self.upstream.read_exact(&mut data).await?;
// Read and verify CRC32
let mut crc_bytes = [0u8; 4];
self.upstream.read_exact(&mut crc_bytes).await?;
let expected_crc = u32::from_le_bytes(crc_bytes);
// Compute CRC over len + seq + data
let mut crc_input = Vec::with_capacity(8 + data_len);
crc_input.extend_from_slice(&len_bytes);
crc_input.extend_from_slice(&seq_bytes);
crc_input.extend_from_slice(&data);
let computed_crc = crc32(&crc_input);
if computed_crc != expected_crc {
return Err(Error::new(
ErrorKind::InvalidData,
format!("CRC mismatch: expected {:08x}, got {:08x}", expected_crc, computed_crc),
));
}
return Ok(Bytes::from(data));
}
}
}
/// Writer for full MTProto framing
pub struct MtprotoFrameWriter<W> {
upstream: W,
seq_no: i32,
}
impl<W> MtprotoFrameWriter<W> {
pub fn new(upstream: W, start_seq: i32) -> Self {
Self { upstream, seq_no: start_seq }
}
}
impl<W: AsyncWrite + Unpin> MtprotoFrameWriter<W> {
pub async fn write_frame(&mut self, msg: &[u8]) -> Result<()> {
// Total length: 4 (len) + 4 (seq) + data + 4 (crc)
let len = msg.len() + 12;
let len_bytes = (len as u32).to_le_bytes();
let seq_bytes = self.seq_no.to_le_bytes();
self.seq_no += 1;
// Compute CRC
let mut crc_input = Vec::with_capacity(8 + msg.len());
crc_input.extend_from_slice(&len_bytes);
crc_input.extend_from_slice(&seq_bytes);
crc_input.extend_from_slice(msg);
let checksum = crc32(&crc_input);
let crc_bytes = checksum.to_le_bytes();
// Calculate padding for CBC alignment
let total_len = len_bytes.len() + seq_bytes.len() + msg.len() + crc_bytes.len();
let padding_needed = (CBC_PADDING - (total_len % CBC_PADDING)) % CBC_PADDING;
let padding_count = padding_needed / PADDING_FILLER.len();
// Write everything
self.upstream.write_all(&len_bytes).await?;
self.upstream.write_all(&seq_bytes).await?;
self.upstream.write_all(msg).await?;
self.upstream.write_all(&crc_bytes).await?;
for _ in 0..padding_count {
self.upstream.write_all(&PADDING_FILLER).await?;
}
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
self.upstream.flush().await
}
}
// ============= Frame Type Enum =============
/// Enum for different frame stream types
pub enum FrameReaderKind<R> {
Abridged(AbridgedFrameReader<R>),
Intermediate(IntermediateFrameReader<R>),
SecureIntermediate(SecureIntermediateFrameReader<R>),
}
impl<R: AsyncRead + Unpin> FrameReaderKind<R> {
pub fn new(upstream: R, proto_tag: ProtoTag) -> Self {
match proto_tag {
ProtoTag::Abridged => FrameReaderKind::Abridged(AbridgedFrameReader::new(upstream)),
ProtoTag::Intermediate => FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream)),
ProtoTag::Secure => FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream)),
}
}
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
match self {
FrameReaderKind::Abridged(r) => r.read_frame().await,
FrameReaderKind::Intermediate(r) => r.read_frame().await,
FrameReaderKind::SecureIntermediate(r) => r.read_frame().await,
}
}
}
pub enum FrameWriterKind<W> {
Abridged(AbridgedFrameWriter<W>),
Intermediate(IntermediateFrameWriter<W>),
SecureIntermediate(SecureIntermediateFrameWriter<W>),
}
impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
pub fn new(upstream: W, proto_tag: ProtoTag) -> Self {
match proto_tag {
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream)),
}
}
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
match self {
FrameWriterKind::Abridged(w) => w.write_frame(data, meta).await,
FrameWriterKind::Intermediate(w) => w.write_frame(data, meta).await,
FrameWriterKind::SecureIntermediate(w) => w.write_frame(data, meta).await,
}
}
pub async fn flush(&mut self) -> Result<()> {
match self {
FrameWriterKind::Abridged(w) => w.flush().await,
FrameWriterKind::Intermediate(w) => w.flush().await,
FrameWriterKind::SecureIntermediate(w) => w.flush().await,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
#[tokio::test]
async fn test_abridged_roundtrip() {
let (client, server) = duplex(1024);
let mut writer = AbridgedFrameWriter::new(client);
let mut reader = AbridgedFrameReader::new(server);
// Short frame
let data = vec![1u8, 2, 3, 4]; // 4 bytes = 1 word
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &data[..]);
}
#[tokio::test]
async fn test_abridged_long_frame() {
let (client, server) = duplex(65536);
let mut writer = AbridgedFrameWriter::new(client);
let mut reader = AbridgedFrameReader::new(server);
// Long frame (> 0x7f words = 508 bytes)
let data: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
let padded_len = (data.len() + 3) / 4 * 4;
let mut padded = data.clone();
padded.resize(padded_len, 0);
writer.write_frame(&padded, &FrameMeta::new()).await.unwrap();
writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &padded[..]);
}
#[tokio::test]
async fn test_intermediate_roundtrip() {
let (client, server) = duplex(1024);
let mut writer = IntermediateFrameWriter::new(client);
let mut reader = IntermediateFrameReader::new(server);
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &data[..]);
}
#[tokio::test]
async fn test_secure_intermediate_padding() {
let (client, server) = duplex(1024);
let mut writer = SecureIntermediateFrameWriter::new(client);
let mut reader = SecureIntermediateFrameReader::new(server);
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap();
// Received should have padding stripped to align to 4
let expected_len = (data.len() / 4) * 4;
assert_eq!(received.len(), expected_len);
}
#[tokio::test]
async fn test_mtproto_frame_roundtrip() {
let (client, server) = duplex(1024);
let mut writer = MtprotoFrameWriter::new(client, 0);
let mut reader = MtprotoFrameReader::new(server, 0);
// Message must be padded properly
let data = vec![0u8; 16]; // Aligned to 4 and CBC_PADDING
writer.write_frame(&data).await.unwrap();
writer.flush().await.unwrap();
let received = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &data[..]);
}
#[tokio::test]
async fn test_frame_reader_kind() {
let (client, server) = duplex(1024);
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate);
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
let data = vec![1u8, 2, 3, 4];
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
writer.flush().await.unwrap();
let (received, _) = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &data[..]);
}
}

10
src/stream/mod.rs Normal file
View File

@@ -0,0 +1,10 @@
//! Stream wrappers for MTProto protocol layers
pub mod traits;
pub mod crypto_stream;
pub mod tls_stream;
pub mod frame_stream;
pub use crypto_stream::{CryptoReader, CryptoWriter, PassthroughStream};
pub use tls_stream::{FakeTlsReader, FakeTlsWriter};
pub use frame_stream::*;

277
src/stream/tls_stream.rs Normal file
View File

@@ -0,0 +1,277 @@
//! Fake TLS 1.3 stream wrappers
use bytes::{Bytes, BytesMut};
use std::io::{Error, ErrorKind, Result};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
use crate::protocol::constants::{
TLS_VERSION, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
MAX_TLS_CHUNK_SIZE,
};
use parking_lot::Mutex;
/// Reader that unwraps TLS 1.3 records
pub struct FakeTlsReader<R> {
upstream: R,
buffer: BytesMut,
pending_read: Option<PendingTlsRead>,
}
struct PendingTlsRead {
record_type: u8,
remaining: usize,
}
impl<R> FakeTlsReader<R> {
/// Create new fake TLS reader
pub fn new(upstream: R) -> Self {
Self {
upstream,
buffer: BytesMut::with_capacity(16384),
pending_read: None,
}
}
/// Get reference to upstream
pub fn get_ref(&self) -> &R {
&self.upstream
}
/// Get mutable reference to upstream
pub fn get_mut(&mut self) -> &mut R {
&mut self.upstream
}
/// Consume and return upstream
pub fn into_inner(self) -> R {
self.upstream
}
}
impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
/// Read exactly n bytes through TLS layer
pub async fn read_exact(&mut self, n: usize) -> Result<Bytes> {
while self.buffer.len() < n {
let data = self.read_tls_record().await?;
if data.is_empty() {
return Err(Error::new(ErrorKind::UnexpectedEof, "Connection closed"));
}
self.buffer.extend_from_slice(&data);
}
Ok(self.buffer.split_to(n).freeze())
}
/// Read a single TLS record
async fn read_tls_record(&mut self) -> Result<Vec<u8>> {
loop {
// Read TLS record header (5 bytes)
let mut header = [0u8; 5];
self.upstream.read_exact(&mut header).await?;
let record_type = header[0];
let version = [header[1], header[2]];
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
// Validate version
if version != TLS_VERSION {
return Err(Error::new(
ErrorKind::InvalidData,
format!("Invalid TLS version: {:02x?}", version),
));
}
// Read record body
let mut data = vec![0u8; length];
self.upstream.read_exact(&mut data).await?;
match record_type {
TLS_RECORD_CHANGE_CIPHER => continue, // Skip
TLS_RECORD_APPLICATION => return Ok(data),
_ => {
return Err(Error::new(
ErrorKind::InvalidData,
format!("Unexpected TLS record type: 0x{:02x}", record_type),
));
}
}
}
}
}
impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
// Drain buffer first
if !self.buffer.is_empty() {
let to_copy = self.buffer.len().min(buf.remaining());
buf.put_slice(&self.buffer.split_to(to_copy));
return Poll::Ready(Ok(()));
}
// We need to read a TLS record, but poll_read doesn't support async/await
// So we'll do a simplified version that reads header synchronously
// Read header
let mut header = [0u8; 5];
let mut header_buf = ReadBuf::new(&mut header);
match Pin::new(&mut self.upstream).poll_read(cx, &mut header_buf) {
Poll::Ready(Ok(())) => {
if header_buf.filled().len() < 5 {
// Need more data - store what we have and return pending
// For simplicity, we'll just return empty
return Poll::Ready(Ok(()));
}
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
let record_type = header[0];
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
if record_type == TLS_RECORD_CHANGE_CIPHER {
// Skip this record, try again
cx.waker().wake_by_ref();
return Poll::Pending;
}
if record_type != TLS_RECORD_APPLICATION {
return Poll::Ready(Err(Error::new(
ErrorKind::InvalidData,
"Invalid TLS record type",
)));
}
// Read body
let mut body = vec![0u8; length];
let mut body_buf = ReadBuf::new(&mut body);
match Pin::new(&mut self.upstream).poll_read(cx, &mut body_buf) {
Poll::Ready(Ok(())) => {
let filled = body_buf.filled();
let to_copy = filled.len().min(buf.remaining());
buf.put_slice(&filled[..to_copy]);
if filled.len() > to_copy {
self.buffer.extend_from_slice(&filled[to_copy..]);
}
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
/// Writer that wraps data in TLS 1.3 records
pub struct FakeTlsWriter<W> {
upstream: W,
}
impl<W> FakeTlsWriter<W> {
/// Create new fake TLS writer
pub fn new(upstream: W) -> Self {
Self { upstream }
}
/// Get reference to upstream
pub fn get_ref(&self) -> &W {
&self.upstream
}
/// Get mutable reference to upstream
pub fn get_mut(&mut self) -> &mut W {
&mut self.upstream
}
/// Consume and return upstream
pub fn into_inner(self) -> W {
self.upstream
}
}
impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
// Build TLS record
let chunk_size = buf.len().min(MAX_TLS_CHUNK_SIZE);
let chunk = &buf[..chunk_size];
let mut record = Vec::with_capacity(5 + chunk_size);
record.push(TLS_RECORD_APPLICATION);
record.extend_from_slice(&TLS_VERSION);
record.push((chunk_size >> 8) as u8);
record.push(chunk_size as u8);
record.extend_from_slice(chunk);
match Pin::new(&mut self.upstream).poll_write(cx, &record) {
Poll::Ready(Ok(written)) => {
if written >= 5 {
Poll::Ready(Ok(written - 5))
} else {
Poll::Ready(Ok(0))
}
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.upstream).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.upstream).poll_shutdown(cx)
}
}
impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
/// Write all data wrapped in TLS records (async method)
pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> {
for chunk in data.chunks(MAX_TLS_CHUNK_SIZE) {
let header = [
TLS_RECORD_APPLICATION,
TLS_VERSION[0],
TLS_VERSION[1],
(chunk.len() >> 8) as u8,
chunk.len() as u8,
];
self.upstream.write_all(&header).await?;
self.upstream.write_all(chunk).await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
#[tokio::test]
async fn test_tls_stream_roundtrip() {
let (client, server) = duplex(4096);
let mut writer = FakeTlsWriter::new(client);
let mut reader = FakeTlsReader::new(server);
let original = b"Hello, fake TLS!";
writer.write_all_tls(original).await.unwrap();
writer.flush().await.unwrap();
let received = reader.read_exact(original.len()).await.unwrap();
assert_eq!(&received[..], original);
}
}

113
src/stream/traits.rs Normal file
View File

@@ -0,0 +1,113 @@
//! Stream traits and common types
use bytes::Bytes;
use std::io::Result;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
/// Extra metadata for frames
#[derive(Debug, Clone, Default)]
pub struct FrameMeta {
/// Quick ACK requested
pub quickack: bool,
/// This is a simple ACK message
pub simple_ack: bool,
/// Skip sending this frame
pub skip_send: bool,
}
impl FrameMeta {
pub fn new() -> Self {
Self::default()
}
pub fn with_quickack(mut self) -> Self {
self.quickack = true;
self
}
pub fn with_simple_ack(mut self) -> Self {
self.simple_ack = true;
self
}
}
/// Result of reading a frame
#[derive(Debug)]
pub enum ReadFrameResult {
/// Frame data with metadata
Frame(Bytes, FrameMeta),
/// Connection closed
Closed,
}
/// Trait for streams that wrap another stream
pub trait LayeredStream<U> {
/// Get reference to upstream
fn upstream(&self) -> &U;
/// Get mutable reference to upstream
fn upstream_mut(&mut self) -> &mut U;
/// Consume self and return upstream
fn into_upstream(self) -> U;
}
/// A split read half of a stream
pub struct ReadHalf<R> {
inner: R,
}
impl<R> ReadHalf<R> {
pub fn new(inner: R) -> Self {
Self { inner }
}
pub fn into_inner(self) -> R {
self.inner
}
}
impl<R: AsyncRead + Unpin> AsyncRead for ReadHalf<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
/// A split write half of a stream
pub struct WriteHalf<W> {
inner: W,
}
impl<W> WriteHalf<W> {
pub fn new(inner: W) -> Self {
Self { inner }
}
pub fn into_inner(self) -> W {
self.inner
}
}
impl<W: AsyncWrite + Unpin> AsyncWrite for WriteHalf<W> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}

9
src/transport/mod.rs Normal file
View File

@@ -0,0 +1,9 @@
//! Transport layer: connection pooling, socket utilities, proxy protocol
pub mod pool;
pub mod proxy_protocol;
pub mod socket;
pub use pool::ConnectionPool;
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
pub use socket::*;

338
src/transport/pool.rs Normal file
View File

@@ -0,0 +1,338 @@
//! Connection Pool
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::time::timeout;
use parking_lot::RwLock;
use tracing::{debug, warn};
use crate::error::{ProxyError, Result};
use super::socket::configure_tcp_socket;
/// A pooled connection with metadata
struct PooledConnection {
stream: TcpStream,
created_at: Instant,
}
/// Internal pool state for a single endpoint
struct PoolInner {
/// Available connections
connections: Vec<PooledConnection>,
/// Number of connections being established
pending: usize,
}
impl PoolInner {
fn new() -> Self {
Self {
connections: Vec::new(),
pending: 0,
}
}
}
/// Connection pool configuration
#[derive(Debug, Clone)]
pub struct PoolConfig {
/// Maximum connections per endpoint
pub max_connections: usize,
/// Connection timeout
pub connect_timeout: Duration,
/// Maximum idle time before connection is dropped
pub max_idle_time: Duration,
/// Enable TCP keepalive
pub keepalive: bool,
/// Keepalive interval
pub keepalive_interval: Duration,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_connections: 64,
connect_timeout: Duration::from_secs(10),
max_idle_time: Duration::from_secs(60),
keepalive: true,
keepalive_interval: Duration::from_secs(40),
}
}
}
/// Thread-safe connection pool
pub struct ConnectionPool {
/// Per-endpoint pools
pools: RwLock<HashMap<SocketAddr, Arc<Mutex<PoolInner>>>>,
/// Configuration
config: PoolConfig,
}
impl ConnectionPool {
/// Create new connection pool with default config
pub fn new() -> Self {
Self::with_config(PoolConfig::default())
}
/// Create connection pool with custom config
pub fn with_config(config: PoolConfig) -> Self {
Self {
pools: RwLock::new(HashMap::new()),
config,
}
}
/// Get or create pool for an endpoint
fn get_or_create_pool(&self, addr: SocketAddr) -> Arc<Mutex<PoolInner>> {
// Fast path with read lock
{
let pools = self.pools.read();
if let Some(pool) = pools.get(&addr) {
return Arc::clone(pool);
}
}
// Slow path with write lock
let mut pools = self.pools.write();
pools.entry(addr)
.or_insert_with(|| Arc::new(Mutex::new(PoolInner::new())))
.clone()
}
/// Get a connection to the specified address
pub async fn get(&self, addr: SocketAddr) -> Result<TcpStream> {
let pool = self.get_or_create_pool(addr);
// Try to get an existing connection
{
let mut inner = pool.lock().await;
// Remove stale connections
let now = Instant::now();
inner.connections.retain(|c| {
now.duration_since(c.created_at) < self.config.max_idle_time
});
// Try to find a usable connection
while let Some(conn) = inner.connections.pop() {
// Check if connection is still alive
if is_connection_alive(&conn.stream) {
debug!(addr = %addr, "Reusing pooled connection");
return Ok(conn.stream);
}
debug!(addr = %addr, "Discarding dead pooled connection");
}
// Check if we can create a new connection
let total = inner.connections.len() + inner.pending;
if total >= self.config.max_connections {
return Err(ProxyError::ConnectionTimeout {
addr: addr.to_string()
});
}
inner.pending += 1;
}
// Create new connection
debug!(addr = %addr, "Creating new connection");
let result = self.create_connection(addr).await;
// Decrement pending count
{
let mut inner = pool.lock().await;
inner.pending = inner.pending.saturating_sub(1);
}
result
}
/// Create a new connection to the address
async fn create_connection(&self, addr: SocketAddr) -> Result<TcpStream> {
let connect_future = TcpStream::connect(addr);
let stream = timeout(self.config.connect_timeout, connect_future)
.await
.map_err(|_| ProxyError::ConnectionTimeout {
addr: addr.to_string()
})?
.map_err(|e| {
if e.kind() == std::io::ErrorKind::ConnectionRefused {
ProxyError::ConnectionRefused { addr: addr.to_string() }
} else {
ProxyError::Io(e)
}
})?;
// Configure socket
configure_tcp_socket(
&stream,
self.config.keepalive,
self.config.keepalive_interval,
)?;
Ok(stream)
}
/// Return a connection to the pool
pub async fn put(&self, addr: SocketAddr, stream: TcpStream) {
let pool = self.get_or_create_pool(addr);
let mut inner = pool.lock().await;
if inner.connections.len() < self.config.max_connections {
inner.connections.push(PooledConnection {
stream,
created_at: Instant::now(),
});
debug!(addr = %addr, pool_size = inner.connections.len(), "Returned connection to pool");
} else {
debug!(addr = %addr, "Pool full, dropping connection");
}
}
/// Close all pooled connections
pub async fn close_all(&self) {
let pools = self.pools.read();
for (addr, pool) in pools.iter() {
let mut inner = pool.lock().await;
let count = inner.connections.len();
inner.connections.clear();
debug!(addr = %addr, count = count, "Closed pooled connections");
}
}
/// Get pool statistics
pub async fn stats(&self) -> PoolStats {
let pools = self.pools.read();
let mut total_connections = 0;
let mut total_pending = 0;
let mut endpoints = 0;
for pool in pools.values() {
let inner = pool.lock().await;
total_connections += inner.connections.len();
total_pending += inner.pending;
endpoints += 1;
}
PoolStats {
endpoints,
total_connections,
total_pending,
}
}
}
impl Default for ConnectionPool {
fn default() -> Self {
Self::new()
}
}
/// Pool statistics
#[derive(Debug, Clone)]
pub struct PoolStats {
pub endpoints: usize,
pub total_connections: usize,
pub total_pending: usize,
}
/// Check if a TCP connection is still alive (non-blocking)
fn is_connection_alive(stream: &TcpStream) -> bool {
// Try a non-blocking read to check connection state
let mut buf = [0u8; 1];
match stream.try_read(&mut buf) {
Ok(0) => false, // Connection closed
Ok(_) => true, // Data available (shouldn't happen, but connection is alive)
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => true, // No data, but alive
Err(_) => false, // Some error, assume dead
}
}
/// Connection pool with custom initialization
pub struct InitializingPool<F> {
pool: ConnectionPool,
init_fn: F,
}
impl<F, Fut> InitializingPool<F>
where
F: Fn(TcpStream, SocketAddr) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<TcpStream>> + Send,
{
/// Create pool with initialization function
pub fn new(config: PoolConfig, init_fn: F) -> Self {
Self {
pool: ConnectionPool::with_config(config),
init_fn,
}
}
/// Get an initialized connection
pub async fn get(&self, addr: SocketAddr) -> Result<TcpStream> {
let stream = self.pool.get(addr).await?;
(self.init_fn)(stream, addr).await
}
/// Return connection to pool
pub async fn put(&self, addr: SocketAddr, stream: TcpStream) {
self.pool.put(addr, stream).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::net::TcpListener;
#[tokio::test]
async fn test_pool_basic() {
// Start a test server
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
// Accept connections in background
tokio::spawn(async move {
loop {
let _ = listener.accept().await;
}
});
let pool = ConnectionPool::new();
// Get a connection
let conn1 = pool.get(addr).await.unwrap();
// Return it to pool
pool.put(addr, conn1).await;
// Get again (should reuse)
let _conn2 = pool.get(addr).await.unwrap();
let stats = pool.stats().await;
assert_eq!(stats.endpoints, 1);
}
#[tokio::test]
async fn test_pool_connection_refused() {
let pool = ConnectionPool::with_config(PoolConfig {
connect_timeout: Duration::from_millis(100),
..Default::default()
});
// Try to connect to a port that's not listening
let result = pool.get("127.0.0.1:1".parse().unwrap()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_pool_stats() {
let pool = ConnectionPool::new();
let stats = pool.stats().await;
assert_eq!(stats.endpoints, 0);
assert_eq!(stats.total_connections, 0);
}
}

View File

@@ -0,0 +1,381 @@
//! HAProxy PROXY protocol V1/V2
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use tokio::io::{AsyncRead, AsyncReadExt};
use crate::error::{ProxyError, Result};
/// PROXY protocol v1 signature
const PROXY_V1_SIGNATURE: &[u8] = b"PROXY ";
/// PROXY protocol v2 signature
const PROXY_V2_SIGNATURE: &[u8] = &[
0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a,
0x51, 0x55, 0x49, 0x54, 0x0a
];
/// Minimum length for v1 detection
const PROXY_V1_MIN_LEN: usize = 6;
/// Minimum length for v2 header
const PROXY_V2_MIN_LEN: usize = 16;
/// Address families for v2
mod address_family {
pub const UNSPEC: u8 = 0x0;
pub const INET: u8 = 0x1;
pub const INET6: u8 = 0x2;
}
/// Information extracted from PROXY protocol header
#[derive(Debug, Clone)]
pub struct ProxyProtocolInfo {
/// Source (client) address
pub src_addr: SocketAddr,
/// Destination address (optional)
pub dst_addr: Option<SocketAddr>,
/// Protocol version used (1 or 2)
pub version: u8,
}
impl ProxyProtocolInfo {
/// Create info with just source address
pub fn new(src_addr: SocketAddr) -> Self {
Self {
src_addr,
dst_addr: None,
version: 0,
}
}
}
/// Parse PROXY protocol header from a stream
///
/// Returns the parsed info or an error if the header is invalid.
/// The stream position is advanced past the header.
pub async fn parse_proxy_protocol<R: AsyncRead + Unpin>(
reader: &mut R,
default_peer: SocketAddr,
) -> Result<ProxyProtocolInfo> {
// Read enough bytes to detect version
let mut header = [0u8; PROXY_V2_MIN_LEN];
reader.read_exact(&mut header[..PROXY_V1_MIN_LEN]).await
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
// Check for v1
if header[..PROXY_V1_MIN_LEN] == PROXY_V1_SIGNATURE[..] {
return parse_v1(reader, default_peer).await;
}
// Read rest for v2 detection
reader.read_exact(&mut header[PROXY_V1_MIN_LEN..]).await
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
// Check for v2
if header[..12] == PROXY_V2_SIGNATURE[..] {
return parse_v2(reader, &header, default_peer).await;
}
Err(ProxyError::InvalidProxyProtocol)
}
/// Parse PROXY protocol v1
async fn parse_v1<R: AsyncRead + Unpin>(
reader: &mut R,
default_peer: SocketAddr,
) -> Result<ProxyProtocolInfo> {
// Read until CRLF (max 107 bytes total for v1)
let mut line = Vec::with_capacity(128);
line.extend_from_slice(PROXY_V1_SIGNATURE);
loop {
let mut byte = [0u8];
reader.read_exact(&mut byte).await
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
line.push(byte[0]);
if line.ends_with(b"\r\n") {
break;
}
if line.len() > 256 {
return Err(ProxyError::InvalidProxyProtocol);
}
}
// Parse the line: PROXY TCP4/TCP6/UNKNOWN src_ip dst_ip src_port dst_port
let line_str = std::str::from_utf8(&line[PROXY_V1_MIN_LEN..line.len() - 2])
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
let parts: Vec<&str> = line_str.split_whitespace().collect();
if parts.is_empty() {
return Err(ProxyError::InvalidProxyProtocol);
}
match parts[0] {
"TCP4" | "TCP6" if parts.len() >= 5 => {
let src_ip: IpAddr = parts[1].parse()
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
let dst_ip: IpAddr = parts[2].parse()
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
let src_port: u16 = parts[3].parse()
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
let dst_port: u16 = parts[4].parse()
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
Ok(ProxyProtocolInfo {
src_addr: SocketAddr::new(src_ip, src_port),
dst_addr: Some(SocketAddr::new(dst_ip, dst_port)),
version: 1,
})
}
"UNKNOWN" => {
// UNKNOWN means no address info, use default
Ok(ProxyProtocolInfo {
src_addr: default_peer,
dst_addr: None,
version: 1,
})
}
_ => Err(ProxyError::InvalidProxyProtocol),
}
}
/// Parse PROXY protocol v2
async fn parse_v2<R: AsyncRead + Unpin>(
reader: &mut R,
header: &[u8; PROXY_V2_MIN_LEN],
default_peer: SocketAddr,
) -> Result<ProxyProtocolInfo> {
let version_command = header[12];
let version = version_command >> 4;
let command = version_command & 0x0f;
// Must be version 2
if version != 2 {
return Err(ProxyError::InvalidProxyProtocol);
}
let family_protocol = header[13];
let addr_len = u16::from_be_bytes([header[14], header[15]]) as usize;
// Read address data
let mut addr_data = vec![0u8; addr_len];
if addr_len > 0 {
reader.read_exact(&mut addr_data).await
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
}
// LOCAL command (0x0) - use default peer
if command == 0 {
return Ok(ProxyProtocolInfo {
src_addr: default_peer,
dst_addr: None,
version: 2,
});
}
// PROXY command (0x1) - parse addresses
if command != 1 {
return Err(ProxyError::InvalidProxyProtocol);
}
let family = family_protocol >> 4;
match family {
address_family::INET if addr_len >= 12 => {
// IPv4: 4 + 4 + 2 + 2 = 12 bytes
let src_ip = Ipv4Addr::new(
addr_data[0], addr_data[1],
addr_data[2], addr_data[3]
);
let dst_ip = Ipv4Addr::new(
addr_data[4], addr_data[5],
addr_data[6], addr_data[7]
);
let src_port = u16::from_be_bytes([addr_data[8], addr_data[9]]);
let dst_port = u16::from_be_bytes([addr_data[10], addr_data[11]]);
Ok(ProxyProtocolInfo {
src_addr: SocketAddr::new(IpAddr::V4(src_ip), src_port),
dst_addr: Some(SocketAddr::new(IpAddr::V4(dst_ip), dst_port)),
version: 2,
})
}
address_family::INET6 if addr_len >= 36 => {
// IPv6: 16 + 16 + 2 + 2 = 36 bytes
let src_ip = Ipv6Addr::from(
<[u8; 16]>::try_from(&addr_data[0..16]).unwrap()
);
let dst_ip = Ipv6Addr::from(
<[u8; 16]>::try_from(&addr_data[16..32]).unwrap()
);
let src_port = u16::from_be_bytes([addr_data[32], addr_data[33]]);
let dst_port = u16::from_be_bytes([addr_data[34], addr_data[35]]);
Ok(ProxyProtocolInfo {
src_addr: SocketAddr::new(IpAddr::V6(src_ip), src_port),
dst_addr: Some(SocketAddr::new(IpAddr::V6(dst_ip), dst_port)),
version: 2,
})
}
address_family::UNSPEC => {
Ok(ProxyProtocolInfo {
src_addr: default_peer,
dst_addr: None,
version: 2,
})
}
_ => Err(ProxyError::InvalidProxyProtocol),
}
}
/// Builder for PROXY protocol v1 header
pub struct ProxyProtocolV1Builder {
family: &'static str,
src_addr: Option<SocketAddr>,
dst_addr: Option<SocketAddr>,
}
impl ProxyProtocolV1Builder {
pub fn new() -> Self {
Self {
family: "UNKNOWN",
src_addr: None,
dst_addr: None,
}
}
pub fn tcp4(mut self, src: SocketAddr, dst: SocketAddr) -> Self {
self.family = "TCP4";
self.src_addr = Some(src);
self.dst_addr = Some(dst);
self
}
pub fn tcp6(mut self, src: SocketAddr, dst: SocketAddr) -> Self {
self.family = "TCP6";
self.src_addr = Some(src);
self.dst_addr = Some(dst);
self
}
pub fn build(&self) -> Vec<u8> {
match (self.src_addr, self.dst_addr) {
(Some(src), Some(dst)) => {
format!(
"PROXY {} {} {} {} {}\r\n",
self.family,
src.ip(),
dst.ip(),
src.port(),
dst.port()
).into_bytes()
}
_ => b"PROXY UNKNOWN\r\n".to_vec(),
}
}
}
impl Default for ProxyProtocolV1Builder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[tokio::test]
async fn test_parse_v1_tcp4() {
let header = b"PROXY TCP4 192.168.1.1 10.0.0.1 12345 443\r\n";
let mut cursor = Cursor::new(&header[PROXY_V1_MIN_LEN..]);
let default = "0.0.0.0:0".parse().unwrap();
// Simulate that we've already read the signature
let info = parse_v1(&mut cursor, default).await.unwrap();
assert_eq!(info.version, 1);
assert_eq!(info.src_addr.ip().to_string(), "192.168.1.1");
assert_eq!(info.src_addr.port(), 12345);
assert!(info.dst_addr.is_some());
}
#[tokio::test]
async fn test_parse_v1_unknown() {
let header = b"PROXY UNKNOWN\r\n";
let mut cursor = Cursor::new(&header[PROXY_V1_MIN_LEN..]);
let default: SocketAddr = "1.2.3.4:5678".parse().unwrap();
let info = parse_v1(&mut cursor, default).await.unwrap();
assert_eq!(info.version, 1);
assert_eq!(info.src_addr, default);
}
#[tokio::test]
async fn test_parse_v2_tcp4() {
// v2 header for TCP4
let mut header = [0u8; 16];
header[..12].copy_from_slice(PROXY_V2_SIGNATURE);
header[12] = 0x21; // v2, PROXY command
header[13] = 0x11; // AF_INET, STREAM
header[14] = 0x00;
header[15] = 0x0c; // 12 bytes of address data
let addr_data = [
192, 168, 1, 1, // src IP
10, 0, 0, 1, // dst IP
0x30, 0x39, // src port (12345)
0x01, 0xbb, // dst port (443)
];
let mut cursor = Cursor::new(addr_data.to_vec());
let default = "0.0.0.0:0".parse().unwrap();
let info = parse_v2(&mut cursor, &header, default).await.unwrap();
assert_eq!(info.version, 2);
assert_eq!(info.src_addr.ip().to_string(), "192.168.1.1");
assert_eq!(info.src_addr.port(), 12345);
}
#[tokio::test]
async fn test_parse_v2_local() {
let mut header = [0u8; 16];
header[..12].copy_from_slice(PROXY_V2_SIGNATURE);
header[12] = 0x20; // v2, LOCAL command
header[13] = 0x00;
header[14] = 0x00;
header[15] = 0x00; // 0 bytes of address data
let mut cursor = Cursor::new(Vec::new());
let default: SocketAddr = "1.2.3.4:5678".parse().unwrap();
let info = parse_v2(&mut cursor, &header, default).await.unwrap();
assert_eq!(info.version, 2);
assert_eq!(info.src_addr, default);
}
#[test]
fn test_v1_builder() {
let src: SocketAddr = "192.168.1.1:12345".parse().unwrap();
let dst: SocketAddr = "10.0.0.1:443".parse().unwrap();
let header = ProxyProtocolV1Builder::new()
.tcp4(src, dst)
.build();
let expected = b"PROXY TCP4 192.168.1.1 10.0.0.1 12345 443\r\n";
assert_eq!(header, expected);
}
#[test]
fn test_v1_builder_unknown() {
let header = ProxyProtocolV1Builder::new().build();
assert_eq!(header, b"PROXY UNKNOWN\r\n");
}
}

230
src/transport/socket.rs Normal file
View File

@@ -0,0 +1,230 @@
//! TCP Socket Configuration
use std::io::Result;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::net::TcpStream;
use socket2::{Socket, TcpKeepalive, Domain, Type, Protocol};
use tracing::debug;
/// Configure TCP socket with recommended settings for proxy use
pub fn configure_tcp_socket(
stream: &TcpStream,
keepalive: bool,
keepalive_interval: Duration,
) -> Result<()> {
let socket = socket2::SockRef::from(stream);
// Disable Nagle's algorithm for lower latency
socket.set_nodelay(true)?;
// Set keepalive if enabled
if keepalive {
let keepalive = TcpKeepalive::new()
.with_time(keepalive_interval);
// Platform-specific keepalive settings
#[cfg(any(target_os = "linux", target_os = "macos", target_os = "ios"))]
let keepalive = keepalive.with_interval(keepalive_interval);
socket.set_tcp_keepalive(&keepalive)?;
}
// Set buffer sizes
set_buffer_sizes(&socket, 65536, 65536)?;
Ok(())
}
/// Set socket buffer sizes
fn set_buffer_sizes(socket: &socket2::SockRef, recv: usize, send: usize) -> Result<()> {
// These may fail on some systems, so we ignore errors
let _ = socket.set_recv_buffer_size(recv);
let _ = socket.set_send_buffer_size(send);
Ok(())
}
/// Configure socket for accepting client connections
pub fn configure_client_socket(
stream: &TcpStream,
keepalive_secs: u64,
ack_timeout_secs: u64,
) -> Result<()> {
let socket = socket2::SockRef::from(stream);
// Disable Nagle's algorithm
socket.set_nodelay(true)?;
// Set keepalive
let keepalive = TcpKeepalive::new()
.with_time(Duration::from_secs(keepalive_secs));
#[cfg(any(target_os = "linux", target_os = "macos", target_os = "ios"))]
let keepalive = keepalive.with_interval(Duration::from_secs(keepalive_secs));
socket.set_tcp_keepalive(&keepalive)?;
// Set TCP user timeout (Linux only)
#[cfg(target_os = "linux")]
{
use std::os::unix::io::AsRawFd;
let fd = stream.as_raw_fd();
let timeout_ms = (ack_timeout_secs * 1000) as libc::c_int;
unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_USER_TIMEOUT,
&timeout_ms as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
);
}
}
Ok(())
}
/// Set socket to send RST on close (for masking)
pub fn set_linger_zero(stream: &TcpStream) -> Result<()> {
let socket = socket2::SockRef::from(stream);
socket.set_linger(Some(Duration::ZERO))?;
Ok(())
}
/// Create a new TCP socket for outgoing connections
pub fn create_outgoing_socket(addr: SocketAddr) -> Result<Socket> {
let domain = if addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
// Set non-blocking
socket.set_nonblocking(true)?;
// Disable Nagle
socket.set_nodelay(true)?;
Ok(socket)
}
/// Get local address of a socket
pub fn get_local_addr(stream: &TcpStream) -> Option<SocketAddr> {
stream.local_addr().ok()
}
/// Get peer address of a socket
pub fn get_peer_addr(stream: &TcpStream) -> Option<SocketAddr> {
stream.peer_addr().ok()
}
/// Check if address is IPv6
pub fn is_ipv6(addr: &SocketAddr) -> bool {
addr.is_ipv6()
}
/// Parse IPv4-mapped IPv6 address to IPv4
pub fn normalize_ip(addr: SocketAddr) -> SocketAddr {
match addr {
SocketAddr::V6(v6) => {
if let Some(v4) = v6.ip().to_ipv4_mapped() {
SocketAddr::new(std::net::IpAddr::V4(v4), v6.port())
} else {
addr
}
}
_ => addr,
}
}
/// Socket options for server listening
#[derive(Debug, Clone)]
pub struct ListenOptions {
/// Enable SO_REUSEADDR
pub reuse_addr: bool,
/// Enable SO_REUSEPORT (Linux/BSD)
pub reuse_port: bool,
/// Backlog size
pub backlog: u32,
/// IPv6 only (disable dual-stack)
pub ipv6_only: bool,
}
impl Default for ListenOptions {
fn default() -> Self {
Self {
reuse_addr: true,
reuse_port: true,
backlog: 1024,
ipv6_only: false,
}
}
}
/// Create a listening socket with the specified options
pub fn create_listener(addr: SocketAddr, options: &ListenOptions) -> Result<Socket> {
let domain = if addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
if options.reuse_addr {
socket.set_reuse_address(true)?;
}
#[cfg(unix)]
if options.reuse_port {
socket.set_reuse_port(true)?;
}
if addr.is_ipv6() && options.ipv6_only {
socket.set_only_v6(true)?;
}
socket.set_nonblocking(true)?;
socket.bind(&addr.into())?;
socket.listen(options.backlog as i32)?;
debug!(addr = %addr, "Created listening socket");
Ok(socket)
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::net::TcpListener;
#[tokio::test]
async fn test_configure_socket() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let stream = TcpStream::connect(addr).await.unwrap();
configure_tcp_socket(&stream, true, Duration::from_secs(30)).unwrap();
}
#[test]
fn test_normalize_ip() {
// IPv4 stays IPv4
let v4: SocketAddr = "192.168.1.1:8080".parse().unwrap();
assert_eq!(normalize_ip(v4), v4);
// Pure IPv6 stays IPv6
let v6: SocketAddr = "[::1]:8080".parse().unwrap();
assert_eq!(normalize_ip(v6), v6);
}
#[test]
fn test_listen_options_default() {
let opts = ListenOptions::default();
assert!(opts.reuse_addr);
assert!(opts.reuse_port);
assert_eq!(opts.backlog, 1024);
}
}

118
src/util/ip.rs Normal file
View File

@@ -0,0 +1,118 @@
//! IP Addr Detect
use std::net::IpAddr;
use std::time::Duration;
use tracing::{debug, warn};
/// Detected IP addresses
#[derive(Debug, Clone, Default)]
pub struct IpInfo {
pub ipv4: Option<IpAddr>,
pub ipv6: Option<IpAddr>,
}
impl IpInfo {
/// Check if any IP is detected
pub fn has_any(&self) -> bool {
self.ipv4.is_some() || self.ipv6.is_some()
}
/// Get preferred IP (IPv6 if available and preferred)
pub fn preferred(&self, prefer_ipv6: bool) -> Option<IpAddr> {
if prefer_ipv6 {
self.ipv6.or(self.ipv4)
} else {
self.ipv4.or(self.ipv6)
}
}
}
/// URLs for IP detection
const IPV4_URLS: &[&str] = &[
"http://v4.ident.me/",
"http://ipv4.icanhazip.com/",
"http://api.ipify.org/",
];
const IPV6_URLS: &[&str] = &[
"http://v6.ident.me/",
"http://ipv6.icanhazip.com/",
"http://api6.ipify.org/",
];
/// Detect public IP addresses
pub async fn detect_ip() -> IpInfo {
let mut info = IpInfo::default();
// Detect IPv4
for url in IPV4_URLS {
if let Some(ip) = fetch_ip(url).await {
if ip.is_ipv4() {
info.ipv4 = Some(ip);
debug!(ip = %ip, "Detected IPv4 address");
break;
}
}
}
// Detect IPv6
for url in IPV6_URLS {
if let Some(ip) = fetch_ip(url).await {
if ip.is_ipv6() {
info.ipv6 = Some(ip);
debug!(ip = %ip, "Detected IPv6 address");
break;
}
}
}
if !info.has_any() {
warn!("Failed to detect public IP address");
}
info
}
/// Fetch IP from URL
async fn fetch_ip(url: &str) -> Option<IpAddr> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(5))
.build()
.ok()?;
let response = client.get(url).send().await.ok()?;
let text = response.text().await.ok()?;
text.trim().parse().ok()
}
/// Synchronous IP detection (for startup)
pub fn detect_ip_sync() -> IpInfo {
tokio::runtime::Handle::current().block_on(detect_ip())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ip_info() {
let info = IpInfo::default();
assert!(!info.has_any());
let info = IpInfo {
ipv4: Some("1.2.3.4".parse().unwrap()),
ipv6: None,
};
assert!(info.has_any());
assert_eq!(info.preferred(false), Some("1.2.3.4".parse().unwrap()));
assert_eq!(info.preferred(true), Some("1.2.3.4".parse().unwrap()));
let info = IpInfo {
ipv4: Some("1.2.3.4".parse().unwrap()),
ipv6: Some("::1".parse().unwrap()),
};
assert_eq!(info.preferred(false), Some("1.2.3.4".parse().unwrap()));
assert_eq!(info.preferred(true), Some("::1".parse().unwrap()));
}
}

7
src/util/mod.rs Normal file
View File

@@ -0,0 +1,7 @@
//! Utils
pub mod ip;
pub mod time;
pub use ip::*;
pub use time::*;

76
src/util/time.rs Normal file
View File

@@ -0,0 +1,76 @@
//! Time Sync
use std::time::Duration;
use chrono::{DateTime, Utc};
use tracing::{debug, warn, error};
const TIME_SYNC_URL: &str = "https://core.telegram.org/getProxySecret";
const MAX_TIME_SKEW_SECS: i64 = 30;
/// Time sync result
#[derive(Debug, Clone)]
pub struct TimeSyncResult {
pub server_time: DateTime<Utc>,
pub local_time: DateTime<Utc>,
pub skew_secs: i64,
pub is_skewed: bool,
}
/// Check time synchronization with Telegram servers
pub async fn check_time_sync() -> Option<TimeSyncResult> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.ok()?;
let response = client.get(TIME_SYNC_URL).send().await.ok()?;
// Get Date header
let date_header = response.headers().get("date")?;
let date_str = date_header.to_str().ok()?;
// Parse date
let server_time = DateTime::parse_from_rfc2822(date_str)
.ok()?
.with_timezone(&Utc);
let local_time = Utc::now();
let skew_secs = (local_time - server_time).num_seconds();
let is_skewed = skew_secs.abs() > MAX_TIME_SKEW_SECS;
let result = TimeSyncResult {
server_time,
local_time,
skew_secs,
is_skewed,
};
if is_skewed {
warn!(
server = %server_time,
local = %local_time,
skew = skew_secs,
"Time skew detected"
);
} else {
debug!(skew = skew_secs, "Time sync OK");
}
Some(result)
}
/// Background time sync task
pub async fn time_sync_task(check_interval: Duration) -> ! {
loop {
if let Some(result) = check_time_sync().await {
if result.is_skewed {
error!(
"System clock is off by {} seconds. Please sync your clock.",
result.skew_secs
);
}
}
tokio::time::sleep(check_interval).await;
}
}