From 3d9150a0743f85b641bb40d5aea95f61021e576c Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 30 Dec 2025 05:08:05 +0300 Subject: [PATCH] 1.0.0 Tschuss Status Quo - Hallo, Zukunft! --- Cargo.toml | 60 ++++ benches/crypto_bench.rs | 12 + config.toml | 13 + src/config/mod.rs | 227 +++++++++++++ src/crypto/aes.rs | 351 +++++++++++++++++++ src/crypto/hash.rs | 90 +++++ src/crypto/mod.rs | 9 + src/crypto/random.rs | 212 ++++++++++++ src/error.rs | 176 ++++++++++ src/main.rs | 158 +++++++++ src/protocol/constants.rs | 261 ++++++++++++++ src/protocol/frame.rs | 120 +++++++ src/protocol/mod.rs | 11 + src/protocol/obfuscation.rs | 217 ++++++++++++ src/protocol/tls.rs | 244 +++++++++++++ src/proxy/client.rs | 378 +++++++++++++++++++++ src/proxy/handshake.rs | 411 ++++++++++++++++++++++ src/proxy/masking.rs | 115 +++++++ src/proxy/mod.rs | 11 + src/proxy/relay.rs | 162 +++++++++ src/stats/mod.rs | 223 ++++++++++++ src/stream/crypto_stream.rs | 474 ++++++++++++++++++++++++++ src/stream/frame_stream.rs | 585 ++++++++++++++++++++++++++++++++ src/stream/mod.rs | 10 + src/stream/tls_stream.rs | 277 +++++++++++++++ src/stream/traits.rs | 113 ++++++ src/transport/mod.rs | 9 + src/transport/pool.rs | 338 ++++++++++++++++++ src/transport/proxy_protocol.rs | 381 +++++++++++++++++++++ src/transport/socket.rs | 230 +++++++++++++ src/util/ip.rs | 118 +++++++ src/util/mod.rs | 7 + src/util/time.rs | 76 +++++ 33 files changed, 6079 insertions(+) create mode 100644 Cargo.toml create mode 100644 benches/crypto_bench.rs create mode 100644 config.toml create mode 100644 src/config/mod.rs create mode 100644 src/crypto/aes.rs create mode 100644 src/crypto/hash.rs create mode 100644 src/crypto/mod.rs create mode 100644 src/crypto/random.rs create mode 100644 src/error.rs create mode 100644 src/main.rs create mode 100644 src/protocol/constants.rs create mode 100644 src/protocol/frame.rs create mode 100644 src/protocol/mod.rs create mode 100644 src/protocol/obfuscation.rs create mode 100644 src/protocol/tls.rs create mode 100644 src/proxy/client.rs create mode 100644 src/proxy/handshake.rs create mode 100644 src/proxy/masking.rs create mode 100644 src/proxy/mod.rs create mode 100644 src/proxy/relay.rs create mode 100644 src/stats/mod.rs create mode 100644 src/stream/crypto_stream.rs create mode 100644 src/stream/frame_stream.rs create mode 100644 src/stream/mod.rs create mode 100644 src/stream/tls_stream.rs create mode 100644 src/stream/traits.rs create mode 100644 src/transport/mod.rs create mode 100644 src/transport/pool.rs create mode 100644 src/transport/proxy_protocol.rs create mode 100644 src/transport/socket.rs create mode 100644 src/util/ip.rs create mode 100644 src/util/mod.rs create mode 100644 src/util/time.rs diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..b645091 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,60 @@ +[package] +name = "telemt" +version = "1.0.0" +edition = "2021" +rust-version = "1.75" + +[dependencies] +# C +libc = "0.2" + +# Async runtime +tokio = { version = "1.35", features = ["full", "tracing"] } +tokio-util = { version = "0.7", features = ["codec"] } + +# Crypto +aes = "0.8" +ctr = "0.9" +cbc = "0.1" +sha2 = "0.10" +sha1 = "0.10" +md-5 = "0.10" +hmac = "0.12" +crc32fast = "1.3" + +# Network +socket2 = { version = "0.5", features = ["all"] } +rustls = "0.22" + +# Serial +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +toml = "0.8" + +# Utils +bytes = "1.5" +thiserror = "1.0" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +parking_lot = "0.12" +dashmap = "5.5" +lru = "0.12" +rand = "0.8" +chrono = { version = "0.4", features = ["serde"] } +hex = "0.4" +base64 = "0.21" +url = "2.5" +regex = "1.10" +once_cell = "1.19" + +# HTTP +reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false } + +[dev-dependencies] +tokio-test = "0.4" +criterion = "0.5" +proptest = "1.4" + +[[bench]] +name = "crypto_bench" +harness = false \ No newline at end of file diff --git a/benches/crypto_bench.rs b/benches/crypto_bench.rs new file mode 100644 index 0000000..0089abe --- /dev/null +++ b/benches/crypto_bench.rs @@ -0,0 +1,12 @@ +// Cryptobench +use criterion::{black_box, criterion_group, Criterion}; + +fn bench_aes_ctr(c: &mut Criterion) { + c.bench_function("aes_ctr_encrypt_64kb", |b| { + let data = vec![0u8; 65536]; + b.iter(|| { + let mut enc = AesCtr::new(&[0u8; 32], 0); + black_box(enc.encrypt(&data)) + }) + }); +} \ No newline at end of file diff --git a/config.toml b/config.toml new file mode 100644 index 0000000..b8a62af --- /dev/null +++ b/config.toml @@ -0,0 +1,13 @@ +port = 443 + +[users] +user1 = "00000000000000000000000000000000" + +[modes] +classic = true +secure = true +tls = true + +tls_domain = "www.github.com" +fast_mode = true +prefer_ipv6 = false \ No newline at end of file diff --git a/src/config/mod.rs b/src/config/mod.rs new file mode 100644 index 0000000..3bc983e --- /dev/null +++ b/src/config/mod.rs @@ -0,0 +1,227 @@ +//! Configuration + +use std::collections::HashMap; +use std::net::IpAddr; +use std::path::Path; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use crate::error::{ProxyError, Result}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProxyModes { + #[serde(default)] + pub classic: bool, + #[serde(default)] + pub secure: bool, + #[serde(default = "default_true")] + pub tls: bool, +} + +fn default_true() -> bool { true } + +impl Default for ProxyModes { + fn default() -> Self { + Self { classic: true, secure: true, tls: true } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProxyConfig { + #[serde(default = "default_port")] + pub port: u16, + + #[serde(default)] + pub users: HashMap, + + #[serde(default)] + pub ad_tag: Option, + + #[serde(default)] + pub modes: ProxyModes, + + #[serde(default = "default_tls_domain")] + pub tls_domain: String, + + #[serde(default = "default_true")] + pub mask: bool, + + #[serde(default)] + pub mask_host: Option, + + #[serde(default = "default_mask_port")] + pub mask_port: u16, + + #[serde(default)] + pub prefer_ipv6: bool, + + #[serde(default = "default_true")] + pub fast_mode: bool, + + #[serde(default)] + pub use_middle_proxy: bool, + + #[serde(default)] + pub user_max_tcp_conns: HashMap, + + #[serde(default)] + pub user_expirations: HashMap>, + + #[serde(default)] + pub user_data_quota: HashMap, + + #[serde(default = "default_replay_check_len")] + pub replay_check_len: usize, + + #[serde(default)] + pub ignore_time_skew: bool, + + #[serde(default = "default_handshake_timeout")] + pub client_handshake_timeout: u64, + + #[serde(default = "default_connect_timeout")] + pub tg_connect_timeout: u64, + + #[serde(default = "default_keepalive")] + pub client_keepalive: u64, + + #[serde(default = "default_ack_timeout")] + pub client_ack_timeout: u64, + + #[serde(default = "default_listen_addr")] + pub listen_addr_ipv4: String, + + #[serde(default)] + pub listen_addr_ipv6: Option, + + #[serde(default)] + pub listen_unix_sock: Option, + + #[serde(default)] + pub metrics_port: Option, + + #[serde(default = "default_metrics_whitelist")] + pub metrics_whitelist: Vec, + + #[serde(default = "default_fake_cert_len")] + pub fake_cert_len: usize, +} + +fn default_port() -> u16 { 443 } +fn default_tls_domain() -> String { "www.google.com".to_string() } +fn default_mask_port() -> u16 { 443 } +fn default_replay_check_len() -> usize { 65536 } +fn default_handshake_timeout() -> u64 { 10 } +fn default_connect_timeout() -> u64 { 10 } +fn default_keepalive() -> u64 { 600 } +fn default_ack_timeout() -> u64 { 300 } +fn default_listen_addr() -> String { "0.0.0.0".to_string() } +fn default_fake_cert_len() -> usize { 2048 } + +fn default_metrics_whitelist() -> Vec { + vec![ + "127.0.0.1".parse().unwrap(), + "::1".parse().unwrap(), + ] +} + +impl Default for ProxyConfig { + fn default() -> Self { + let mut users = HashMap::new(); + users.insert("default".to_string(), "00000000000000000000000000000000".to_string()); + + Self { + port: default_port(), + users, + ad_tag: None, + modes: ProxyModes::default(), + tls_domain: default_tls_domain(), + mask: true, + mask_host: None, + mask_port: default_mask_port(), + prefer_ipv6: false, + fast_mode: true, + use_middle_proxy: false, + user_max_tcp_conns: HashMap::new(), + user_expirations: HashMap::new(), + user_data_quota: HashMap::new(), + replay_check_len: default_replay_check_len(), + ignore_time_skew: false, + client_handshake_timeout: default_handshake_timeout(), + tg_connect_timeout: default_connect_timeout(), + client_keepalive: default_keepalive(), + client_ack_timeout: default_ack_timeout(), + listen_addr_ipv4: default_listen_addr(), + listen_addr_ipv6: Some("::".to_string()), + listen_unix_sock: None, + metrics_port: None, + metrics_whitelist: default_metrics_whitelist(), + fake_cert_len: default_fake_cert_len(), + } + } +} + +impl ProxyConfig { + pub fn load>(path: P) -> Result { + let content = std::fs::read_to_string(path) + .map_err(|e| ProxyError::Config(e.to_string()))?; + + let mut config: ProxyConfig = toml::from_str(&content) + .map_err(|e| ProxyError::Config(e.to_string()))?; + + // Validate secrets + for (user, secret) in &config.users { + if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 { + return Err(ProxyError::InvalidSecret { + user: user.clone(), + reason: "Must be 32 hex characters".to_string(), + }); + } + } + + // Default mask_host + if config.mask_host.is_none() { + config.mask_host = Some(config.tls_domain.clone()); + } + + // Random fake_cert_len + use rand::Rng; + config.fake_cert_len = rand::thread_rng().gen_range(1024..4096); + + Ok(config) + } + + pub fn validate(&self) -> Result<()> { + if self.users.is_empty() { + return Err(ProxyError::Config("No users configured".to_string())); + } + + if !self.modes.classic && !self.modes.secure && !self.modes.tls { + return Err(ProxyError::Config("No modes enabled".to_string())); + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let config = ProxyConfig::default(); + assert_eq!(config.port, 443); + assert!(config.modes.tls); + assert_eq!(config.client_keepalive, 600); + assert_eq!(config.client_ack_timeout, 300); + } + + #[test] + fn test_config_validate() { + let mut config = ProxyConfig::default(); + assert!(config.validate().is_ok()); + + config.users.clear(); + assert!(config.validate().is_err()); + } +} \ No newline at end of file diff --git a/src/crypto/aes.rs b/src/crypto/aes.rs new file mode 100644 index 0000000..b5651b1 --- /dev/null +++ b/src/crypto/aes.rs @@ -0,0 +1,351 @@ +//! AES + +use aes::Aes256; +use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}}; +use cbc::{Encryptor as CbcEncryptor, Decryptor as CbcDecryptor}; +use cbc::cipher::{BlockEncryptMut, BlockDecryptMut, block_padding::NoPadding}; +use crate::error::{ProxyError, Result}; + +type Aes256Ctr = Ctr128BE; +type Aes256CbcEnc = CbcEncryptor; +type Aes256CbcDec = CbcDecryptor; + +/// AES-256-CTR encryptor/decryptor +pub struct AesCtr { + cipher: Aes256Ctr, +} + +impl AesCtr { + pub fn new(key: &[u8; 32], iv: u128) -> Self { + let iv_bytes = iv.to_be_bytes(); + Self { + cipher: Aes256Ctr::new(key.into(), (&iv_bytes).into()), + } + } + + pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result { + if key.len() != 32 { + return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() }); + } + if iv.len() != 16 { + return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() }); + } + + let key: [u8; 32] = key.try_into().unwrap(); + let iv = u128::from_be_bytes(iv.try_into().unwrap()); + Ok(Self::new(&key, iv)) + } + + /// Encrypt/decrypt data in-place (CTR mode is symmetric) + pub fn apply(&mut self, data: &mut [u8]) { + self.cipher.apply_keystream(data); + } + + /// Encrypt data, returning new buffer + pub fn encrypt(&mut self, data: &[u8]) -> Vec { + let mut output = data.to_vec(); + self.apply(&mut output); + output + } + + /// Decrypt data (for CTR, identical to encrypt) + pub fn decrypt(&mut self, data: &[u8]) -> Vec { + self.encrypt(data) + } +} + +/// AES-256-CBC Ciphermagic +pub struct AesCbc { + key: [u8; 32], + iv: [u8; 16], +} + +impl AesCbc { + pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self { + Self { key, iv } + } + + pub fn from_slices(key: &[u8], iv: &[u8]) -> Result { + if key.len() != 32 { + return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() }); + } + if iv.len() != 16 { + return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() }); + } + + Ok(Self { + key: key.try_into().unwrap(), + iv: iv.try_into().unwrap(), + }) + } + + /// Encrypt data using CBC mode + pub fn encrypt(&self, data: &[u8]) -> Result> { + if data.len() % 16 != 0 { + return Err(ProxyError::Crypto( + format!("CBC data must be aligned to 16 bytes, got {}", data.len()) + )); + } + + if data.is_empty() { + return Ok(Vec::new()); + } + + let mut buffer = data.to_vec(); + + let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into()); + + for chunk in buffer.chunks_mut(16) { + encryptor.encrypt_block_mut(chunk.into()); + } + + Ok(buffer) + } + + /// Decrypt data using CBC mode + pub fn decrypt(&self, data: &[u8]) -> Result> { + if data.len() % 16 != 0 { + return Err(ProxyError::Crypto( + format!("CBC data must be aligned to 16 bytes, got {}", data.len()) + )); + } + + if data.is_empty() { + return Ok(Vec::new()); + } + + let mut buffer = data.to_vec(); + + let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into()); + + for chunk in buffer.chunks_mut(16) { + decryptor.decrypt_block_mut(chunk.into()); + } + + Ok(buffer) + } + + /// Encrypt data in-place + pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> { + if data.len() % 16 != 0 { + return Err(ProxyError::Crypto( + format!("CBC data must be aligned to 16 bytes, got {}", data.len()) + )); + } + + if data.is_empty() { + return Ok(()); + } + + let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into()); + + for chunk in data.chunks_mut(16) { + encryptor.encrypt_block_mut(chunk.into()); + } + + Ok(()) + } + + /// Decrypt data in-place + pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> { + if data.len() % 16 != 0 { + return Err(ProxyError::Crypto( + format!("CBC data must be aligned to 16 bytes, got {}", data.len()) + )); + } + + if data.is_empty() { + return Ok(()); + } + + let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into()); + + for chunk in data.chunks_mut(16) { + decryptor.decrypt_block_mut(chunk.into()); + } + + Ok(()) + } +} + +/// Trait for unified encryption interface +pub trait Encryptor: Send + Sync { + fn encrypt(&mut self, data: &[u8]) -> Vec; +} + +/// Trait for unified decryption interface +pub trait Decryptor: Send + Sync { + fn decrypt(&mut self, data: &[u8]) -> Vec; +} + +impl Encryptor for AesCtr { + fn encrypt(&mut self, data: &[u8]) -> Vec { + AesCtr::encrypt(self, data) + } +} + +impl Decryptor for AesCtr { + fn decrypt(&mut self, data: &[u8]) -> Vec { + AesCtr::decrypt(self, data) + } +} + +/// No-op encryptor for fast mode +pub struct PassthroughEncryptor; + +impl Encryptor for PassthroughEncryptor { + fn encrypt(&mut self, data: &[u8]) -> Vec { + data.to_vec() + } +} + +impl Decryptor for PassthroughEncryptor { + fn decrypt(&mut self, data: &[u8]) -> Vec { + data.to_vec() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_aes_ctr_roundtrip() { + let key = [0u8; 32]; + let iv = 12345u128; + + let original = b"Hello, MTProto!"; + + let mut enc = AesCtr::new(&key, iv); + let encrypted = enc.encrypt(original); + + let mut dec = AesCtr::new(&key, iv); + let decrypted = dec.decrypt(&encrypted); + + assert_eq!(original.as_slice(), decrypted.as_slice()); + } + + #[test] + fn test_aes_cbc_roundtrip() { + let key = [0u8; 32]; + let iv = [0u8; 16]; + + // Must be aligned to 16 bytes + let original = [0u8; 32]; + + let cipher = AesCbc::new(key, iv); + let encrypted = cipher.encrypt(&original).unwrap(); + let decrypted = cipher.decrypt(&encrypted).unwrap(); + + assert_eq!(original.as_slice(), decrypted.as_slice()); + } + + #[test] + fn test_aes_cbc_chaining_works() { + let key = [0x42u8; 32]; + let iv = [0x00u8; 16]; + + let plaintext = [0xAA_u8; 32]; + + let cipher = AesCbc::new(key, iv); + let ciphertext = cipher.encrypt(&plaintext).unwrap(); + + // CBC Corrections + let block1 = &ciphertext[0..16]; + let block2 = &ciphertext[16..32]; + + assert_ne!(block1, block2, "CBC chaining broken: identical plaintext blocks produced identical ciphertext"); + } + + #[test] + fn test_aes_cbc_known_vector() { + let key = [0u8; 32]; + let iv = [0u8; 16]; + + // 3 Datablocks + let plaintext = [ + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, + 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, + // Block 2 + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, + 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, + // Block 3 - different + 0xFF, 0xEE, 0xDD, 0xCC, 0xBB, 0xAA, 0x99, 0x88, + 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11, 0x00, + ]; + + let cipher = AesCbc::new(key, iv); + let ciphertext = cipher.encrypt(&plaintext).unwrap(); + + // Decrypt + Verify + let decrypted = cipher.decrypt(&ciphertext).unwrap(); + assert_eq!(plaintext.as_slice(), decrypted.as_slice()); + + // Verify Ciphertexts Block 1 != Block 2 + assert_ne!(&ciphertext[0..16], &ciphertext[16..32]); + } + + #[test] + fn test_aes_cbc_in_place() { + let key = [0x12u8; 32]; + let iv = [0x34u8; 16]; + + let original = [0x56u8; 48]; // 3 blocks + let mut buffer = original.clone(); + + let cipher = AesCbc::new(key, iv); + + cipher.encrypt_in_place(&mut buffer).unwrap(); + assert_ne!(&buffer[..], &original[..]); + + cipher.decrypt_in_place(&mut buffer).unwrap(); + assert_eq!(&buffer[..], &original[..]); + } + + #[test] + fn test_aes_cbc_empty_data() { + let cipher = AesCbc::new([0u8; 32], [0u8; 16]); + + let encrypted = cipher.encrypt(&[]).unwrap(); + assert!(encrypted.is_empty()); + + let decrypted = cipher.decrypt(&[]).unwrap(); + assert!(decrypted.is_empty()); + } + + #[test] + fn test_aes_cbc_unaligned_error() { + let cipher = AesCbc::new([0u8; 32], [0u8; 16]); + + // 15 bytes + let result = cipher.encrypt(&[0u8; 15]); + assert!(result.is_err()); + + // 17 bytes + let result = cipher.encrypt(&[0u8; 17]); + assert!(result.is_err()); + } + + #[test] + fn test_aes_cbc_avalanche_effect() { + // Cipherplane + + let key = [0xAB; 32]; + let iv = [0xCD; 16]; + + let mut plaintext1 = [0u8; 32]; + let mut plaintext2 = [0u8; 32]; + plaintext2[0] = 0x01; // Один бит отличается + + let cipher = AesCbc::new(key, iv); + + let ciphertext1 = cipher.encrypt(&plaintext1).unwrap(); + let ciphertext2 = cipher.encrypt(&plaintext2).unwrap(); + + // First Blocks Diff + assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]); + + // Second Blocks Diff + assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]); + } +} \ No newline at end of file diff --git a/src/crypto/hash.rs b/src/crypto/hash.rs new file mode 100644 index 0000000..0472018 --- /dev/null +++ b/src/crypto/hash.rs @@ -0,0 +1,90 @@ +use hmac::{Hmac, Mac}; +use sha2::Sha256; +use md5::Md5; +use sha1::Sha1; +use sha2::Digest; + +type HmacSha256 = Hmac; + +/// SHA-256 +pub fn sha256(data: &[u8]) -> [u8; 32] { + let mut hasher = Sha256::new(); + hasher.update(data); + hasher.finalize().into() +} + +/// SHA-256 HMAC +pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] { + let mut mac = HmacSha256::new_from_slice(key) + .expect("HMAC accepts any key length"); + mac.update(data); + mac.finalize().into_bytes().into() +} + +/// SHA-1 +pub fn sha1(data: &[u8]) -> [u8; 20] { + let mut hasher = Sha1::new(); + hasher.update(data); + hasher.finalize().into() +} + +/// MD5 +pub fn md5(data: &[u8]) -> [u8; 16] { + let mut hasher = Md5::new(); + hasher.update(data); + hasher.finalize().into() +} + +/// CRC32 +pub fn crc32(data: &[u8]) -> u32 { + crc32fast::hash(data) +} + +/// Middle Proxy Keygen +pub fn derive_middleproxy_keys( + nonce_srv: &[u8; 16], + nonce_clt: &[u8; 16], + clt_ts: &[u8; 4], + srv_ip: Option<&[u8]>, + clt_port: &[u8; 2], + purpose: &[u8], + clt_ip: Option<&[u8]>, + srv_port: &[u8; 2], + secret: &[u8], + clt_ipv6: Option<&[u8; 16]>, + srv_ipv6: Option<&[u8; 16]>, +) -> ([u8; 32], [u8; 16]) { + const EMPTY_IP: [u8; 4] = [0, 0, 0, 0]; + + let srv_ip = srv_ip.unwrap_or(&EMPTY_IP); + let clt_ip = clt_ip.unwrap_or(&EMPTY_IP); + + let mut s = Vec::with_capacity(256); + s.extend_from_slice(nonce_srv); + s.extend_from_slice(nonce_clt); + s.extend_from_slice(clt_ts); + s.extend_from_slice(srv_ip); + s.extend_from_slice(clt_port); + s.extend_from_slice(purpose); + s.extend_from_slice(clt_ip); + s.extend_from_slice(srv_port); + s.extend_from_slice(secret); + s.extend_from_slice(nonce_srv); + + if let (Some(clt_v6), Some(srv_v6)) = (clt_ipv6, srv_ipv6) { + s.extend_from_slice(clt_v6); + s.extend_from_slice(srv_v6); + } + + s.extend_from_slice(nonce_clt); + + let md5_1 = md5(&s[1..]); + let sha1_sum = sha1(&s); + let md5_2 = md5(&s[2..]); + + let mut key = [0u8; 32]; + key[..12].copy_from_slice(&md5_1[..12]); + key[12..].copy_from_slice(&sha1_sum); + + (key, md5_2) +} \ No newline at end of file diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs new file mode 100644 index 0000000..6339927 --- /dev/null +++ b/src/crypto/mod.rs @@ -0,0 +1,9 @@ +//! Crypto + +pub mod aes; +pub mod hash; +pub mod random; + +pub use aes::{AesCtr, AesCbc}; +pub use hash::{sha256, sha256_hmac, sha1, md5, crc32}; +pub use random::{SecureRandom, SECURE_RANDOM}; \ No newline at end of file diff --git a/src/crypto/random.rs b/src/crypto/random.rs new file mode 100644 index 0000000..c179f25 --- /dev/null +++ b/src/crypto/random.rs @@ -0,0 +1,212 @@ +//! Pseudorandom + +use rand::{Rng, RngCore, SeedableRng}; +use rand::rngs::StdRng; +use parking_lot::Mutex; +use crate::crypto::AesCtr; +use once_cell::sync::Lazy; + +/// Global secure random instance +pub static SECURE_RANDOM: Lazy = Lazy::new(SecureRandom::new); + +/// Cryptographically secure PRNG with AES-CTR +pub struct SecureRandom { + inner: Mutex, +} + +struct SecureRandomInner { + rng: StdRng, + cipher: AesCtr, + buffer: Vec, +} + +impl SecureRandom { + pub fn new() -> Self { + let mut rng = StdRng::from_entropy(); + + let mut key = [0u8; 32]; + rng.fill_bytes(&mut key); + let iv: u128 = rng.gen(); + + Self { + inner: Mutex::new(SecureRandomInner { + rng, + cipher: AesCtr::new(&key, iv), + buffer: Vec::with_capacity(1024), + }), + } + } + + /// Generate random bytes + pub fn bytes(&self, len: usize) -> Vec { + let mut inner = self.inner.lock(); + const CHUNK_SIZE: usize = 512; + + while inner.buffer.len() < len { + let mut chunk = vec![0u8; CHUNK_SIZE]; + inner.rng.fill_bytes(&mut chunk); + inner.cipher.apply(&mut chunk); + inner.buffer.extend_from_slice(&chunk); + } + + inner.buffer.drain(..len).collect() + } + + /// Generate random number in range [0, max) + pub fn range(&self, max: usize) -> usize { + if max == 0 { + return 0; + } + let mut inner = self.inner.lock(); + inner.rng.gen_range(0..max) + } + + /// Generate random bits + pub fn bits(&self, k: usize) -> u64 { + if k == 0 { + return 0; + } + + let bytes_needed = (k + 7) / 8; + let bytes = self.bytes(bytes_needed.min(8)); + + let mut result = 0u64; + for (i, &b) in bytes.iter().enumerate() { + if i >= 8 { + break; + } + result |= (b as u64) << (i * 8); + } + + // Mask extra bits + if k < 64 { + result &= (1u64 << k) - 1; + } + + result + } + + /// Choose random element from slice + pub fn choose<'a, T>(&self, slice: &'a [T]) -> Option<&'a T> { + if slice.is_empty() { + None + } else { + Some(&slice[self.range(slice.len())]) + } + } + + /// Shuffle slice in place + pub fn shuffle(&self, slice: &mut [T]) { + let mut inner = self.inner.lock(); + for i in (1..slice.len()).rev() { + let j = inner.rng.gen_range(0..=i); + slice.swap(i, j); + } + } + + /// Generate random u32 + pub fn u32(&self) -> u32 { + let mut inner = self.inner.lock(); + inner.rng.gen() + } + + /// Generate random u64 + pub fn u64(&self) -> u64 { + let mut inner = self.inner.lock(); + inner.rng.gen() + } +} + +impl Default for SecureRandom { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashSet; + + #[test] + fn test_bytes_uniqueness() { + let rng = SecureRandom::new(); + let a = rng.bytes(32); + let b = rng.bytes(32); + assert_ne!(a, b); + } + + #[test] + fn test_bytes_length() { + let rng = SecureRandom::new(); + assert_eq!(rng.bytes(0).len(), 0); + assert_eq!(rng.bytes(1).len(), 1); + assert_eq!(rng.bytes(100).len(), 100); + assert_eq!(rng.bytes(1000).len(), 1000); + } + + #[test] + fn test_range() { + let rng = SecureRandom::new(); + + for _ in 0..1000 { + let n = rng.range(10); + assert!(n < 10); + } + + assert_eq!(rng.range(1), 0); + assert_eq!(rng.range(0), 0); + } + + #[test] + fn test_bits() { + let rng = SecureRandom::new(); + + // Single bit should be 0 or 1 + for _ in 0..100 { + assert!(rng.bits(1) <= 1); + } + + // 8 bits should be 0-255 + for _ in 0..100 { + assert!(rng.bits(8) <= 255); + } + } + + #[test] + fn test_choose() { + let rng = SecureRandom::new(); + let items = vec![1, 2, 3, 4, 5]; + + let mut seen = HashSet::new(); + for _ in 0..1000 { + if let Some(&item) = rng.choose(&items) { + seen.insert(item); + } + } + + // Should have seen all items + assert_eq!(seen.len(), 5); + + // Empty slice should return None + let empty: Vec = vec![]; + assert!(rng.choose(&empty).is_none()); + } + + #[test] + fn test_shuffle() { + let rng = SecureRandom::new(); + let original = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + + let mut shuffled = original.clone(); + rng.shuffle(&mut shuffled); + + // Should contain same elements + let mut sorted = shuffled.clone(); + sorted.sort(); + assert_eq!(sorted, original); + + // Should be different order (with very high probability) + assert_ne!(shuffled, original); + } +} \ No newline at end of file diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..c14e148 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,176 @@ +//! Error Types + +use std::net::SocketAddr; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum ProxyError { + // ============= Crypto Errors ============= + + #[error("Crypto error: {0}")] + Crypto(String), + + #[error("Invalid key length: expected {expected}, got {got}")] + InvalidKeyLength { expected: usize, got: usize }, + + // ============= Protocol Errors ============= + + #[error("Invalid handshake: {0}")] + InvalidHandshake(String), + + #[error("Invalid protocol tag: {0:02x?}")] + InvalidProtoTag([u8; 4]), + + #[error("Invalid TLS record: type={record_type}, version={version:02x?}")] + InvalidTlsRecord { record_type: u8, version: [u8; 2] }, + + #[error("Replay attack detected from {addr}")] + ReplayAttack { addr: SocketAddr }, + + #[error("Time skew detected: client={client_time}, server={server_time}")] + TimeSkew { client_time: u32, server_time: u32 }, + + #[error("Invalid message length: {len} (min={min}, max={max})")] + InvalidMessageLength { len: usize, min: usize, max: usize }, + + #[error("Checksum mismatch: expected={expected:08x}, got={got:08x}")] + ChecksumMismatch { expected: u32, got: u32 }, + + #[error("Sequence number mismatch: expected={expected}, got={got}")] + SeqNoMismatch { expected: i32, got: i32 }, + + // ============= Network Errors ============= + + #[error("Connection timeout to {addr}")] + ConnectionTimeout { addr: String }, + + #[error("Connection refused by {addr}")] + ConnectionRefused { addr: String }, + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + // ============= Proxy Protocol Errors ============= + + #[error("Invalid proxy protocol header")] + InvalidProxyProtocol, + + // ============= Config Errors ============= + + #[error("Config error: {0}")] + Config(String), + + #[error("Invalid secret for user {user}: {reason}")] + InvalidSecret { user: String, reason: String }, + + // ============= User Errors ============= + + #[error("User {user} expired")] + UserExpired { user: String }, + + #[error("User {user} exceeded connection limit")] + ConnectionLimitExceeded { user: String }, + + #[error("User {user} exceeded data quota")] + DataQuotaExceeded { user: String }, + + #[error("Unknown user")] + UnknownUser, + + // ============= General Errors ============= + + #[error("Internal error: {0}")] + Internal(String), +} + +/// Convenient Result type alias +pub type Result = std::result::Result; + +/// Result with optional bad client handling +#[derive(Debug)] +pub enum HandshakeResult { + /// Handshake succeeded + Success(T), + /// Client failed validation, needs masking + BadClient, + /// Error occurred + Error(ProxyError), +} + +impl HandshakeResult { + /// Check if successful + pub fn is_success(&self) -> bool { + matches!(self, HandshakeResult::Success(_)) + } + + /// Check if bad client + pub fn is_bad_client(&self) -> bool { + matches!(self, HandshakeResult::BadClient) + } + + /// Convert to Result, treating BadClient as error + pub fn into_result(self) -> Result { + match self { + HandshakeResult::Success(v) => Ok(v), + HandshakeResult::BadClient => Err(ProxyError::InvalidHandshake("Bad client".into())), + HandshakeResult::Error(e) => Err(e), + } + } + + /// Map the success value + pub fn map U>(self, f: F) -> HandshakeResult { + match self { + HandshakeResult::Success(v) => HandshakeResult::Success(f(v)), + HandshakeResult::BadClient => HandshakeResult::BadClient, + HandshakeResult::Error(e) => HandshakeResult::Error(e), + } + } +} + +impl From for HandshakeResult { + fn from(err: ProxyError) -> Self { + HandshakeResult::Error(err) + } +} + +impl From for HandshakeResult { + fn from(err: std::io::Error) -> Self { + HandshakeResult::Error(ProxyError::Io(err)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_handshake_result() { + let success: HandshakeResult = HandshakeResult::Success(42); + assert!(success.is_success()); + assert!(!success.is_bad_client()); + + let bad: HandshakeResult = HandshakeResult::BadClient; + assert!(!bad.is_success()); + assert!(bad.is_bad_client()); + } + + #[test] + fn test_handshake_result_map() { + let success: HandshakeResult = HandshakeResult::Success(42); + let mapped = success.map(|x| x * 2); + + match mapped { + HandshakeResult::Success(v) => assert_eq!(v, 84), + _ => panic!("Expected success"), + } + } + + #[test] + fn test_error_display() { + let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() }; + assert!(err.to_string().contains("1.2.3.4:443")); + + let err = ProxyError::InvalidProxyProtocol; + assert!(err.to_string().contains("proxy protocol")); + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..29e4bc1 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,158 @@ +//! Telemt - MTProxy on Rust + +use std::sync::Arc; +use std::net::SocketAddr; +use tokio::net::TcpListener; +use tokio::signal; +use tracing::{info, error, Level}; +use tracing_subscriber::{FmtSubscriber, EnvFilter}; + +mod error; +mod crypto; +mod protocol; +mod stream; +mod transport; +mod proxy; +mod config; +mod stats; +mod util; + +use config::ProxyConfig; +use stats::{Stats, ReplayChecker}; +use transport::ConnectionPool; +use proxy::ClientHandler; + +#[tokio::main] +async fn main() -> std::result::Result<(), Box> { + // Initialize logging with env filter + // Use RUST_LOG=debug or RUST_LOG=trace for more details + let filter = EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new("info")); + + let subscriber = FmtSubscriber::builder() + .with_env_filter(filter) + .with_target(true) + .with_thread_ids(false) + .with_file(false) + .with_line_number(false) + .finish(); + + tracing::subscriber::set_global_default(subscriber)?; + + // Load configuration + let config_path = std::env::args() + .nth(1) + .unwrap_or_else(|| "config.toml".to_string()); + + info!("Loading configuration from {}", config_path); + + let config = ProxyConfig::load(&config_path).unwrap_or_else(|e| { + error!("Failed to load config: {}", e); + info!("Using default configuration"); + ProxyConfig::default() + }); + + if let Err(e) = config.validate() { + error!("Invalid configuration: {}", e); + std::process::exit(1); + } + + let config = Arc::new(config); + + info!("Starting MTProto Proxy on port {}", config.port); + info!("Fast mode: {}", config.fast_mode); + info!("Modes: classic={}, secure={}, tls={}", + config.modes.classic, config.modes.secure, config.modes.tls); + + // Initialize components + let stats = Arc::new(Stats::new()); + let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len)); + let pool = Arc::new(ConnectionPool::new()); + + // Create handler + let handler = Arc::new(ClientHandler::new( + Arc::clone(&config), + Arc::clone(&stats), + Arc::clone(&replay_checker), + Arc::clone(&pool), + )); + + // Start listener + let addr: SocketAddr = format!("{}:{}", config.listen_addr_ipv4, config.port) + .parse()?; + + let listener = TcpListener::bind(addr).await?; + info!("Listening on {}", addr); + + // Print proxy links + print_proxy_links(&config); + + info!("Use RUST_LOG=debug or RUST_LOG=trace for more detailed logging"); + + // Main accept loop + let accept_loop = async { + loop { + match listener.accept().await { + Ok((stream, peer)) => { + let handler = Arc::clone(&handler); + tokio::spawn(async move { + handler.handle(stream, peer).await; + }); + } + Err(e) => { + error!("Accept error: {}", e); + } + } + } + }; + + // Graceful shutdown + tokio::select! { + _ = accept_loop => {} + _ = signal::ctrl_c() => { + info!("Shutting down..."); + } + } + + // Cleanup + pool.close_all().await; + + info!("Goodbye!"); + Ok(()) +} + +fn print_proxy_links(config: &ProxyConfig) { + println!("\n=== Proxy Links ===\n"); + + for (user, secret) in &config.users { + if config.modes.tls { + let tls_secret = format!( + "ee{}{}", + secret, + hex::encode(config.tls_domain.as_bytes()) + ); + println!( + "{} (TLS): tg://proxy?server=IP&port={}&secret={}", + user, config.port, tls_secret + ); + } + + if config.modes.secure { + println!( + "{} (Secure): tg://proxy?server=IP&port={}&secret=dd{}", + user, config.port, secret + ); + } + + if config.modes.classic { + println!( + "{} (Classic): tg://proxy?server=IP&port={}&secret={}", + user, config.port, secret + ); + } + + println!(); + } + + println!("===================\n"); +} \ No newline at end of file diff --git a/src/protocol/constants.rs b/src/protocol/constants.rs new file mode 100644 index 0000000..17857e4 --- /dev/null +++ b/src/protocol/constants.rs @@ -0,0 +1,261 @@ +//! Protocol constants and datacenter addresses + +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use once_cell::sync::Lazy; + +// ============= Telegram Datacenters ============= + +pub const TG_DATACENTER_PORT: u16 = 443; + +pub static TG_DATACENTERS_V4: Lazy> = Lazy::new(|| { + vec![ + IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), + IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)), + IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), + IpAddr::V4(Ipv4Addr::new(149, 154, 167, 91)), + IpAddr::V4(Ipv4Addr::new(149, 154, 171, 5)), + ] +}); + +pub static TG_DATACENTERS_V6: Lazy> = Lazy::new(|| { + vec![ + IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()), + IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()), + IpAddr::V6("2001:b28:f23d:f003::a".parse().unwrap()), + IpAddr::V6("2001:67c:04e8:f004::a".parse().unwrap()), + IpAddr::V6("2001:b28:f23f:f005::a".parse().unwrap()), + ] +}); + +// ============= Middle Proxies (for advertising) ============= + +pub static TG_MIDDLE_PROXIES_V4: Lazy>> = + Lazy::new(|| { + let mut m = std::collections::HashMap::new(); + m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); + m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); + m.insert(2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]); + m.insert(-2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]); + m.insert(3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]); + m.insert(-3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]); + m.insert(4, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888)]); + m.insert(-4, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 165, 109)), 8888)]); + m.insert(5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]); + m.insert(-5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]); + m + }); + +pub static TG_MIDDLE_PROXIES_V6: Lazy>> = + Lazy::new(|| { + let mut m = std::collections::HashMap::new(); + m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); + m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); + m.insert(2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]); + m.insert(-2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]); + m.insert(3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]); + m.insert(-3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]); + m.insert(4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]); + m.insert(-4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]); + m.insert(5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]); + m.insert(-5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]); + m + }); + +// ============= Protocol Tags ============= + +/// MTProto transport protocol variants +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u32)] +pub enum ProtoTag { + /// Abridged protocol - compact framing + Abridged = 0xefefefef, + /// Intermediate protocol - simple 4-byte length prefix + Intermediate = 0xeeeeeeee, + /// Secure intermediate - with random padding + Secure = 0xdddddddd, +} + +impl ProtoTag { + /// Parse protocol tag from 4 bytes + pub fn from_bytes(bytes: [u8; 4]) -> Option { + match u32::from_le_bytes(bytes) { + 0xefefefef => Some(ProtoTag::Abridged), + 0xeeeeeeee => Some(ProtoTag::Intermediate), + 0xdddddddd => Some(ProtoTag::Secure), + _ => None, + } + } + + /// Convert to 4 bytes (little-endian) + pub fn to_bytes(self) -> [u8; 4] { + (self as u32).to_le_bytes() + } + + /// Get protocol tag as bytes slice + pub fn as_bytes(&self) -> &'static [u8; 4] { + match self { + ProtoTag::Abridged => &PROTO_TAG_ABRIDGED, + ProtoTag::Intermediate => &PROTO_TAG_INTERMEDIATE, + ProtoTag::Secure => &PROTO_TAG_SECURE, + } + } +} + +/// Protocol tag bytes +pub const PROTO_TAG_ABRIDGED: [u8; 4] = [0xef, 0xef, 0xef, 0xef]; +pub const PROTO_TAG_INTERMEDIATE: [u8; 4] = [0xee, 0xee, 0xee, 0xee]; +pub const PROTO_TAG_SECURE: [u8; 4] = [0xdd, 0xdd, 0xdd, 0xdd]; + +// ============= Handshake Layout ============= + +/// Bytes to skip at the start of handshake +pub const SKIP_LEN: usize = 8; +/// Pre-key length (before hashing with secret) +pub const PREKEY_LEN: usize = 32; +/// AES key length +pub const KEY_LEN: usize = 32; +/// AES IV length +pub const IV_LEN: usize = 16; +/// Total handshake length +pub const HANDSHAKE_LEN: usize = 64; +/// Position of protocol tag in decrypted handshake +pub const PROTO_TAG_POS: usize = 56; +/// Position of datacenter index +pub const DC_IDX_POS: usize = 60; + +// ============= Message Limits ============= + +/// Minimum message length +pub const MIN_MSG_LEN: usize = 12; +/// Maximum message length (16 MB) +pub const MAX_MSG_LEN: usize = 1 << 24; +/// CBC block padding size +pub const CBC_PADDING: usize = 16; +/// Padding filler bytes +pub const PADDING_FILLER: [u8; 4] = [0x04, 0x00, 0x00, 0x00]; + +// ============= TLS Constants ============= + +/// Minimum certificate length for detection +pub const MIN_CERT_LEN: usize = 1024; +/// TLS 1.3 version bytes +pub const TLS_VERSION: [u8; 2] = [0x03, 0x03]; +/// TLS record type: Handshake +pub const TLS_RECORD_HANDSHAKE: u8 = 0x16; +/// TLS record type: Change Cipher Spec +pub const TLS_RECORD_CHANGE_CIPHER: u8 = 0x14; +/// TLS record type: Application Data +pub const TLS_RECORD_APPLICATION: u8 = 0x17; +/// TLS record type: Alert +pub const TLS_RECORD_ALERT: u8 = 0x15; +/// Maximum TLS record size +pub const MAX_TLS_RECORD_SIZE: usize = 16384; +/// Maximum TLS chunk size (with overhead) +pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 24; + +// ============= Timeouts ============= + +/// Default handshake timeout in seconds +pub const DEFAULT_HANDSHAKE_TIMEOUT_SECS: u64 = 10; +/// Default connect timeout in seconds +pub const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10; +/// Default keepalive interval in seconds +pub const DEFAULT_KEEPALIVE_SECS: u64 = 600; +/// Default ACK timeout in seconds +pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300; + +// ============= Buffer Sizes ============= + +/// Default buffer size +pub const DEFAULT_BUFFER_SIZE: usize = 65536; +/// Small buffer size for bad client handling +pub const SMALL_BUFFER_SIZE: usize = 8192; + +// ============= Statistics ============= + +/// Duration buckets for histogram metrics +pub static DURATION_BUCKETS: &[f64] = &[ + 0.1, 0.5, 1.0, 2.0, 5.0, 15.0, 60.0, 300.0, 600.0, 1800.0, +]; + +// ============= Reserved Nonce Patterns ============= + +/// Reserved first bytes of nonce (must avoid) +pub static RESERVED_NONCE_FIRST_BYTES: &[u8] = &[0xef]; + +/// Reserved 4-byte beginnings of nonce +pub static RESERVED_NONCE_BEGINNINGS: &[[u8; 4]] = &[ + [0x48, 0x45, 0x41, 0x44], // HEAD + [0x50, 0x4F, 0x53, 0x54], // POST + [0x47, 0x45, 0x54, 0x20], // GET + [0xee, 0xee, 0xee, 0xee], // Intermediate + [0xdd, 0xdd, 0xdd, 0xdd], // Secure + [0x16, 0x03, 0x01, 0x02], // TLS +]; + +/// Reserved continuation bytes (bytes 4-7) +pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[ + [0x00, 0x00, 0x00, 0x00], +]; + +// ============= RPC Constants (for Middle Proxy) ============= + +/// RPC Proxy Request +pub const RPC_PROXY_REQ: [u8; 4] = [0xee, 0xf1, 0xce, 0x36]; +/// RPC Proxy Answer +pub const RPC_PROXY_ANS: [u8; 4] = [0x0d, 0xda, 0x03, 0x44]; +/// RPC Close Extended +pub const RPC_CLOSE_EXT: [u8; 4] = [0xa2, 0x34, 0xb6, 0x5e]; +/// RPC Simple ACK +pub const RPC_SIMPLE_ACK: [u8; 4] = [0x9b, 0x40, 0xac, 0x3b]; +/// RPC Unknown +pub const RPC_UNKNOWN: [u8; 4] = [0xdf, 0xa2, 0x30, 0x57]; +/// RPC Handshake +pub const RPC_HANDSHAKE: [u8; 4] = [0xf5, 0xee, 0x82, 0x76]; +/// RPC Nonce +pub const RPC_NONCE: [u8; 4] = [0xaa, 0x87, 0xcb, 0x7a]; + +/// RPC Flags +pub mod rpc_flags { + pub const FLAG_NOT_ENCRYPTED: u32 = 0x2; + pub const FLAG_HAS_AD_TAG: u32 = 0x8; + pub const FLAG_MAGIC: u32 = 0x1000; + pub const FLAG_EXTMODE2: u32 = 0x20000; + pub const FLAG_PAD: u32 = 0x8000000; + pub const FLAG_INTERMEDIATE: u32 = 0x20000000; + pub const FLAG_ABRIDGED: u32 = 0x40000000; + pub const FLAG_QUICKACK: u32 = 0x80000000; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_proto_tag_roundtrip() { + for tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] { + let bytes = tag.to_bytes(); + let parsed = ProtoTag::from_bytes(bytes).unwrap(); + assert_eq!(tag, parsed); + } + } + + #[test] + fn test_proto_tag_values() { + assert_eq!(ProtoTag::Abridged.to_bytes(), PROTO_TAG_ABRIDGED); + assert_eq!(ProtoTag::Intermediate.to_bytes(), PROTO_TAG_INTERMEDIATE); + assert_eq!(ProtoTag::Secure.to_bytes(), PROTO_TAG_SECURE); + } + + #[test] + fn test_invalid_proto_tag() { + assert!(ProtoTag::from_bytes([0, 0, 0, 0]).is_none()); + assert!(ProtoTag::from_bytes([0xff, 0xff, 0xff, 0xff]).is_none()); + } + + #[test] + fn test_datacenters_count() { + assert_eq!(TG_DATACENTERS_V4.len(), 5); + assert_eq!(TG_DATACENTERS_V6.len(), 5); + } +} \ No newline at end of file diff --git a/src/protocol/frame.rs b/src/protocol/frame.rs new file mode 100644 index 0000000..f4517b4 --- /dev/null +++ b/src/protocol/frame.rs @@ -0,0 +1,120 @@ +//! MTProto frame types and metadata + +use std::collections::HashMap; + +/// Extra metadata associated with a frame +#[derive(Debug, Clone, Default)] +pub struct FrameExtra { + /// Quick ACK flag - request immediate acknowledgment + pub quickack: bool, + /// Simple ACK - this is an acknowledgment message + pub simple_ack: bool, + /// Skip sending - internal flag to skip forwarding + pub skip_send: bool, + /// Custom key-value metadata + pub custom: HashMap, +} + +impl FrameExtra { + /// Create new empty frame extra + pub fn new() -> Self { + Self::default() + } + + /// Create with quickack flag set + pub fn with_quickack() -> Self { + Self { + quickack: true, + ..Default::default() + } + } + + /// Create with simple_ack flag set + pub fn with_simple_ack() -> Self { + Self { + simple_ack: true, + ..Default::default() + } + } + + /// Check if any flags are set + pub fn has_flags(&self) -> bool { + self.quickack || self.simple_ack || self.skip_send + } +} + +/// Result of reading a frame +#[derive(Debug)] +pub enum FrameReadResult { + /// Successfully read a frame with data and metadata + Data(Vec, FrameExtra), + /// Connection closed normally + Closed, + /// Need more data (for non-blocking reads) + WouldBlock, +} + +/// Frame encoding/decoding mode +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FrameMode { + /// Abridged - 1 or 4 byte length prefix + Abridged, + /// Intermediate - 4 byte length prefix + Intermediate, + /// Secure Intermediate - 4 byte length with padding + SecureIntermediate, + /// Full MTProto - with seq_no and CRC32 + Full, +} + +impl FrameMode { + /// Get maximum overhead for this frame mode + pub fn max_overhead(&self) -> usize { + match self { + FrameMode::Abridged => 4, + FrameMode::Intermediate => 4, + FrameMode::SecureIntermediate => 4 + 3, // length + padding + FrameMode::Full => 12 + 16, // header + max CBC padding + } + } +} + +/// Validate message length for MTProto +pub fn validate_message_length(len: usize) -> bool { + use super::constants::{MIN_MSG_LEN, MAX_MSG_LEN, PADDING_FILLER}; + + len >= MIN_MSG_LEN && len <= MAX_MSG_LEN && len % PADDING_FILLER.len() == 0 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_frame_extra_default() { + let extra = FrameExtra::default(); + assert!(!extra.quickack); + assert!(!extra.simple_ack); + assert!(!extra.skip_send); + assert!(!extra.has_flags()); + } + + #[test] + fn test_frame_extra_flags() { + let extra = FrameExtra::with_quickack(); + assert!(extra.quickack); + assert!(extra.has_flags()); + + let extra = FrameExtra::with_simple_ack(); + assert!(extra.simple_ack); + assert!(extra.has_flags()); + } + + #[test] + fn test_validate_message_length() { + assert!(validate_message_length(12)); // MIN_MSG_LEN + assert!(validate_message_length(16)); + assert!(!validate_message_length(8)); // Too small + assert!(!validate_message_length(13)); // Not aligned to 4 + } +} \ No newline at end of file diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs new file mode 100644 index 0000000..4081f1c --- /dev/null +++ b/src/protocol/mod.rs @@ -0,0 +1,11 @@ +//! MTProto Defs + Cons + +pub mod constants; +pub mod frame; +pub mod obfuscation; +pub mod tls; + +pub use constants::*; +pub use frame::*; +pub use obfuscation::*; +pub use tls::*; \ No newline at end of file diff --git a/src/protocol/obfuscation.rs b/src/protocol/obfuscation.rs new file mode 100644 index 0000000..4e09942 --- /dev/null +++ b/src/protocol/obfuscation.rs @@ -0,0 +1,217 @@ +//! MTProto Obfuscation + +use crate::crypto::{sha256, AesCtr}; +use crate::error::Result; +use super::constants::*; + +/// Obfuscation parameters from handshake +#[derive(Debug, Clone)] +pub struct ObfuscationParams { + /// Key for decrypting client -> proxy traffic + pub decrypt_key: [u8; 32], + /// IV for decrypting client -> proxy traffic + pub decrypt_iv: u128, + /// Key for encrypting proxy -> client traffic + pub encrypt_key: [u8; 32], + /// IV for encrypting proxy -> client traffic + pub encrypt_iv: u128, + /// Protocol tag (abridged/intermediate/secure) + pub proto_tag: ProtoTag, + /// Datacenter index + pub dc_idx: i16, +} + +impl ObfuscationParams { + /// Parse obfuscation parameters from handshake bytes + /// Returns None if handshake doesn't match any user secret + pub fn from_handshake( + handshake: &[u8; HANDSHAKE_LEN], + secrets: &[(String, Vec)], // (username, secret_bytes) + ) -> Option<(Self, String)> { + // Extract prekey and IV for decryption + let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; + let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; + + // Reversed for encryption direction + let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); + let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; + let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; + + for (username, secret) in secrets { + // Derive decryption key + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(secret); + let decrypt_key = sha256(&dec_key_input); + + let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); + + // Create decryptor and decrypt handshake + let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv); + let decrypted = decryptor.decrypt(handshake); + + // Check protocol tag + let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] + .try_into() + .unwrap(); + + let proto_tag = match ProtoTag::from_bytes(tag_bytes) { + Some(tag) => tag, + None => continue, // Try next secret + }; + + // Extract DC index + let dc_idx = i16::from_le_bytes( + decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() + ); + + // Derive encryption key + let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + enc_key_input.extend_from_slice(enc_prekey); + enc_key_input.extend_from_slice(secret); + let encrypt_key = sha256(&enc_key_input); + let encrypt_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); + + return Some(( + ObfuscationParams { + decrypt_key, + decrypt_iv, + encrypt_key, + encrypt_iv, + proto_tag, + dc_idx, + }, + username.clone(), + )); + } + + None + } + + /// Create AES-CTR decryptor for client -> proxy direction + pub fn create_decryptor(&self) -> AesCtr { + AesCtr::new(&self.decrypt_key, self.decrypt_iv) + } + + /// Create AES-CTR encryptor for proxy -> client direction + pub fn create_encryptor(&self) -> AesCtr { + AesCtr::new(&self.encrypt_key, self.encrypt_iv) + } + + /// Get the combined encrypt key and IV for fast mode + pub fn enc_key_iv(&self) -> Vec { + let mut result = Vec::with_capacity(KEY_LEN + IV_LEN); + result.extend_from_slice(&self.encrypt_key); + result.extend_from_slice(&self.encrypt_iv.to_be_bytes()); + result + } +} + +/// Generate a valid random nonce for Telegram handshake +pub fn generate_nonce Vec>(mut random_bytes: R) -> [u8; HANDSHAKE_LEN] { + loop { + let nonce_vec = random_bytes(HANDSHAKE_LEN); + let mut nonce = [0u8; HANDSHAKE_LEN]; + nonce.copy_from_slice(&nonce_vec); + + if is_valid_nonce(&nonce) { + return nonce; + } + } +} + +/// Check if nonce is valid (not matching reserved patterns) +pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool { + // Check first byte + if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { + return false; + } + + // Check first 4 bytes + let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); + if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { + return false; + } + + // Check bytes 4-7 + let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); + if RESERVED_NONCE_CONTINUES.contains(&continue_four) { + return false; + } + + true +} + +/// Prepare nonce for sending to Telegram +pub fn prepare_tg_nonce( + nonce: &mut [u8; HANDSHAKE_LEN], + proto_tag: ProtoTag, + enc_key_iv: Option<&[u8]>, // For fast mode +) { + // Set protocol tag + nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + + // For fast mode, copy the reversed enc_key_iv + if let Some(key_iv) = enc_key_iv { + let reversed: Vec = key_iv.iter().rev().copied().collect(); + nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed); + } +} + +/// Encrypt the outgoing nonce for Telegram +pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { + // Derive encryption key from the nonce itself + let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + let enc_key = sha256(key_iv); + let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap()); + + let mut encryptor = AesCtr::new(&enc_key, enc_iv); + + // Only encrypt from PROTO_TAG_POS onwards + let mut result = nonce.to_vec(); + let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]); + result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part); + + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_valid_nonce() { + // Valid nonce + let mut valid = [0x42u8; HANDSHAKE_LEN]; + valid[4..8].copy_from_slice(&[1, 2, 3, 4]); + assert!(is_valid_nonce(&valid)); + + // Invalid: starts with 0xef + let mut invalid = [0x00u8; HANDSHAKE_LEN]; + invalid[0] = 0xef; + assert!(!is_valid_nonce(&invalid)); + + // Invalid: starts with HEAD + let mut invalid = [0x00u8; HANDSHAKE_LEN]; + invalid[..4].copy_from_slice(b"HEAD"); + assert!(!is_valid_nonce(&invalid)); + + // Invalid: bytes 4-7 are zeros + let mut invalid = [0x42u8; HANDSHAKE_LEN]; + invalid[4..8].copy_from_slice(&[0, 0, 0, 0]); + assert!(!is_valid_nonce(&invalid)); + } + + #[test] + fn test_generate_nonce() { + let mut counter = 0u8; + let nonce = generate_nonce(|n| { + counter = counter.wrapping_add(1); + vec![counter; n] + }); + + assert!(is_valid_nonce(&nonce)); + assert_eq!(nonce.len(), HANDSHAKE_LEN); + } +} \ No newline at end of file diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs new file mode 100644 index 0000000..f050d9f --- /dev/null +++ b/src/protocol/tls.rs @@ -0,0 +1,244 @@ +//! Fake TLS 1.3 Handshake + +use crate::crypto::{sha256_hmac, random::SECURE_RANDOM}; +use crate::error::{ProxyError, Result}; +use super::constants::*; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// TLS handshake digest length +pub const TLS_DIGEST_LEN: usize = 32; +/// Position of digest in TLS ClientHello +pub const TLS_DIGEST_POS: usize = 11; +/// Length to store for replay protection (first 16 bytes of digest) +pub const TLS_DIGEST_HALF_LEN: usize = 16; + +/// Time skew limits for anti-replay (in seconds) +pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before +pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after + +/// Result of validating TLS handshake +#[derive(Debug)] +pub struct TlsValidation { + /// Username that validated + pub user: String, + /// Session ID from ClientHello + pub session_id: Vec, + /// Client digest for response generation + pub digest: [u8; TLS_DIGEST_LEN], + /// Timestamp extracted from digest + pub timestamp: u32, +} + +/// Validate TLS ClientHello against user secrets +pub fn validate_tls_handshake( + handshake: &[u8], + secrets: &[(String, Vec)], + ignore_time_skew: bool, +) -> Option { + if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 { + return None; + } + + // Extract digest + let digest: [u8; TLS_DIGEST_LEN] = handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .try_into() + .ok()?; + + // Extract session ID + let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN; + let session_id_len = handshake.get(session_id_len_pos).copied()? as usize; + let session_id_start = session_id_len_pos + 1; + + if handshake.len() < session_id_start + session_id_len { + return None; + } + + let session_id = handshake[session_id_start..session_id_start + session_id_len].to_vec(); + + // Build message for HMAC (with zeroed digest) + let mut msg = handshake.to_vec(); + msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + // Get current time + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + + for (user, secret) in secrets { + let computed = sha256_hmac(secret, &msg); + + // XOR digests + let xored: Vec = digest.iter() + .zip(computed.iter()) + .map(|(a, b)| a ^ b) + .collect(); + + // Check that first 28 bytes are zeros (timestamp in last 4) + if !xored[..28].iter().all(|&b| b == 0) { + continue; + } + + // Extract timestamp + let timestamp = u32::from_le_bytes(xored[28..32].try_into().unwrap()); + let time_diff = now - timestamp as i64; + + // Check time skew + if !ignore_time_skew { + // Allow very small timestamps (boot time instead of unix time) + let is_boot_time = timestamp < 60 * 60 * 24 * 1000; + + if !is_boot_time && (time_diff < TIME_SKEW_MIN || time_diff > TIME_SKEW_MAX) { + continue; + } + } + + return Some(TlsValidation { + user: user.clone(), + session_id, + digest, + timestamp, + }); + } + + None +} + +/// Generate a fake X25519 public key for TLS +/// This generates a value that looks like a valid X25519 key +pub fn gen_fake_x25519_key() -> [u8; 32] { + // For simplicity, just generate random 32 bytes + // In real X25519, this would be a point on the curve + let bytes = SECURE_RANDOM.bytes(32); + bytes.try_into().unwrap() +} + +/// Build TLS ServerHello response +pub fn build_server_hello( + secret: &[u8], + client_digest: &[u8; TLS_DIGEST_LEN], + session_id: &[u8], + fake_cert_len: usize, +) -> Vec { + let x25519_key = gen_fake_x25519_key(); + + // TLS extensions + let mut extensions = Vec::new(); + extensions.extend_from_slice(&[0x00, 0x2e]); // Extension length placeholder + extensions.extend_from_slice(&[0x00, 0x33, 0x00, 0x24]); // Key share extension + extensions.extend_from_slice(&[0x00, 0x1d, 0x00, 0x20]); // X25519 curve + extensions.extend_from_slice(&x25519_key); + extensions.extend_from_slice(&[0x00, 0x2b, 0x00, 0x02, 0x03, 0x04]); // Supported versions + + // ServerHello body + let mut srv_hello = Vec::new(); + srv_hello.extend_from_slice(&TLS_VERSION); + srv_hello.extend_from_slice(&[0u8; TLS_DIGEST_LEN]); // Placeholder for digest + srv_hello.push(session_id.len() as u8); + srv_hello.extend_from_slice(session_id); + srv_hello.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256 + srv_hello.push(0x00); // No compression + srv_hello.extend_from_slice(&extensions); + + // Build complete packet + let mut hello_pkt = Vec::new(); + + // ServerHello record + hello_pkt.push(TLS_RECORD_HANDSHAKE); + hello_pkt.extend_from_slice(&TLS_VERSION); + hello_pkt.extend_from_slice(&((srv_hello.len() + 4) as u16).to_be_bytes()); + hello_pkt.push(0x02); // ServerHello message type + let len_bytes = (srv_hello.len() as u32).to_be_bytes(); + hello_pkt.extend_from_slice(&len_bytes[1..4]); // 3-byte length + hello_pkt.extend_from_slice(&srv_hello); + + // Change Cipher Spec record + hello_pkt.extend_from_slice(&[ + TLS_RECORD_CHANGE_CIPHER, + TLS_VERSION[0], TLS_VERSION[1], + 0x00, 0x01, 0x01 + ]); + + // Application Data record (fake certificate) + let fake_cert = SECURE_RANDOM.bytes(fake_cert_len); + hello_pkt.push(TLS_RECORD_APPLICATION); + hello_pkt.extend_from_slice(&TLS_VERSION); + hello_pkt.extend_from_slice(&(fake_cert.len() as u16).to_be_bytes()); + hello_pkt.extend_from_slice(&fake_cert); + + // Compute HMAC for the response + let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + hello_pkt.len()); + hmac_input.extend_from_slice(client_digest); + hmac_input.extend_from_slice(&hello_pkt); + let response_digest = sha256_hmac(secret, &hmac_input); + + // Insert computed digest + // Position: after record header (5) + message type/length (4) + version (2) = 11 + hello_pkt[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .copy_from_slice(&response_digest); + + hello_pkt +} + +/// Check if bytes look like a TLS ClientHello +pub fn is_tls_handshake(first_bytes: &[u8]) -> bool { + if first_bytes.len() < 3 { + return false; + } + + // TLS record header: 0x16 0x03 0x01 + first_bytes[0] == TLS_RECORD_HANDSHAKE + && first_bytes[1] == 0x03 + && first_bytes[2] == 0x01 +} + +/// Parse TLS record header, returns (record_type, length) +pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> { + let record_type = header[0]; + let version = [header[1], header[2]]; + + // We accept both TLS 1.0 header (for ClientHello) and TLS 1.2/1.3 + if version != [0x03, 0x01] && version != TLS_VERSION { + return None; + } + + let length = u16::from_be_bytes([header[3], header[4]]); + Some((record_type, length)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_tls_handshake() { + assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); + assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00])); + assert!(!is_tls_handshake(&[0x17, 0x03, 0x01])); // Application data + assert!(!is_tls_handshake(&[0x16, 0x03, 0x02])); // Wrong version + assert!(!is_tls_handshake(&[0x16, 0x03])); // Too short + } + + #[test] + fn test_parse_tls_record_header() { + let header = [0x16, 0x03, 0x01, 0x02, 0x00]; + let result = parse_tls_record_header(&header).unwrap(); + assert_eq!(result.0, TLS_RECORD_HANDSHAKE); + assert_eq!(result.1, 512); + + let header = [0x17, 0x03, 0x03, 0x40, 0x00]; + let result = parse_tls_record_header(&header).unwrap(); + assert_eq!(result.0, TLS_RECORD_APPLICATION); + assert_eq!(result.1, 16384); + } + + #[test] + fn test_gen_fake_x25519_key() { + let key1 = gen_fake_x25519_key(); + let key2 = gen_fake_x25519_key(); + + assert_eq!(key1.len(), 32); + assert_eq!(key2.len(), 32); + assert_ne!(key1, key2); // Should be random + } +} \ No newline at end of file diff --git a/src/proxy/client.rs b/src/proxy/client.rs new file mode 100644 index 0000000..6af001c --- /dev/null +++ b/src/proxy/client.rs @@ -0,0 +1,378 @@ +//! Client Handler + +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; +use tokio::net::TcpStream; +use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; +use tokio::time::timeout; +use tracing::{debug, info, warn, error, trace}; + +use crate::config::ProxyConfig; +use crate::error::{ProxyError, Result, HandshakeResult}; +use crate::protocol::constants::*; +use crate::protocol::tls; +use crate::stats::{Stats, ReplayChecker}; +use crate::transport::{ConnectionPool, configure_client_socket}; +use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter}; +use crate::crypto::AesCtr; + +use super::handshake::{ + handle_tls_handshake, handle_mtproto_handshake, + HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce, +}; +use super::relay::relay_bidirectional; +use super::masking::handle_bad_client; + +/// Client connection handler +pub struct ClientHandler { + config: Arc, + stats: Arc, + replay_checker: Arc, + pool: Arc, +} + +impl ClientHandler { + /// Create new client handler + pub fn new( + config: Arc, + stats: Arc, + replay_checker: Arc, + pool: Arc, + ) -> Self { + Self { + config, + stats, + replay_checker, + pool, + } + } + + /// Handle a client connection + pub async fn handle(&self, stream: TcpStream, peer: SocketAddr) { + self.stats.increment_connects_all(); + + debug!(peer = %peer, "New connection"); + + // Configure socket + if let Err(e) = configure_client_socket( + &stream, + self.config.client_keepalive, + self.config.client_ack_timeout, + ) { + debug!(peer = %peer, error = %e, "Failed to configure client socket"); + } + + // Perform handshake with timeout + let handshake_timeout = Duration::from_secs(self.config.client_handshake_timeout); + + let result = timeout( + handshake_timeout, + self.do_handshake(stream, peer) + ).await; + + match result { + Ok(Ok(())) => { + debug!(peer = %peer, "Connection handled successfully"); + } + Ok(Err(e)) => { + debug!(peer = %peer, error = %e, "Handshake failed"); + } + Err(_) => { + self.stats.increment_handshake_timeouts(); + debug!(peer = %peer, "Handshake timeout"); + } + } + } + + /// Perform handshake and relay + async fn do_handshake(&self, mut stream: TcpStream, peer: SocketAddr) -> Result<()> { + // Read first bytes to determine handshake type + let mut first_bytes = [0u8; 5]; + stream.read_exact(&mut first_bytes).await?; + + let is_tls = tls::is_tls_handshake(&first_bytes[..3]); + + debug!(peer = %peer, is_tls = is_tls, first_bytes = %hex::encode(&first_bytes), "Handshake type detected"); + + if is_tls { + self.handle_tls_client(stream, peer, first_bytes).await + } else { + self.handle_direct_client(stream, peer, first_bytes).await + } + } + + /// Handle TLS-wrapped client + async fn handle_tls_client( + &self, + mut stream: TcpStream, + peer: SocketAddr, + first_bytes: [u8; 5], + ) -> Result<()> { + // Read TLS handshake length + let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; + + debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake"); + + if tls_len < 512 { + debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); + self.stats.increment_connects_bad(); + handle_bad_client(stream, &first_bytes, &self.config).await; + return Ok(()); + } + + // Read full TLS handshake + let mut handshake = vec![0u8; 5 + tls_len]; + handshake[..5].copy_from_slice(&first_bytes); + stream.read_exact(&mut handshake[5..]).await?; + + // Split stream for reading/writing + let (read_half, write_half) = stream.into_split(); + + // Handle TLS handshake + let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( + &handshake, + read_half, + write_half, + peer, + &self.config, + &self.replay_checker, + ).await { + HandshakeResult::Success(result) => result, + HandshakeResult::BadClient => { + self.stats.increment_connects_bad(); + return Ok(()); + } + HandshakeResult::Error(e) => return Err(e), + }; + + // Read MTProto handshake through TLS + debug!(peer = %peer, "Reading MTProto handshake through TLS"); + let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; + let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into() + .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?; + + // Handle MTProto handshake + let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( + &mtproto_handshake, + tls_reader, + tls_writer, + peer, + &self.config, + &self.replay_checker, + true, + ).await { + HandshakeResult::Success(result) => result, + HandshakeResult::BadClient => { + self.stats.increment_connects_bad(); + return Ok(()); + } + HandshakeResult::Error(e) => return Err(e), + }; + + // Handle authenticated client + self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await + } + + /// Handle direct (non-TLS) client + async fn handle_direct_client( + &self, + mut stream: TcpStream, + peer: SocketAddr, + first_bytes: [u8; 5], + ) -> Result<()> { + // Check if non-TLS modes are enabled + if !self.config.modes.classic && !self.config.modes.secure { + debug!(peer = %peer, "Non-TLS modes disabled"); + self.stats.increment_connects_bad(); + handle_bad_client(stream, &first_bytes, &self.config).await; + return Ok(()); + } + + // Read rest of handshake + let mut handshake = [0u8; HANDSHAKE_LEN]; + handshake[..5].copy_from_slice(&first_bytes); + stream.read_exact(&mut handshake[5..]).await?; + + // Split stream + let (read_half, write_half) = stream.into_split(); + + // Handle MTProto handshake + let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( + &handshake, + read_half, + write_half, + peer, + &self.config, + &self.replay_checker, + false, + ).await { + HandshakeResult::Success(result) => result, + HandshakeResult::BadClient => { + self.stats.increment_connects_bad(); + return Ok(()); + } + HandshakeResult::Error(e) => return Err(e), + }; + + self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await + } + + /// Handle authenticated client - connect to Telegram and relay + async fn handle_authenticated_inner( + &self, + client_reader: CryptoReader, + client_writer: CryptoWriter, + success: HandshakeSuccess, + ) -> Result<()> + where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, + { + let user = &success.user; + + // Check user limits + if let Err(e) = self.check_user_limits(user) { + warn!(user = %user, error = %e, "User limit exceeded"); + return Err(e); + } + + // Get datacenter address + let dc_addr = self.get_dc_addr(success.dc_idx)?; + + info!( + user = %user, + peer = %success.peer, + dc = success.dc_idx, + dc_addr = %dc_addr, + proto = ?success.proto_tag, + fast_mode = self.config.fast_mode, + "Connecting to Telegram" + ); + + // Connect to Telegram + let tg_stream = self.pool.get(dc_addr).await?; + + debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected to Telegram, performing handshake"); + + // Perform Telegram handshake and get crypto streams + let (tg_reader, tg_writer) = self.do_tg_handshake( + tg_stream, + &success, + ).await?; + + debug!(peer = %success.peer, "Telegram handshake complete, starting relay"); + + // Update stats + self.stats.increment_user_connects(user); + self.stats.increment_user_curr_connects(user); + + // Relay traffic - передаём Arc::clone(&self.stats) + let relay_result = relay_bidirectional( + client_reader, + client_writer, + tg_reader, + tg_writer, + user, + Arc::clone(&self.stats), + ).await; + + // Update stats + self.stats.decrement_user_curr_connects(user); + + match &relay_result { + Ok(()) => debug!(user = %user, peer = %success.peer, "Relay completed normally"), + Err(e) => debug!(user = %user, peer = %success.peer, error = %e, "Relay ended with error"), + } + + relay_result + } + + /// Check user limits (expiration, connection count, data quota) + fn check_user_limits(&self, user: &str) -> Result<()> { + // Check expiration + if let Some(expiration) = self.config.user_expirations.get(user) { + if chrono::Utc::now() > *expiration { + return Err(ProxyError::UserExpired { user: user.to_string() }); + } + } + + // Check connection limit + if let Some(limit) = self.config.user_max_tcp_conns.get(user) { + let current = self.stats.get_user_curr_connects(user); + if current >= *limit as u64 { + return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() }); + } + } + + // Check data quota + if let Some(quota) = self.config.user_data_quota.get(user) { + let used = self.stats.get_user_total_octets(user); + if used >= *quota { + return Err(ProxyError::DataQuotaExceeded { user: user.to_string() }); + } + } + + Ok(()) + } + + /// Get datacenter address by index + fn get_dc_addr(&self, dc_idx: i16) -> Result { + let idx = (dc_idx.abs() - 1) as usize; + + let datacenters = if self.config.prefer_ipv6 { + &*TG_DATACENTERS_V6 + } else { + &*TG_DATACENTERS_V4 + }; + + datacenters.get(idx) + .map(|ip| SocketAddr::new(*ip, TG_DATACENTER_PORT)) + .ok_or_else(|| ProxyError::InvalidHandshake( + format!("Invalid DC index: {}", dc_idx) + )) + } + + /// Perform handshake with Telegram server + /// Returns crypto reader and writer for TG connection + async fn do_tg_handshake( + &self, + mut stream: TcpStream, + success: &HandshakeSuccess, + ) -> Result<(CryptoReader, CryptoWriter)> { + // Generate nonce with keys for TG + let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce( + success.proto_tag, + &success.dec_key, // Client's dec key + success.dec_iv, + self.config.fast_mode, + ); + + // Encrypt nonce + let encrypted_nonce = encrypt_tg_nonce(&nonce); + + debug!( + peer = %success.peer, + nonce_head = %hex::encode(&nonce[..16]), + encrypted_head = %hex::encode(&encrypted_nonce[..16]), + "Sending nonce to Telegram" + ); + + // Send to Telegram + stream.write_all(&encrypted_nonce).await?; + stream.flush().await?; + + debug!(peer = %success.peer, "Nonce sent to Telegram"); + + // Split stream and wrap with crypto + let (read_half, write_half) = stream.into_split(); + + let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv); + let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv); + + let tg_reader = CryptoReader::new(read_half, decryptor); + let tg_writer = CryptoWriter::new(write_half, encryptor); + + Ok((tg_reader, tg_writer)) + } +} \ No newline at end of file diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs new file mode 100644 index 0000000..19e8baa --- /dev/null +++ b/src/proxy/handshake.rs @@ -0,0 +1,411 @@ +//! MTProto Handshake Magics + +use std::net::SocketAddr; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tracing::{debug, warn, trace, info}; + +use crate::crypto::{sha256, AesCtr}; +use crate::crypto::random::SECURE_RANDOM; +use crate::protocol::constants::*; +use crate::protocol::tls; +use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter}; +use crate::error::{ProxyError, HandshakeResult}; +use crate::stats::ReplayChecker; +use crate::config::ProxyConfig; + +/// Result of successful handshake +#[derive(Debug, Clone)] +pub struct HandshakeSuccess { + /// Authenticated user name + pub user: String, + /// Target datacenter index + pub dc_idx: i16, + /// Protocol variant (abridged/intermediate/secure) + pub proto_tag: ProtoTag, + /// Decryption key and IV (for reading from client) + pub dec_key: [u8; 32], + pub dec_iv: u128, + /// Encryption key and IV (for writing to client) + pub enc_key: [u8; 32], + pub enc_iv: u128, + /// Client address + pub peer: SocketAddr, + /// Whether TLS was used + pub is_tls: bool, +} + +/// Handle fake TLS handshake +pub async fn handle_tls_handshake( + handshake: &[u8], + reader: R, + mut writer: W, + peer: SocketAddr, + config: &ProxyConfig, + replay_checker: &ReplayChecker, +) -> HandshakeResult<(FakeTlsReader, FakeTlsWriter, String)> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); + + // Check minimum length + if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { + debug!(peer = %peer, "TLS handshake too short"); + return HandshakeResult::BadClient; + } + + // Extract digest for replay check + let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]; + let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN]; + + // Check for replay + if replay_checker.check_tls_digest(digest_half) { + warn!(peer = %peer, "TLS replay attack detected"); + return HandshakeResult::BadClient; + } + + // Build secrets list + let secrets: Vec<(String, Vec)> = config.users.iter() + .filter_map(|(name, hex)| { + hex::decode(hex).ok().map(|bytes| (name.clone(), bytes)) + }) + .collect(); + + debug!(peer = %peer, num_users = secrets.len(), "Validating TLS handshake against users"); + + // Validate handshake + let validation = match tls::validate_tls_handshake( + handshake, + &secrets, + config.ignore_time_skew, + ) { + Some(v) => v, + None => { + debug!(peer = %peer, "TLS handshake validation failed - no matching user"); + return HandshakeResult::BadClient; + } + }; + + // Get secret for response + let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { + Some((_, s)) => s, + None => return HandshakeResult::BadClient, + }; + + // Build and send response + let response = tls::build_server_hello( + secret, + &validation.digest, + &validation.session_id, + config.fake_cert_len, + ); + + debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); + + if let Err(e) = writer.write_all(&response).await { + return HandshakeResult::Error(ProxyError::Io(e)); + } + + if let Err(e) = writer.flush().await { + return HandshakeResult::Error(ProxyError::Io(e)); + } + + // Record for replay protection + replay_checker.add_tls_digest(digest_half); + + info!( + peer = %peer, + user = %validation.user, + "TLS handshake successful" + ); + + HandshakeResult::Success(( + FakeTlsReader::new(reader), + FakeTlsWriter::new(writer), + validation.user, + )) +} + +/// Handle MTProto obfuscation handshake +pub async fn handle_mtproto_handshake( + handshake: &[u8; HANDSHAKE_LEN], + reader: R, + writer: W, + peer: SocketAddr, + config: &ProxyConfig, + replay_checker: &ReplayChecker, + is_tls: bool, +) -> HandshakeResult<(CryptoReader, CryptoWriter, HandshakeSuccess)> +where + R: AsyncRead + Unpin + Send, + W: AsyncWrite + Unpin + Send, +{ + trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); + + // Extract prekey and IV + let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + debug!( + peer = %peer, + dec_prekey_iv = %hex::encode(dec_prekey_iv), + "Extracted prekey+IV from handshake" + ); + + // Check for replay + if replay_checker.check_handshake(dec_prekey_iv) { + warn!(peer = %peer, "MTProto replay attack detected"); + return HandshakeResult::BadClient; + } + + // Reversed for encryption direction + let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); + + // Try each user's secret + for (user, secret_hex) in &config.users { + let secret = match hex::decode(secret_hex) { + Ok(s) => s, + Err(_) => continue, + }; + + // Derive decryption key + let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; + let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); + + // Decrypt handshake to check protocol tag + let mut decryptor = AesCtr::new(&dec_key, dec_iv); + let decrypted = decryptor.decrypt(handshake); + + trace!( + peer = %peer, + user = %user, + decrypted_tail = %hex::encode(&decrypted[PROTO_TAG_POS..]), + "Decrypted handshake tail" + ); + + // Check protocol tag + let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] + .try_into() + .unwrap(); + + let proto_tag = match ProtoTag::from_bytes(tag_bytes) { + Some(tag) => tag, + None => { + trace!(peer = %peer, user = %user, tag = %hex::encode(tag_bytes), "Invalid proto tag"); + continue; + } + }; + + debug!(peer = %peer, user = %user, proto = ?proto_tag, "Found valid proto tag"); + + // Check if mode is enabled + let mode_ok = match proto_tag { + ProtoTag::Secure => { + if is_tls { config.modes.tls } else { config.modes.secure } + } + ProtoTag::Intermediate | ProtoTag::Abridged => config.modes.classic, + }; + + if !mode_ok { + debug!(peer = %peer, user = %user, proto = ?proto_tag, "Mode not enabled"); + continue; + } + + // Extract DC index + let dc_idx = i16::from_le_bytes( + decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() + ); + + // Derive encryption key + let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; + let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; + + let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + enc_key_input.extend_from_slice(enc_prekey); + enc_key_input.extend_from_slice(&secret); + let enc_key = sha256(&enc_key_input); + + let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); + + // Record for replay protection + replay_checker.add_handshake(dec_prekey_iv); + + // Create new cipher instances + let decryptor = AesCtr::new(&dec_key, dec_iv); + let encryptor = AesCtr::new(&enc_key, enc_iv); + + let success = HandshakeSuccess { + user: user.clone(), + dc_idx, + proto_tag, + dec_key, + dec_iv, + enc_key, + enc_iv, + peer, + is_tls, + }; + + info!( + peer = %peer, + user = %user, + dc = dc_idx, + proto = ?proto_tag, + tls = is_tls, + "MTProto handshake successful" + ); + + return HandshakeResult::Success(( + CryptoReader::new(reader, decryptor), + CryptoWriter::new(writer, encryptor), + success, + )); + } + + debug!(peer = %peer, "MTProto handshake: no matching user found"); + HandshakeResult::BadClient +} + +/// Generate nonce for Telegram connection +/// +/// In FAST MODE: we use the same keys for TG as for client, but reversed. +/// This means: client's enc_key becomes TG's dec_key and vice versa. +pub fn generate_tg_nonce( + proto_tag: ProtoTag, + client_dec_key: &[u8; 32], + client_dec_iv: u128, + fast_mode: bool, +) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) { + loop { + let bytes = SECURE_RANDOM.bytes(HANDSHAKE_LEN); + let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap(); + + // Check reserved patterns + if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { + continue; + } + + let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); + if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { + continue; + } + + let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); + if RESERVED_NONCE_CONTINUES.contains(&continue_four) { + continue; + } + + // Set protocol tag + nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + + // Fast mode: copy client's dec_key+iv (this becomes TG's enc direction) + // In fast mode, we make TG use the same keys as client but swapped: + // - What we decrypt FROM TG = what we encrypt TO client (so no re-encryption needed) + // - What we encrypt TO TG = what we decrypt FROM client + if fast_mode { + // Put client's dec_key + dec_iv into nonce[8:56] + // This will be used by TG for encryption TO us + nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key); + nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN] + .copy_from_slice(&client_dec_iv.to_be_bytes()); + } + + // Now compute what keys WE will use for TG connection + // enc_key_iv = nonce[8:56] (for encrypting TO TG) + // dec_key_iv = nonce[8:56] reversed (for decrypting FROM TG) + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); + + let tg_enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); + let tg_enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); + + let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap(); + let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap()); + + debug!( + fast_mode = fast_mode, + tg_enc_key = %hex::encode(&tg_enc_key[..8]), + tg_dec_key = %hex::encode(&tg_dec_key[..8]), + "Generated TG nonce" + ); + + return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv); + } +} + +/// Encrypt nonce for sending to Telegram +/// +/// Only the part from PROTO_TAG_POS onwards is encrypted. +/// The encryption key is derived from enc_key_iv in the nonce itself. +pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { + // enc_key_iv is at nonce[8:56] + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + + // Key for encrypting is just the first 32 bytes of enc_key_iv + let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); + let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); + + let mut encryptor = AesCtr::new(&key, iv); + + // Encrypt the entire nonce first, then take only the encrypted tail + let encrypted_full = encryptor.encrypt(nonce); + + // Result: unencrypted head + encrypted tail + let mut result = nonce[..PROTO_TAG_POS].to_vec(); + result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]); + + trace!( + original = %hex::encode(&nonce[PROTO_TAG_POS..]), + encrypted = %hex::encode(&result[PROTO_TAG_POS..]), + "Encrypted nonce tail" + ); + + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_tg_nonce() { + let client_dec_key = [0x42u8; 32]; + let client_dec_iv = 12345u128; + + let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = + generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false); + + // Check length + assert_eq!(nonce.len(), HANDSHAKE_LEN); + + // Check proto tag is set + let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); + assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); + } + + #[test] + fn test_encrypt_tg_nonce() { + let client_dec_key = [0x42u8; 32]; + let client_dec_iv = 12345u128; + + let (nonce, _, _, _, _) = + generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false); + + let encrypted = encrypt_tg_nonce(&nonce); + + assert_eq!(encrypted.len(), HANDSHAKE_LEN); + + // First PROTO_TAG_POS bytes should be unchanged + assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); + + // Rest should be different (encrypted) + assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); + } +} \ No newline at end of file diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs new file mode 100644 index 0000000..485ae53 --- /dev/null +++ b/src/proxy/masking.rs @@ -0,0 +1,115 @@ +//! Masking - forward unrecognized traffic to mask host + +use std::time::Duration; +use tokio::net::TcpStream; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::time::timeout; +use tracing::debug; +use crate::config::ProxyConfig; +use crate::transport::set_linger_zero; + +const MASK_TIMEOUT: Duration = Duration::from_secs(5); +const MASK_BUFFER_SIZE: usize = 8192; + +/// Handle a bad client by forwarding to mask host +pub async fn handle_bad_client( + mut client: TcpStream, + initial_data: &[u8], + config: &ProxyConfig, +) { + if !config.mask { + // Masking disabled, just consume data + consume_client_data(client).await; + return; + } + + let mask_host = config.mask_host.as_deref() + .unwrap_or(&config.tls_domain); + let mask_port = config.mask_port; + + debug!( + host = %mask_host, + port = mask_port, + "Forwarding bad client to mask host" + ); + + // Connect to mask host + let mask_addr = format!("{}:{}", mask_host, mask_port); + let connect_result = timeout( + MASK_TIMEOUT, + TcpStream::connect(&mask_addr) + ).await; + + let mut mask_stream = match connect_result { + Ok(Ok(s)) => s, + Ok(Err(e)) => { + debug!(error = %e, "Failed to connect to mask host"); + consume_client_data(client).await; + return; + } + Err(_) => { + debug!("Timeout connecting to mask host"); + consume_client_data(client).await; + return; + } + }; + + // Send initial data to mask host + if mask_stream.write_all(initial_data).await.is_err() { + return; + } + + // Relay traffic + let (mut client_read, mut client_write) = client.into_split(); + let (mut mask_read, mut mask_write) = mask_stream.into_split(); + + let c2m = tokio::spawn(async move { + let mut buf = vec![0u8; MASK_BUFFER_SIZE]; + loop { + match client_read.read(&mut buf).await { + Ok(0) | Err(_) => { + let _ = mask_write.shutdown().await; + break; + } + Ok(n) => { + if mask_write.write_all(&buf[..n]).await.is_err() { + break; + } + } + } + } + }); + + let m2c = tokio::spawn(async move { + let mut buf = vec![0u8; MASK_BUFFER_SIZE]; + loop { + match mask_read.read(&mut buf).await { + Ok(0) | Err(_) => { + let _ = client_write.shutdown().await; + break; + } + Ok(n) => { + if client_write.write_all(&buf[..n]).await.is_err() { + break; + } + } + } + } + }); + + // Wait for either to complete + tokio::select! { + _ = c2m => {} + _ = m2c => {} + } +} + +/// Just consume all data from client without responding +async fn consume_client_data(mut client: TcpStream) { + let mut buf = vec![0u8; MASK_BUFFER_SIZE]; + while let Ok(n) = client.read(&mut buf).await { + if n == 0 { + break; + } + } +} \ No newline at end of file diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs new file mode 100644 index 0000000..92dd373 --- /dev/null +++ b/src/proxy/mod.rs @@ -0,0 +1,11 @@ +//! Proxy Defs + +pub mod handshake; +pub mod client; +pub mod relay; +pub mod masking; + +pub use handshake::*; +pub use client::ClientHandler; +pub use relay::*; +pub use masking::*; \ No newline at end of file diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs new file mode 100644 index 0000000..c531a1a --- /dev/null +++ b/src/proxy/relay.rs @@ -0,0 +1,162 @@ +//! Bidirectional Relay + +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; +use tracing::{debug, trace, warn}; +use crate::error::Result; +use crate::stats::Stats; +use std::sync::atomic::{AtomicU64, Ordering}; + +const BUFFER_SIZE: usize = 65536; + +/// Relay data bidirectionally between client and server +pub async fn relay_bidirectional( + mut client_reader: CR, + mut client_writer: CW, + mut server_reader: SR, + mut server_writer: SW, + user: &str, + stats: Arc, +) -> Result<()> +where + CR: AsyncRead + Unpin + Send + 'static, + CW: AsyncWrite + Unpin + Send + 'static, + SR: AsyncRead + Unpin + Send + 'static, + SW: AsyncWrite + Unpin + Send + 'static, +{ + let user_c2s = user.to_string(); + let user_s2c = user.to_string(); + + // Используем Arc::clone вместо stats.clone() + let stats_c2s = Arc::clone(&stats); + let stats_s2c = Arc::clone(&stats); + + let c2s_bytes = Arc::new(AtomicU64::new(0)); + let s2c_bytes = Arc::new(AtomicU64::new(0)); + let c2s_bytes_clone = Arc::clone(&c2s_bytes); + let s2c_bytes_clone = Arc::clone(&s2c_bytes); + + // Client -> Server task + let c2s = tokio::spawn(async move { + let mut buf = vec![0u8; BUFFER_SIZE]; + let mut total_bytes = 0u64; + let mut msg_count = 0u64; + + loop { + match client_reader.read(&mut buf).await { + Ok(0) => { + debug!( + user = %user_c2s, + total_bytes = total_bytes, + msgs = msg_count, + "Client closed connection (C->S)" + ); + let _ = server_writer.shutdown().await; + break; + } + Ok(n) => { + total_bytes += n as u64; + msg_count += 1; + c2s_bytes_clone.store(total_bytes, Ordering::Relaxed); + + stats_c2s.add_user_octets_from(&user_c2s, n as u64); + stats_c2s.increment_user_msgs_from(&user_c2s); + + trace!( + user = %user_c2s, + bytes = n, + total = total_bytes, + data_preview = %hex::encode(&buf[..n.min(32)]), + "C->S data" + ); + + if let Err(e) = server_writer.write_all(&buf[..n]).await { + debug!(user = %user_c2s, error = %e, "Failed to write to server"); + break; + } + if let Err(e) = server_writer.flush().await { + debug!(user = %user_c2s, error = %e, "Failed to flush to server"); + break; + } + } + Err(e) => { + debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error"); + break; + } + } + } + }); + + // Server -> Client task + let s2c = tokio::spawn(async move { + let mut buf = vec![0u8; BUFFER_SIZE]; + let mut total_bytes = 0u64; + let mut msg_count = 0u64; + + loop { + match server_reader.read(&mut buf).await { + Ok(0) => { + debug!( + user = %user_s2c, + total_bytes = total_bytes, + msgs = msg_count, + "Server closed connection (S->C)" + ); + let _ = client_writer.shutdown().await; + break; + } + Ok(n) => { + total_bytes += n as u64; + msg_count += 1; + s2c_bytes_clone.store(total_bytes, Ordering::Relaxed); + + stats_s2c.add_user_octets_to(&user_s2c, n as u64); + stats_s2c.increment_user_msgs_to(&user_s2c); + + trace!( + user = %user_s2c, + bytes = n, + total = total_bytes, + data_preview = %hex::encode(&buf[..n.min(32)]), + "S->C data" + ); + + if let Err(e) = client_writer.write_all(&buf[..n]).await { + debug!(user = %user_s2c, error = %e, "Failed to write to client"); + break; + } + if let Err(e) = client_writer.flush().await { + debug!(user = %user_s2c, error = %e, "Failed to flush to client"); + break; + } + } + Err(e) => { + debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error"); + break; + } + } + } + }); + + // Wait for either direction to complete + tokio::select! { + result = c2s => { + if let Err(e) = result { + warn!(error = %e, "C->S task panicked"); + } + } + result = s2c => { + if let Err(e) = result { + warn!(error = %e, "S->C task panicked"); + } + } + } + + debug!( + c2s_bytes = c2s_bytes.load(Ordering::Relaxed), + s2c_bytes = s2c_bytes.load(Ordering::Relaxed), + "Relay finished" + ); + + Ok(()) +} \ No newline at end of file diff --git a/src/stats/mod.rs b/src/stats/mod.rs new file mode 100644 index 0000000..6ae5af2 --- /dev/null +++ b/src/stats/mod.rs @@ -0,0 +1,223 @@ +//! Statistics + +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::Instant; +use dashmap::DashMap; +use parking_lot::RwLock; +use lru::LruCache; +use std::num::NonZeroUsize; + +/// Thread-safe statistics +#[derive(Default)] +pub struct Stats { + // Global counters + connects_all: AtomicU64, + connects_bad: AtomicU64, + handshake_timeouts: AtomicU64, + + // Per-user stats + user_stats: DashMap, + + // Start time + start_time: RwLock>, +} + +/// Per-user statistics +#[derive(Default)] +pub struct UserStats { + pub connects: AtomicU64, + pub curr_connects: AtomicU64, + pub octets_from_client: AtomicU64, + pub octets_to_client: AtomicU64, + pub msgs_from_client: AtomicU64, + pub msgs_to_client: AtomicU64, +} + +impl Stats { + pub fn new() -> Self { + let stats = Self::default(); + *stats.start_time.write() = Some(Instant::now()); + stats + } + + // Global stats + pub fn increment_connects_all(&self) { + self.connects_all.fetch_add(1, Ordering::Relaxed); + } + + pub fn increment_connects_bad(&self) { + self.connects_bad.fetch_add(1, Ordering::Relaxed); + } + + pub fn increment_handshake_timeouts(&self) { + self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); + } + + pub fn get_connects_all(&self) -> u64 { + self.connects_all.load(Ordering::Relaxed) + } + + pub fn get_connects_bad(&self) -> u64 { + self.connects_bad.load(Ordering::Relaxed) + } + + // User stats + pub fn increment_user_connects(&self, user: &str) { + self.user_stats + .entry(user.to_string()) + .or_default() + .connects + .fetch_add(1, Ordering::Relaxed); + } + + pub fn increment_user_curr_connects(&self, user: &str) { + self.user_stats + .entry(user.to_string()) + .or_default() + .curr_connects + .fetch_add(1, Ordering::Relaxed); + } + + pub fn decrement_user_curr_connects(&self, user: &str) { + if let Some(stats) = self.user_stats.get(user) { + stats.curr_connects.fetch_sub(1, Ordering::Relaxed); + } + } + + pub fn get_user_curr_connects(&self, user: &str) -> u64 { + self.user_stats + .get(user) + .map(|s| s.curr_connects.load(Ordering::Relaxed)) + .unwrap_or(0) + } + + pub fn add_user_octets_from(&self, user: &str, bytes: u64) { + self.user_stats + .entry(user.to_string()) + .or_default() + .octets_from_client + .fetch_add(bytes, Ordering::Relaxed); + } + + pub fn add_user_octets_to(&self, user: &str, bytes: u64) { + self.user_stats + .entry(user.to_string()) + .or_default() + .octets_to_client + .fetch_add(bytes, Ordering::Relaxed); + } + + pub fn increment_user_msgs_from(&self, user: &str) { + self.user_stats + .entry(user.to_string()) + .or_default() + .msgs_from_client + .fetch_add(1, Ordering::Relaxed); + } + + pub fn increment_user_msgs_to(&self, user: &str) { + self.user_stats + .entry(user.to_string()) + .or_default() + .msgs_to_client + .fetch_add(1, Ordering::Relaxed); + } + + pub fn get_user_total_octets(&self, user: &str) -> u64 { + self.user_stats + .get(user) + .map(|s| { + s.octets_from_client.load(Ordering::Relaxed) + + s.octets_to_client.load(Ordering::Relaxed) + }) + .unwrap_or(0) + } + + pub fn uptime_secs(&self) -> f64 { + self.start_time.read() + .map(|t| t.elapsed().as_secs_f64()) + .unwrap_or(0.0) + } +} + +// Arc Hightech Stats :D + +/// Replay attack checker using LRU cache +pub struct ReplayChecker { + handshakes: RwLock, ()>>, + tls_digests: RwLock, ()>>, +} + +impl ReplayChecker { + pub fn new(capacity: usize) -> Self { + let cap = NonZeroUsize::new(capacity.max(1)).unwrap(); + Self { + handshakes: RwLock::new(LruCache::new(cap)), + tls_digests: RwLock::new(LruCache::new(cap)), + } + } + + pub fn check_handshake(&self, data: &[u8]) -> bool { + self.handshakes.read().contains(&data.to_vec()) + } + + pub fn add_handshake(&self, data: &[u8]) { + self.handshakes.write().put(data.to_vec(), ()); + } + + pub fn check_tls_digest(&self, data: &[u8]) -> bool { + self.tls_digests.read().contains(&data.to_vec()) + } + + pub fn add_tls_digest(&self, data: &[u8]) { + self.tls_digests.write().put(data.to_vec(), ()); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_stats_shared_counters() { + let stats = Arc::new(Stats::new()); + + // Симулируем использование из разных "задач" + let stats1 = Arc::clone(&stats); + let stats2 = Arc::clone(&stats); + + stats1.increment_connects_all(); + stats2.increment_connects_all(); + stats1.increment_connects_all(); + + // Все инкременты должны быть видны + assert_eq!(stats.get_connects_all(), 3); + } + + #[test] + fn test_user_stats_shared() { + let stats = Arc::new(Stats::new()); + + let stats1 = Arc::clone(&stats); + let stats2 = Arc::clone(&stats); + + stats1.add_user_octets_from("user1", 100); + stats2.add_user_octets_from("user1", 200); + stats1.add_user_octets_to("user1", 50); + + assert_eq!(stats.get_user_total_octets("user1"), 350); + } + + #[test] + fn test_concurrent_user_connects() { + let stats = Arc::new(Stats::new()); + + stats.increment_user_curr_connects("user1"); + stats.increment_user_curr_connects("user1"); + assert_eq!(stats.get_user_curr_connects("user1"), 2); + + stats.decrement_user_curr_connects("user1"); + assert_eq!(stats.get_user_curr_connects("user1"), 1); + } +} \ No newline at end of file diff --git a/src/stream/crypto_stream.rs b/src/stream/crypto_stream.rs new file mode 100644 index 0000000..123dfa5 --- /dev/null +++ b/src/stream/crypto_stream.rs @@ -0,0 +1,474 @@ +//! Encrypted stream wrappers using AES-CTR + +use bytes::{Bytes, BytesMut, BufMut}; +use std::io::{Error, ErrorKind, Result}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf}; +use crate::crypto::AesCtr; +use parking_lot::Mutex; + +/// Reader that decrypts data using AES-CTR +pub struct CryptoReader { + upstream: R, + decryptor: AesCtr, + buffer: BytesMut, +} + +impl CryptoReader { + /// Create new crypto reader + pub fn new(upstream: R, decryptor: AesCtr) -> Self { + Self { + upstream, + decryptor, + buffer: BytesMut::with_capacity(8192), + } + } + + /// Get reference to upstream + pub fn get_ref(&self) -> &R { + &self.upstream + } + + /// Get mutable reference to upstream + pub fn get_mut(&mut self) -> &mut R { + &mut self.upstream + } + + /// Consume and return upstream + pub fn into_inner(self) -> R { + self.upstream + } +} + +impl AsyncRead for CryptoReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + + if !this.buffer.is_empty() { + let to_copy = this.buffer.len().min(buf.remaining()); + buf.put_slice(&this.buffer.split_to(to_copy)); + return Poll::Ready(Ok(())); + } + + // Zero-copy Reader + let before = buf.filled().len(); + + match Pin::new(&mut this.upstream).poll_read(cx, buf) { + Poll::Ready(Ok(())) => { + let after = buf.filled().len(); + let bytes_read = after - before; + + if bytes_read > 0 { + // Decrypt in-place + let filled = buf.filled_mut(); + this.decryptor.apply(&mut filled[before..after]); + } + + Poll::Ready(Ok(())) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } +} + +impl CryptoReader { + /// Read and decrypt exactly n bytes with Async + pub async fn read_exact_decrypt(&mut self, n: usize) -> Result { + let mut result = BytesMut::with_capacity(n); + + if !self.buffer.is_empty() { + let to_take = self.buffer.len().min(n); + result.extend_from_slice(&self.buffer.split_to(to_take)); + } + + // Reread + while result.len() < n { + let mut temp = vec![0u8; n - result.len()]; + let read = self.upstream.read(&mut temp).await?; + + if read == 0 { + return Err(Error::new(ErrorKind::UnexpectedEof, "Connection closed")); + } + + // Decrypt + self.decryptor.apply(&mut temp[..read]); + result.extend_from_slice(&temp[..read]); + } + + Ok(result.freeze()) + } +} + +/// Writer that encrypts data using AES-CTR +pub struct CryptoWriter { + upstream: W, + encryptor: AesCtr, + pending: BytesMut, +} + +impl CryptoWriter { + pub fn new(upstream: W, encryptor: AesCtr) -> Self { + Self { + upstream, + encryptor, + pending: BytesMut::with_capacity(8192), + } + } + + pub fn get_ref(&self) -> &W { + &self.upstream + } + + pub fn get_mut(&mut self) -> &mut W { + &mut self.upstream + } + + pub fn into_inner(self) -> W { + self.upstream + } +} + +impl AsyncWrite for CryptoWriter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + + if !this.pending.is_empty() { + match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) { + Poll::Ready(Ok(written)) => { + let _ = this.pending.split_to(written); + + if !this.pending.is_empty() { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + + // Pending Null + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + + // Encrypt + let mut encrypted = buf.to_vec(); + this.encryptor.apply(&mut encrypted); + + // Write Try + match Pin::new(&mut this.upstream).poll_write(cx, &encrypted) { + Poll::Ready(Ok(written)) => { + if written < encrypted.len() { + // Partial write — сохраняем остаток в pending + this.pending.extend_from_slice(&encrypted[written..]); + } + Poll::Ready(Ok(buf.len())) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => { + this.pending.extend_from_slice(&encrypted); + Poll::Ready(Ok(buf.len())) + } + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + while !this.pending.is_empty() { + match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) { + Poll::Ready(Ok(0)) => { + return Poll::Ready(Err(Error::new( + ErrorKind::WriteZero, + "Failed to write pending data during flush", + ))); + } + Poll::Ready(Ok(written)) => { + let _ = this.pending.split_to(written); + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + + Pin::new(&mut this.upstream).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + while !this.pending.is_empty() { + match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) { + Poll::Ready(Ok(0)) => { + break; + } + Poll::Ready(Ok(written)) => { + let _ = this.pending.split_to(written); + } + Poll::Ready(Err(_)) => { + break; + } + Poll::Pending => return Poll::Pending, + } + } + + Pin::new(&mut this.upstream).poll_shutdown(cx) + } +} + +/// Passthrough stream for fast mode - no encryption/decryption +pub struct PassthroughStream { + inner: S, +} + +impl PassthroughStream { + pub fn new(inner: S) -> Self { + Self { inner } + } + + pub fn into_inner(self) -> S { + self.inner + } +} + +impl AsyncRead for PassthroughStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for PassthroughStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::VecDeque; + use std::pin::Pin; + use std::task::{Context, Poll, Waker, RawWaker, RawWakerVTable}; + use tokio::io::duplex; + + /// Mock writer + struct PartialWriter { + chunk_size: usize, + data: Vec, + write_count: usize, + } + + impl PartialWriter { + fn new(chunk_size: usize) -> Self { + Self { + chunk_size, + data: Vec::new(), + write_count: 0, + } + } + } + + impl AsyncWrite for PartialWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.write_count += 1; + let to_write = buf.len().min(self.chunk_size); + self.data.extend_from_slice(&buf[..to_write]); + Poll::Ready(Ok(to_write)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + } + + fn noop_waker() -> Waker { + const VTABLE: RawWakerVTable = RawWakerVTable::new( + |_| RawWaker::new(std::ptr::null(), &VTABLE), + |_| {}, + |_| {}, + |_| {}, + ); + unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) } + } + + #[test] + fn test_crypto_writer_partial_write_correctness() { + let key = [0x42u8; 32]; + let iv = 12345u128; + + // 10-byte Writer + let mock_writer = PartialWriter::new(10); + let encryptor = AesCtr::new(&key, iv); + let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor); + + let waker = noop_waker(); + let mut cx = Context::from_waker(&waker); + + // 25 byte + let original = b"Hello, this is test data!"; + + // First Write + let result = Pin::new(&mut crypto_writer).poll_write(&mut cx, original); + assert!(matches!(result, Poll::Ready(Ok(25)))); + + // Flush before continue Pending + loop { + match Pin::new(&mut crypto_writer).poll_flush(&mut cx) { + Poll::Ready(Ok(())) => break, + Poll::Ready(Err(e)) => panic!("Flush error: {}", e), + Poll::Pending => continue, + } + } + + // Write Check + let encrypted = &crypto_writer.upstream.data; + assert_eq!(encrypted.len(), 25); + + // Decrypt + Verify + let mut decryptor = AesCtr::new(&key, iv); + let mut decrypted = encrypted.clone(); + decryptor.apply(&mut decrypted); + + assert_eq!(&decrypted, original); + } + + #[test] + fn test_crypto_writer_multiple_partial_writes() { + let key = [0xAB; 32]; + let iv = 9999u128; + + let mock_writer = PartialWriter::new(3); + let encryptor = AesCtr::new(&key, iv); + let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor); + + let waker = noop_waker(); + let mut cx = Context::from_waker(&waker); + + let data1 = b"First"; + let data2 = b"Second"; + let data3 = b"Third"; + + Pin::new(&mut crypto_writer).poll_write(&mut cx, data1).unwrap(); + // Flush + while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {} + + Pin::new(&mut crypto_writer).poll_write(&mut cx, data2).unwrap(); + while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {} + + Pin::new(&mut crypto_writer).poll_write(&mut cx, data3).unwrap(); + while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {} + + // Assemble + let mut expected = Vec::new(); + expected.extend_from_slice(data1); + expected.extend_from_slice(data2); + expected.extend_from_slice(data3); + + // Decrypt + let mut decryptor = AesCtr::new(&key, iv); + let mut decrypted = crypto_writer.upstream.data.clone(); + decryptor.apply(&mut decrypted); + + assert_eq!(decrypted, expected); + } + + #[tokio::test] + async fn test_crypto_stream_roundtrip() { + let key = [0u8; 32]; + let iv = 12345u128; + + let (client, server) = duplex(4096); + + let encryptor = AesCtr::new(&key, iv); + let decryptor = AesCtr::new(&key, iv); + + let mut writer = CryptoWriter::new(client, encryptor); + let mut reader = CryptoReader::new(server, decryptor); + + // Write + let original = b"Hello, encrypted world!"; + writer.write_all(original).await.unwrap(); + writer.flush().await.unwrap(); + + // Read + let mut buf = vec![0u8; original.len()]; + reader.read_exact(&mut buf).await.unwrap(); + + assert_eq!(&buf, original); + } + + #[tokio::test] + async fn test_crypto_stream_large_data() { + let key = [0x55u8; 32]; + let iv = 777u128; + + let (client, server) = duplex(1024); + + let encryptor = AesCtr::new(&key, iv); + let decryptor = AesCtr::new(&key, iv); + + let mut writer = CryptoWriter::new(client, encryptor); + let mut reader = CryptoReader::new(server, decryptor); + + // Hugeload + let original: Vec = (0..10000).map(|i| (i % 256) as u8).collect(); + + // Write + let write_data = original.clone(); + let write_handle = tokio::spawn(async move { + writer.write_all(&write_data).await.unwrap(); + writer.flush().await.unwrap(); + writer.shutdown().await.unwrap(); + }); + + // Read + let mut received = Vec::new(); + let mut buf = vec![0u8; 1024]; + loop { + match reader.read(&mut buf).await { + Ok(0) => break, + Ok(n) => received.extend_from_slice(&buf[..n]), + Err(e) => panic!("Read error: {}", e), + } + } + + write_handle.await.unwrap(); + + assert_eq!(received, original); + } +} \ No newline at end of file diff --git a/src/stream/frame_stream.rs b/src/stream/frame_stream.rs new file mode 100644 index 0000000..9e62c8d --- /dev/null +++ b/src/stream/frame_stream.rs @@ -0,0 +1,585 @@ +//! MTProto frame stream wrappers + +use bytes::{Bytes, BytesMut}; +use std::io::{Error, ErrorKind, Result}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; +use crate::protocol::constants::*; +use crate::crypto::crc32; +use crate::crypto::random::SECURE_RANDOM; +use super::traits::{FrameMeta, LayeredStream}; + +// ============= Abridged (Compact) Frame ============= + +/// Reader for abridged MTProto framing +pub struct AbridgedFrameReader { + upstream: R, +} + +impl AbridgedFrameReader { + pub fn new(upstream: R) -> Self { + Self { upstream } + } +} + +impl AbridgedFrameReader { + /// Read a frame and return (data, metadata) + pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> { + let mut meta = FrameMeta::new(); + + // Read length byte + let mut len_byte = [0u8]; + self.upstream.read_exact(&mut len_byte).await?; + + let mut len = len_byte[0] as usize; + + // Check QuickACK flag (high bit) + if len >= 0x80 { + meta.quickack = true; + len -= 0x80; + } + + // Extended length (3 bytes) + if len == 0x7f { + let mut len_bytes = [0u8; 3]; + self.upstream.read_exact(&mut len_bytes).await?; + len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize; + } + + // Length is in 4-byte words + let byte_len = len * 4; + + // Read data + let mut data = vec![0u8; byte_len]; + self.upstream.read_exact(&mut data).await?; + + Ok((Bytes::from(data), meta)) + } +} + +impl LayeredStream for AbridgedFrameReader { + fn upstream(&self) -> &R { &self.upstream } + fn upstream_mut(&mut self) -> &mut R { &mut self.upstream } + fn into_upstream(self) -> R { self.upstream } +} + +/// Writer for abridged MTProto framing +pub struct AbridgedFrameWriter { + upstream: W, +} + +impl AbridgedFrameWriter { + pub fn new(upstream: W) -> Self { + Self { upstream } + } +} + +impl AbridgedFrameWriter { + /// Write a frame + pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> { + if data.len() % 4 != 0 { + return Err(Error::new( + ErrorKind::InvalidInput, + format!("Abridged frame must be aligned to 4 bytes, got {}", data.len()), + )); + } + + // Simple ACK: send reversed data + if meta.simple_ack { + let reversed: Vec = data.iter().rev().copied().collect(); + self.upstream.write_all(&reversed).await?; + return Ok(()); + } + + let len_div_4 = data.len() / 4; + + if len_div_4 < 0x7f { + // Short length (1 byte) + self.upstream.write_all(&[len_div_4 as u8]).await?; + } else if len_div_4 < (1 << 24) { + // Long length (4 bytes: 0x7f + 3 bytes) + let mut header = [0x7f, 0, 0, 0]; + header[1..4].copy_from_slice(&(len_div_4 as u32).to_le_bytes()[..3]); + self.upstream.write_all(&header).await?; + } else { + return Err(Error::new( + ErrorKind::InvalidInput, + format!("Frame too large: {} bytes", data.len()), + )); + } + + self.upstream.write_all(data).await?; + Ok(()) + } + + pub async fn flush(&mut self) -> Result<()> { + self.upstream.flush().await + } +} + +impl LayeredStream for AbridgedFrameWriter { + fn upstream(&self) -> &W { &self.upstream } + fn upstream_mut(&mut self) -> &mut W { &mut self.upstream } + fn into_upstream(self) -> W { self.upstream } +} + +// ============= Intermediate Frame ============= + +/// Reader for intermediate MTProto framing +pub struct IntermediateFrameReader { + upstream: R, +} + +impl IntermediateFrameReader { + pub fn new(upstream: R) -> Self { + Self { upstream } + } +} + +impl IntermediateFrameReader { + pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> { + let mut meta = FrameMeta::new(); + + // Read 4-byte length + let mut len_bytes = [0u8; 4]; + self.upstream.read_exact(&mut len_bytes).await?; + + let mut len = u32::from_le_bytes(len_bytes) as usize; + + // Check QuickACK flag (high bit) + if len > 0x80000000 { + meta.quickack = true; + len -= 0x80000000; + } + + // Read data + let mut data = vec![0u8; len]; + self.upstream.read_exact(&mut data).await?; + + Ok((Bytes::from(data), meta)) + } +} + +impl LayeredStream for IntermediateFrameReader { + fn upstream(&self) -> &R { &self.upstream } + fn upstream_mut(&mut self) -> &mut R { &mut self.upstream } + fn into_upstream(self) -> R { self.upstream } +} + +/// Writer for intermediate MTProto framing +pub struct IntermediateFrameWriter { + upstream: W, +} + +impl IntermediateFrameWriter { + pub fn new(upstream: W) -> Self { + Self { upstream } + } +} + +impl IntermediateFrameWriter { + pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> { + if meta.simple_ack { + self.upstream.write_all(data).await?; + } else { + let len_bytes = (data.len() as u32).to_le_bytes(); + self.upstream.write_all(&len_bytes).await?; + self.upstream.write_all(data).await?; + } + Ok(()) + } + + pub async fn flush(&mut self) -> Result<()> { + self.upstream.flush().await + } +} + +impl LayeredStream for IntermediateFrameWriter { + fn upstream(&self) -> &W { &self.upstream } + fn upstream_mut(&mut self) -> &mut W { &mut self.upstream } + fn into_upstream(self) -> W { self.upstream } +} + +// ============= Secure Intermediate Frame ============= + +/// Reader for secure intermediate MTProto framing (with padding) +pub struct SecureIntermediateFrameReader { + upstream: R, +} + +impl SecureIntermediateFrameReader { + pub fn new(upstream: R) -> Self { + Self { upstream } + } +} + +impl SecureIntermediateFrameReader { + pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> { + let mut meta = FrameMeta::new(); + + // Read 4-byte length + let mut len_bytes = [0u8; 4]; + self.upstream.read_exact(&mut len_bytes).await?; + + let mut len = u32::from_le_bytes(len_bytes) as usize; + + // Check QuickACK flag + if len > 0x80000000 { + meta.quickack = true; + len -= 0x80000000; + } + + // Read data (including padding) + let mut data = vec![0u8; len]; + self.upstream.read_exact(&mut data).await?; + + // Strip padding (not aligned to 4) + if len % 4 != 0 { + let actual_len = len - (len % 4); + data.truncate(actual_len); + } + + Ok((Bytes::from(data), meta)) + } +} + +impl LayeredStream for SecureIntermediateFrameReader { + fn upstream(&self) -> &R { &self.upstream } + fn upstream_mut(&mut self) -> &mut R { &mut self.upstream } + fn into_upstream(self) -> R { self.upstream } +} + +/// Writer for secure intermediate MTProto framing +pub struct SecureIntermediateFrameWriter { + upstream: W, +} + +impl SecureIntermediateFrameWriter { + pub fn new(upstream: W) -> Self { + Self { upstream } + } +} + +impl SecureIntermediateFrameWriter { + pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> { + if meta.simple_ack { + self.upstream.write_all(data).await?; + return Ok(()); + } + + // Add random padding (0-3 bytes) + let padding_len = SECURE_RANDOM.range(4); + let padding = SECURE_RANDOM.bytes(padding_len); + + let total_len = data.len() + padding_len; + let len_bytes = (total_len as u32).to_le_bytes(); + + self.upstream.write_all(&len_bytes).await?; + self.upstream.write_all(data).await?; + self.upstream.write_all(&padding).await?; + + Ok(()) + } + + pub async fn flush(&mut self) -> Result<()> { + self.upstream.flush().await + } +} + +impl LayeredStream for SecureIntermediateFrameWriter { + fn upstream(&self) -> &W { &self.upstream } + fn upstream_mut(&mut self) -> &mut W { &mut self.upstream } + fn into_upstream(self) -> W { self.upstream } +} + +// ============= Full MTProto Frame (with CRC) ============= + +/// Reader for full MTProto framing with sequence numbers and CRC32 +pub struct MtprotoFrameReader { + upstream: R, + seq_no: i32, +} + +impl MtprotoFrameReader { + pub fn new(upstream: R, start_seq: i32) -> Self { + Self { upstream, seq_no: start_seq } + } +} + +impl MtprotoFrameReader { + pub async fn read_frame(&mut self) -> Result { + loop { + // Read length (4 bytes) + let mut len_bytes = [0u8; 4]; + self.upstream.read_exact(&mut len_bytes).await?; + let len = u32::from_le_bytes(len_bytes) as usize; + + // Skip padding-only messages + if len == 4 { + continue; + } + + // Validate length + if len < MIN_MSG_LEN || len > MAX_MSG_LEN || len % PADDING_FILLER.len() != 0 { + return Err(Error::new( + ErrorKind::InvalidData, + format!("Invalid message length: {}", len), + )); + } + + // Read sequence number + let mut seq_bytes = [0u8; 4]; + self.upstream.read_exact(&mut seq_bytes).await?; + let msg_seq = i32::from_le_bytes(seq_bytes); + + if msg_seq != self.seq_no { + return Err(Error::new( + ErrorKind::InvalidData, + format!("Sequence mismatch: expected {}, got {}", self.seq_no, msg_seq), + )); + } + self.seq_no += 1; + + // Read data (length - 4 len - 4 seq - 4 crc = len - 12) + let data_len = len - 12; + let mut data = vec![0u8; data_len]; + self.upstream.read_exact(&mut data).await?; + + // Read and verify CRC32 + let mut crc_bytes = [0u8; 4]; + self.upstream.read_exact(&mut crc_bytes).await?; + let expected_crc = u32::from_le_bytes(crc_bytes); + + // Compute CRC over len + seq + data + let mut crc_input = Vec::with_capacity(8 + data_len); + crc_input.extend_from_slice(&len_bytes); + crc_input.extend_from_slice(&seq_bytes); + crc_input.extend_from_slice(&data); + let computed_crc = crc32(&crc_input); + + if computed_crc != expected_crc { + return Err(Error::new( + ErrorKind::InvalidData, + format!("CRC mismatch: expected {:08x}, got {:08x}", expected_crc, computed_crc), + )); + } + + return Ok(Bytes::from(data)); + } + } +} + +/// Writer for full MTProto framing +pub struct MtprotoFrameWriter { + upstream: W, + seq_no: i32, +} + +impl MtprotoFrameWriter { + pub fn new(upstream: W, start_seq: i32) -> Self { + Self { upstream, seq_no: start_seq } + } +} + +impl MtprotoFrameWriter { + pub async fn write_frame(&mut self, msg: &[u8]) -> Result<()> { + // Total length: 4 (len) + 4 (seq) + data + 4 (crc) + let len = msg.len() + 12; + + let len_bytes = (len as u32).to_le_bytes(); + let seq_bytes = self.seq_no.to_le_bytes(); + self.seq_no += 1; + + // Compute CRC + let mut crc_input = Vec::with_capacity(8 + msg.len()); + crc_input.extend_from_slice(&len_bytes); + crc_input.extend_from_slice(&seq_bytes); + crc_input.extend_from_slice(msg); + let checksum = crc32(&crc_input); + let crc_bytes = checksum.to_le_bytes(); + + // Calculate padding for CBC alignment + let total_len = len_bytes.len() + seq_bytes.len() + msg.len() + crc_bytes.len(); + let padding_needed = (CBC_PADDING - (total_len % CBC_PADDING)) % CBC_PADDING; + let padding_count = padding_needed / PADDING_FILLER.len(); + + // Write everything + self.upstream.write_all(&len_bytes).await?; + self.upstream.write_all(&seq_bytes).await?; + self.upstream.write_all(msg).await?; + self.upstream.write_all(&crc_bytes).await?; + + for _ in 0..padding_count { + self.upstream.write_all(&PADDING_FILLER).await?; + } + + Ok(()) + } + + pub async fn flush(&mut self) -> Result<()> { + self.upstream.flush().await + } +} + +// ============= Frame Type Enum ============= + +/// Enum for different frame stream types +pub enum FrameReaderKind { + Abridged(AbridgedFrameReader), + Intermediate(IntermediateFrameReader), + SecureIntermediate(SecureIntermediateFrameReader), +} + +impl FrameReaderKind { + pub fn new(upstream: R, proto_tag: ProtoTag) -> Self { + match proto_tag { + ProtoTag::Abridged => FrameReaderKind::Abridged(AbridgedFrameReader::new(upstream)), + ProtoTag::Intermediate => FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream)), + ProtoTag::Secure => FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream)), + } + } + + pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> { + match self { + FrameReaderKind::Abridged(r) => r.read_frame().await, + FrameReaderKind::Intermediate(r) => r.read_frame().await, + FrameReaderKind::SecureIntermediate(r) => r.read_frame().await, + } + } +} + +pub enum FrameWriterKind { + Abridged(AbridgedFrameWriter), + Intermediate(IntermediateFrameWriter), + SecureIntermediate(SecureIntermediateFrameWriter), +} + +impl FrameWriterKind { + pub fn new(upstream: W, proto_tag: ProtoTag) -> Self { + match proto_tag { + ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)), + ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)), + ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream)), + } + } + + pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> { + match self { + FrameWriterKind::Abridged(w) => w.write_frame(data, meta).await, + FrameWriterKind::Intermediate(w) => w.write_frame(data, meta).await, + FrameWriterKind::SecureIntermediate(w) => w.write_frame(data, meta).await, + } + } + + pub async fn flush(&mut self) -> Result<()> { + match self { + FrameWriterKind::Abridged(w) => w.flush().await, + FrameWriterKind::Intermediate(w) => w.flush().await, + FrameWriterKind::SecureIntermediate(w) => w.flush().await, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::duplex; + + #[tokio::test] + async fn test_abridged_roundtrip() { + let (client, server) = duplex(1024); + + let mut writer = AbridgedFrameWriter::new(client); + let mut reader = AbridgedFrameReader::new(server); + + // Short frame + let data = vec![1u8, 2, 3, 4]; // 4 bytes = 1 word + writer.write_frame(&data, &FrameMeta::new()).await.unwrap(); + writer.flush().await.unwrap(); + + let (received, _meta) = reader.read_frame().await.unwrap(); + assert_eq!(&received[..], &data[..]); + } + + #[tokio::test] + async fn test_abridged_long_frame() { + let (client, server) = duplex(65536); + + let mut writer = AbridgedFrameWriter::new(client); + let mut reader = AbridgedFrameReader::new(server); + + // Long frame (> 0x7f words = 508 bytes) + let data: Vec = (0..1000).map(|i| (i % 256) as u8).collect(); + let padded_len = (data.len() + 3) / 4 * 4; + let mut padded = data.clone(); + padded.resize(padded_len, 0); + + writer.write_frame(&padded, &FrameMeta::new()).await.unwrap(); + writer.flush().await.unwrap(); + + let (received, _meta) = reader.read_frame().await.unwrap(); + assert_eq!(&received[..], &padded[..]); + } + + #[tokio::test] + async fn test_intermediate_roundtrip() { + let (client, server) = duplex(1024); + + let mut writer = IntermediateFrameWriter::new(client); + let mut reader = IntermediateFrameReader::new(server); + + let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8]; + writer.write_frame(&data, &FrameMeta::new()).await.unwrap(); + writer.flush().await.unwrap(); + + let (received, _meta) = reader.read_frame().await.unwrap(); + assert_eq!(&received[..], &data[..]); + } + + #[tokio::test] + async fn test_secure_intermediate_padding() { + let (client, server) = duplex(1024); + + let mut writer = SecureIntermediateFrameWriter::new(client); + let mut reader = SecureIntermediateFrameReader::new(server); + + let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8]; + writer.write_frame(&data, &FrameMeta::new()).await.unwrap(); + writer.flush().await.unwrap(); + + let (received, _meta) = reader.read_frame().await.unwrap(); + // Received should have padding stripped to align to 4 + let expected_len = (data.len() / 4) * 4; + assert_eq!(received.len(), expected_len); + } + + #[tokio::test] + async fn test_mtproto_frame_roundtrip() { + let (client, server) = duplex(1024); + + let mut writer = MtprotoFrameWriter::new(client, 0); + let mut reader = MtprotoFrameReader::new(server, 0); + + // Message must be padded properly + let data = vec![0u8; 16]; // Aligned to 4 and CBC_PADDING + writer.write_frame(&data).await.unwrap(); + writer.flush().await.unwrap(); + + let received = reader.read_frame().await.unwrap(); + assert_eq!(&received[..], &data[..]); + } + + #[tokio::test] + async fn test_frame_reader_kind() { + let (client, server) = duplex(1024); + + let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate); + let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate); + + let data = vec![1u8, 2, 3, 4]; + writer.write_frame(&data, &FrameMeta::new()).await.unwrap(); + writer.flush().await.unwrap(); + + let (received, _) = reader.read_frame().await.unwrap(); + assert_eq!(&received[..], &data[..]); + } +} \ No newline at end of file diff --git a/src/stream/mod.rs b/src/stream/mod.rs new file mode 100644 index 0000000..1d98a5e --- /dev/null +++ b/src/stream/mod.rs @@ -0,0 +1,10 @@ +//! Stream wrappers for MTProto protocol layers + +pub mod traits; +pub mod crypto_stream; +pub mod tls_stream; +pub mod frame_stream; + +pub use crypto_stream::{CryptoReader, CryptoWriter, PassthroughStream}; +pub use tls_stream::{FakeTlsReader, FakeTlsWriter}; +pub use frame_stream::*; \ No newline at end of file diff --git a/src/stream/tls_stream.rs b/src/stream/tls_stream.rs new file mode 100644 index 0000000..fbe2f5e --- /dev/null +++ b/src/stream/tls_stream.rs @@ -0,0 +1,277 @@ +//! Fake TLS 1.3 stream wrappers + +use bytes::{Bytes, BytesMut}; +use std::io::{Error, ErrorKind, Result}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf}; +use crate::protocol::constants::{ + TLS_VERSION, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, + MAX_TLS_CHUNK_SIZE, +}; +use parking_lot::Mutex; + +/// Reader that unwraps TLS 1.3 records +pub struct FakeTlsReader { + upstream: R, + buffer: BytesMut, + pending_read: Option, +} + +struct PendingTlsRead { + record_type: u8, + remaining: usize, +} + +impl FakeTlsReader { + /// Create new fake TLS reader + pub fn new(upstream: R) -> Self { + Self { + upstream, + buffer: BytesMut::with_capacity(16384), + pending_read: None, + } + } + + /// Get reference to upstream + pub fn get_ref(&self) -> &R { + &self.upstream + } + + /// Get mutable reference to upstream + pub fn get_mut(&mut self) -> &mut R { + &mut self.upstream + } + + /// Consume and return upstream + pub fn into_inner(self) -> R { + self.upstream + } +} + +impl FakeTlsReader { + /// Read exactly n bytes through TLS layer + pub async fn read_exact(&mut self, n: usize) -> Result { + while self.buffer.len() < n { + let data = self.read_tls_record().await?; + if data.is_empty() { + return Err(Error::new(ErrorKind::UnexpectedEof, "Connection closed")); + } + self.buffer.extend_from_slice(&data); + } + + Ok(self.buffer.split_to(n).freeze()) + } + + /// Read a single TLS record + async fn read_tls_record(&mut self) -> Result> { + loop { + // Read TLS record header (5 bytes) + let mut header = [0u8; 5]; + self.upstream.read_exact(&mut header).await?; + + let record_type = header[0]; + let version = [header[1], header[2]]; + let length = u16::from_be_bytes([header[3], header[4]]) as usize; + + // Validate version + if version != TLS_VERSION { + return Err(Error::new( + ErrorKind::InvalidData, + format!("Invalid TLS version: {:02x?}", version), + )); + } + + // Read record body + let mut data = vec![0u8; length]; + self.upstream.read_exact(&mut data).await?; + + match record_type { + TLS_RECORD_CHANGE_CIPHER => continue, // Skip + TLS_RECORD_APPLICATION => return Ok(data), + _ => { + return Err(Error::new( + ErrorKind::InvalidData, + format!("Unexpected TLS record type: 0x{:02x}", record_type), + )); + } + } + } + } +} + +impl AsyncRead for FakeTlsReader { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + // Drain buffer first + if !self.buffer.is_empty() { + let to_copy = self.buffer.len().min(buf.remaining()); + buf.put_slice(&self.buffer.split_to(to_copy)); + return Poll::Ready(Ok(())); + } + + // We need to read a TLS record, but poll_read doesn't support async/await + // So we'll do a simplified version that reads header synchronously + + // Read header + let mut header = [0u8; 5]; + let mut header_buf = ReadBuf::new(&mut header); + + match Pin::new(&mut self.upstream).poll_read(cx, &mut header_buf) { + Poll::Ready(Ok(())) => { + if header_buf.filled().len() < 5 { + // Need more data - store what we have and return pending + // For simplicity, we'll just return empty + return Poll::Ready(Ok(())); + } + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + + let record_type = header[0]; + let length = u16::from_be_bytes([header[3], header[4]]) as usize; + + if record_type == TLS_RECORD_CHANGE_CIPHER { + // Skip this record, try again + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + if record_type != TLS_RECORD_APPLICATION { + return Poll::Ready(Err(Error::new( + ErrorKind::InvalidData, + "Invalid TLS record type", + ))); + } + + // Read body + let mut body = vec![0u8; length]; + let mut body_buf = ReadBuf::new(&mut body); + + match Pin::new(&mut self.upstream).poll_read(cx, &mut body_buf) { + Poll::Ready(Ok(())) => { + let filled = body_buf.filled(); + let to_copy = filled.len().min(buf.remaining()); + buf.put_slice(&filled[..to_copy]); + + if filled.len() > to_copy { + self.buffer.extend_from_slice(&filled[to_copy..]); + } + + Poll::Ready(Ok(())) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } +} + +/// Writer that wraps data in TLS 1.3 records +pub struct FakeTlsWriter { + upstream: W, +} + +impl FakeTlsWriter { + /// Create new fake TLS writer + pub fn new(upstream: W) -> Self { + Self { upstream } + } + + /// Get reference to upstream + pub fn get_ref(&self) -> &W { + &self.upstream + } + + /// Get mutable reference to upstream + pub fn get_mut(&mut self) -> &mut W { + &mut self.upstream + } + + /// Consume and return upstream + pub fn into_inner(self) -> W { + self.upstream + } +} + +impl AsyncWrite for FakeTlsWriter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + // Build TLS record + let chunk_size = buf.len().min(MAX_TLS_CHUNK_SIZE); + let chunk = &buf[..chunk_size]; + + let mut record = Vec::with_capacity(5 + chunk_size); + record.push(TLS_RECORD_APPLICATION); + record.extend_from_slice(&TLS_VERSION); + record.push((chunk_size >> 8) as u8); + record.push(chunk_size as u8); + record.extend_from_slice(chunk); + + match Pin::new(&mut self.upstream).poll_write(cx, &record) { + Poll::Ready(Ok(written)) => { + if written >= 5 { + Poll::Ready(Ok(written - 5)) + } else { + Poll::Ready(Ok(0)) + } + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.upstream).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.upstream).poll_shutdown(cx) + } +} + +impl FakeTlsWriter { + /// Write all data wrapped in TLS records (async method) + pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> { + for chunk in data.chunks(MAX_TLS_CHUNK_SIZE) { + let header = [ + TLS_RECORD_APPLICATION, + TLS_VERSION[0], + TLS_VERSION[1], + (chunk.len() >> 8) as u8, + chunk.len() as u8, + ]; + + self.upstream.write_all(&header).await?; + self.upstream.write_all(chunk).await?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::duplex; + + #[tokio::test] + async fn test_tls_stream_roundtrip() { + let (client, server) = duplex(4096); + + let mut writer = FakeTlsWriter::new(client); + let mut reader = FakeTlsReader::new(server); + + let original = b"Hello, fake TLS!"; + writer.write_all_tls(original).await.unwrap(); + writer.flush().await.unwrap(); + + let received = reader.read_exact(original.len()).await.unwrap(); + assert_eq!(&received[..], original); + } +} \ No newline at end of file diff --git a/src/stream/traits.rs b/src/stream/traits.rs new file mode 100644 index 0000000..6419824 --- /dev/null +++ b/src/stream/traits.rs @@ -0,0 +1,113 @@ +//! Stream traits and common types + +use bytes::Bytes; +use std::io::Result; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// Extra metadata for frames +#[derive(Debug, Clone, Default)] +pub struct FrameMeta { + /// Quick ACK requested + pub quickack: bool, + /// This is a simple ACK message + pub simple_ack: bool, + /// Skip sending this frame + pub skip_send: bool, +} + +impl FrameMeta { + pub fn new() -> Self { + Self::default() + } + + pub fn with_quickack(mut self) -> Self { + self.quickack = true; + self + } + + pub fn with_simple_ack(mut self) -> Self { + self.simple_ack = true; + self + } +} + +/// Result of reading a frame +#[derive(Debug)] +pub enum ReadFrameResult { + /// Frame data with metadata + Frame(Bytes, FrameMeta), + /// Connection closed + Closed, +} + +/// Trait for streams that wrap another stream +pub trait LayeredStream { + /// Get reference to upstream + fn upstream(&self) -> &U; + + /// Get mutable reference to upstream + fn upstream_mut(&mut self) -> &mut U; + + /// Consume self and return upstream + fn into_upstream(self) -> U; +} + +/// A split read half of a stream +pub struct ReadHalf { + inner: R, +} + +impl ReadHalf { + pub fn new(inner: R) -> Self { + Self { inner } + } + + pub fn into_inner(self) -> R { + self.inner + } +} + +impl AsyncRead for ReadHalf { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +/// A split write half of a stream +pub struct WriteHalf { + inner: W, +} + +impl WriteHalf { + pub fn new(inner: W) -> Self { + Self { inner } + } + + pub fn into_inner(self) -> W { + self.inner + } +} + +impl AsyncWrite for WriteHalf { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} \ No newline at end of file diff --git a/src/transport/mod.rs b/src/transport/mod.rs new file mode 100644 index 0000000..437b303 --- /dev/null +++ b/src/transport/mod.rs @@ -0,0 +1,9 @@ +//! Transport layer: connection pooling, socket utilities, proxy protocol + +pub mod pool; +pub mod proxy_protocol; +pub mod socket; + +pub use pool::ConnectionPool; +pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol}; +pub use socket::*; \ No newline at end of file diff --git a/src/transport/pool.rs b/src/transport/pool.rs new file mode 100644 index 0000000..1daa998 --- /dev/null +++ b/src/transport/pool.rs @@ -0,0 +1,338 @@ +//! Connection Pool + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::net::TcpStream; +use tokio::sync::Mutex; +use tokio::time::timeout; +use parking_lot::RwLock; +use tracing::{debug, warn}; +use crate::error::{ProxyError, Result}; +use super::socket::configure_tcp_socket; + +/// A pooled connection with metadata +struct PooledConnection { + stream: TcpStream, + created_at: Instant, +} + +/// Internal pool state for a single endpoint +struct PoolInner { + /// Available connections + connections: Vec, + /// Number of connections being established + pending: usize, +} + +impl PoolInner { + fn new() -> Self { + Self { + connections: Vec::new(), + pending: 0, + } + } +} + +/// Connection pool configuration +#[derive(Debug, Clone)] +pub struct PoolConfig { + /// Maximum connections per endpoint + pub max_connections: usize, + /// Connection timeout + pub connect_timeout: Duration, + /// Maximum idle time before connection is dropped + pub max_idle_time: Duration, + /// Enable TCP keepalive + pub keepalive: bool, + /// Keepalive interval + pub keepalive_interval: Duration, +} + +impl Default for PoolConfig { + fn default() -> Self { + Self { + max_connections: 64, + connect_timeout: Duration::from_secs(10), + max_idle_time: Duration::from_secs(60), + keepalive: true, + keepalive_interval: Duration::from_secs(40), + } + } +} + +/// Thread-safe connection pool +pub struct ConnectionPool { + /// Per-endpoint pools + pools: RwLock>>>, + /// Configuration + config: PoolConfig, +} + +impl ConnectionPool { + /// Create new connection pool with default config + pub fn new() -> Self { + Self::with_config(PoolConfig::default()) + } + + /// Create connection pool with custom config + pub fn with_config(config: PoolConfig) -> Self { + Self { + pools: RwLock::new(HashMap::new()), + config, + } + } + + /// Get or create pool for an endpoint + fn get_or_create_pool(&self, addr: SocketAddr) -> Arc> { + // Fast path with read lock + { + let pools = self.pools.read(); + if let Some(pool) = pools.get(&addr) { + return Arc::clone(pool); + } + } + + // Slow path with write lock + let mut pools = self.pools.write(); + pools.entry(addr) + .or_insert_with(|| Arc::new(Mutex::new(PoolInner::new()))) + .clone() + } + + /// Get a connection to the specified address + pub async fn get(&self, addr: SocketAddr) -> Result { + let pool = self.get_or_create_pool(addr); + + // Try to get an existing connection + { + let mut inner = pool.lock().await; + + // Remove stale connections + let now = Instant::now(); + inner.connections.retain(|c| { + now.duration_since(c.created_at) < self.config.max_idle_time + }); + + // Try to find a usable connection + while let Some(conn) = inner.connections.pop() { + // Check if connection is still alive + if is_connection_alive(&conn.stream) { + debug!(addr = %addr, "Reusing pooled connection"); + return Ok(conn.stream); + } + debug!(addr = %addr, "Discarding dead pooled connection"); + } + + // Check if we can create a new connection + let total = inner.connections.len() + inner.pending; + if total >= self.config.max_connections { + return Err(ProxyError::ConnectionTimeout { + addr: addr.to_string() + }); + } + + inner.pending += 1; + } + + // Create new connection + debug!(addr = %addr, "Creating new connection"); + let result = self.create_connection(addr).await; + + // Decrement pending count + { + let mut inner = pool.lock().await; + inner.pending = inner.pending.saturating_sub(1); + } + + result + } + + /// Create a new connection to the address + async fn create_connection(&self, addr: SocketAddr) -> Result { + let connect_future = TcpStream::connect(addr); + + let stream = timeout(self.config.connect_timeout, connect_future) + .await + .map_err(|_| ProxyError::ConnectionTimeout { + addr: addr.to_string() + })? + .map_err(|e| { + if e.kind() == std::io::ErrorKind::ConnectionRefused { + ProxyError::ConnectionRefused { addr: addr.to_string() } + } else { + ProxyError::Io(e) + } + })?; + + // Configure socket + configure_tcp_socket( + &stream, + self.config.keepalive, + self.config.keepalive_interval, + )?; + + Ok(stream) + } + + /// Return a connection to the pool + pub async fn put(&self, addr: SocketAddr, stream: TcpStream) { + let pool = self.get_or_create_pool(addr); + let mut inner = pool.lock().await; + + if inner.connections.len() < self.config.max_connections { + inner.connections.push(PooledConnection { + stream, + created_at: Instant::now(), + }); + debug!(addr = %addr, pool_size = inner.connections.len(), "Returned connection to pool"); + } else { + debug!(addr = %addr, "Pool full, dropping connection"); + } + } + + /// Close all pooled connections + pub async fn close_all(&self) { + let pools = self.pools.read(); + for (addr, pool) in pools.iter() { + let mut inner = pool.lock().await; + let count = inner.connections.len(); + inner.connections.clear(); + debug!(addr = %addr, count = count, "Closed pooled connections"); + } + } + + /// Get pool statistics + pub async fn stats(&self) -> PoolStats { + let pools = self.pools.read(); + let mut total_connections = 0; + let mut total_pending = 0; + let mut endpoints = 0; + + for pool in pools.values() { + let inner = pool.lock().await; + total_connections += inner.connections.len(); + total_pending += inner.pending; + endpoints += 1; + } + + PoolStats { + endpoints, + total_connections, + total_pending, + } + } +} + +impl Default for ConnectionPool { + fn default() -> Self { + Self::new() + } +} + +/// Pool statistics +#[derive(Debug, Clone)] +pub struct PoolStats { + pub endpoints: usize, + pub total_connections: usize, + pub total_pending: usize, +} + +/// Check if a TCP connection is still alive (non-blocking) +fn is_connection_alive(stream: &TcpStream) -> bool { + // Try a non-blocking read to check connection state + let mut buf = [0u8; 1]; + match stream.try_read(&mut buf) { + Ok(0) => false, // Connection closed + Ok(_) => true, // Data available (shouldn't happen, but connection is alive) + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => true, // No data, but alive + Err(_) => false, // Some error, assume dead + } +} + +/// Connection pool with custom initialization +pub struct InitializingPool { + pool: ConnectionPool, + init_fn: F, +} + +impl InitializingPool +where + F: Fn(TcpStream, SocketAddr) -> Fut + Send + Sync, + Fut: std::future::Future> + Send, +{ + /// Create pool with initialization function + pub fn new(config: PoolConfig, init_fn: F) -> Self { + Self { + pool: ConnectionPool::with_config(config), + init_fn, + } + } + + /// Get an initialized connection + pub async fn get(&self, addr: SocketAddr) -> Result { + let stream = self.pool.get(addr).await?; + (self.init_fn)(stream, addr).await + } + + /// Return connection to pool + pub async fn put(&self, addr: SocketAddr, stream: TcpStream) { + self.pool.put(addr, stream).await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::net::TcpListener; + + #[tokio::test] + async fn test_pool_basic() { + // Start a test server + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + // Accept connections in background + tokio::spawn(async move { + loop { + let _ = listener.accept().await; + } + }); + + let pool = ConnectionPool::new(); + + // Get a connection + let conn1 = pool.get(addr).await.unwrap(); + + // Return it to pool + pool.put(addr, conn1).await; + + // Get again (should reuse) + let _conn2 = pool.get(addr).await.unwrap(); + + let stats = pool.stats().await; + assert_eq!(stats.endpoints, 1); + } + + #[tokio::test] + async fn test_pool_connection_refused() { + let pool = ConnectionPool::with_config(PoolConfig { + connect_timeout: Duration::from_millis(100), + ..Default::default() + }); + + // Try to connect to a port that's not listening + let result = pool.get("127.0.0.1:1".parse().unwrap()).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_pool_stats() { + let pool = ConnectionPool::new(); + + let stats = pool.stats().await; + assert_eq!(stats.endpoints, 0); + assert_eq!(stats.total_connections, 0); + } +} \ No newline at end of file diff --git a/src/transport/proxy_protocol.rs b/src/transport/proxy_protocol.rs new file mode 100644 index 0000000..03c78d8 --- /dev/null +++ b/src/transport/proxy_protocol.rs @@ -0,0 +1,381 @@ +//! HAProxy PROXY protocol V1/V2 + +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use tokio::io::{AsyncRead, AsyncReadExt}; +use crate::error::{ProxyError, Result}; + +/// PROXY protocol v1 signature +const PROXY_V1_SIGNATURE: &[u8] = b"PROXY "; + +/// PROXY protocol v2 signature +const PROXY_V2_SIGNATURE: &[u8] = &[ + 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, + 0x51, 0x55, 0x49, 0x54, 0x0a +]; + +/// Minimum length for v1 detection +const PROXY_V1_MIN_LEN: usize = 6; + +/// Minimum length for v2 header +const PROXY_V2_MIN_LEN: usize = 16; + +/// Address families for v2 +mod address_family { + pub const UNSPEC: u8 = 0x0; + pub const INET: u8 = 0x1; + pub const INET6: u8 = 0x2; +} + +/// Information extracted from PROXY protocol header +#[derive(Debug, Clone)] +pub struct ProxyProtocolInfo { + /// Source (client) address + pub src_addr: SocketAddr, + /// Destination address (optional) + pub dst_addr: Option, + /// Protocol version used (1 or 2) + pub version: u8, +} + +impl ProxyProtocolInfo { + /// Create info with just source address + pub fn new(src_addr: SocketAddr) -> Self { + Self { + src_addr, + dst_addr: None, + version: 0, + } + } +} + +/// Parse PROXY protocol header from a stream +/// +/// Returns the parsed info or an error if the header is invalid. +/// The stream position is advanced past the header. +pub async fn parse_proxy_protocol( + reader: &mut R, + default_peer: SocketAddr, +) -> Result { + // Read enough bytes to detect version + let mut header = [0u8; PROXY_V2_MIN_LEN]; + reader.read_exact(&mut header[..PROXY_V1_MIN_LEN]).await + .map_err(|_| ProxyError::InvalidProxyProtocol)?; + + // Check for v1 + if header[..PROXY_V1_MIN_LEN] == PROXY_V1_SIGNATURE[..] { + return parse_v1(reader, default_peer).await; + } + + // Read rest for v2 detection + reader.read_exact(&mut header[PROXY_V1_MIN_LEN..]).await + .map_err(|_| ProxyError::InvalidProxyProtocol)?; + + // Check for v2 + if header[..12] == PROXY_V2_SIGNATURE[..] { + return parse_v2(reader, &header, default_peer).await; + } + + Err(ProxyError::InvalidProxyProtocol) +} + +/// Parse PROXY protocol v1 +async fn parse_v1( + reader: &mut R, + default_peer: SocketAddr, +) -> Result { + // Read until CRLF (max 107 bytes total for v1) + let mut line = Vec::with_capacity(128); + line.extend_from_slice(PROXY_V1_SIGNATURE); + + loop { + let mut byte = [0u8]; + reader.read_exact(&mut byte).await + .map_err(|_| ProxyError::InvalidProxyProtocol)?; + line.push(byte[0]); + + if line.ends_with(b"\r\n") { + break; + } + + if line.len() > 256 { + return Err(ProxyError::InvalidProxyProtocol); + } + } + + // Parse the line: PROXY TCP4/TCP6/UNKNOWN src_ip dst_ip src_port dst_port + let line_str = std::str::from_utf8(&line[PROXY_V1_MIN_LEN..line.len() - 2]) + .map_err(|_| ProxyError::InvalidProxyProtocol)?; + + let parts: Vec<&str> = line_str.split_whitespace().collect(); + + if parts.is_empty() { + return Err(ProxyError::InvalidProxyProtocol); + } + + match parts[0] { + "TCP4" | "TCP6" if parts.len() >= 5 => { + let src_ip: IpAddr = parts[1].parse() + .map_err(|_| ProxyError::InvalidProxyProtocol)?; + let dst_ip: IpAddr = parts[2].parse() + .map_err(|_| ProxyError::InvalidProxyProtocol)?; + let src_port: u16 = parts[3].parse() + .map_err(|_| ProxyError::InvalidProxyProtocol)?; + let dst_port: u16 = parts[4].parse() + .map_err(|_| ProxyError::InvalidProxyProtocol)?; + + Ok(ProxyProtocolInfo { + src_addr: SocketAddr::new(src_ip, src_port), + dst_addr: Some(SocketAddr::new(dst_ip, dst_port)), + version: 1, + }) + } + "UNKNOWN" => { + // UNKNOWN means no address info, use default + Ok(ProxyProtocolInfo { + src_addr: default_peer, + dst_addr: None, + version: 1, + }) + } + _ => Err(ProxyError::InvalidProxyProtocol), + } +} + +/// Parse PROXY protocol v2 +async fn parse_v2( + reader: &mut R, + header: &[u8; PROXY_V2_MIN_LEN], + default_peer: SocketAddr, +) -> Result { + let version_command = header[12]; + let version = version_command >> 4; + let command = version_command & 0x0f; + + // Must be version 2 + if version != 2 { + return Err(ProxyError::InvalidProxyProtocol); + } + + let family_protocol = header[13]; + let addr_len = u16::from_be_bytes([header[14], header[15]]) as usize; + + // Read address data + let mut addr_data = vec![0u8; addr_len]; + if addr_len > 0 { + reader.read_exact(&mut addr_data).await + .map_err(|_| ProxyError::InvalidProxyProtocol)?; + } + + // LOCAL command (0x0) - use default peer + if command == 0 { + return Ok(ProxyProtocolInfo { + src_addr: default_peer, + dst_addr: None, + version: 2, + }); + } + + // PROXY command (0x1) - parse addresses + if command != 1 { + return Err(ProxyError::InvalidProxyProtocol); + } + + let family = family_protocol >> 4; + + match family { + address_family::INET if addr_len >= 12 => { + // IPv4: 4 + 4 + 2 + 2 = 12 bytes + let src_ip = Ipv4Addr::new( + addr_data[0], addr_data[1], + addr_data[2], addr_data[3] + ); + let dst_ip = Ipv4Addr::new( + addr_data[4], addr_data[5], + addr_data[6], addr_data[7] + ); + let src_port = u16::from_be_bytes([addr_data[8], addr_data[9]]); + let dst_port = u16::from_be_bytes([addr_data[10], addr_data[11]]); + + Ok(ProxyProtocolInfo { + src_addr: SocketAddr::new(IpAddr::V4(src_ip), src_port), + dst_addr: Some(SocketAddr::new(IpAddr::V4(dst_ip), dst_port)), + version: 2, + }) + } + address_family::INET6 if addr_len >= 36 => { + // IPv6: 16 + 16 + 2 + 2 = 36 bytes + let src_ip = Ipv6Addr::from( + <[u8; 16]>::try_from(&addr_data[0..16]).unwrap() + ); + let dst_ip = Ipv6Addr::from( + <[u8; 16]>::try_from(&addr_data[16..32]).unwrap() + ); + let src_port = u16::from_be_bytes([addr_data[32], addr_data[33]]); + let dst_port = u16::from_be_bytes([addr_data[34], addr_data[35]]); + + Ok(ProxyProtocolInfo { + src_addr: SocketAddr::new(IpAddr::V6(src_ip), src_port), + dst_addr: Some(SocketAddr::new(IpAddr::V6(dst_ip), dst_port)), + version: 2, + }) + } + address_family::UNSPEC => { + Ok(ProxyProtocolInfo { + src_addr: default_peer, + dst_addr: None, + version: 2, + }) + } + _ => Err(ProxyError::InvalidProxyProtocol), + } +} + +/// Builder for PROXY protocol v1 header +pub struct ProxyProtocolV1Builder { + family: &'static str, + src_addr: Option, + dst_addr: Option, +} + +impl ProxyProtocolV1Builder { + pub fn new() -> Self { + Self { + family: "UNKNOWN", + src_addr: None, + dst_addr: None, + } + } + + pub fn tcp4(mut self, src: SocketAddr, dst: SocketAddr) -> Self { + self.family = "TCP4"; + self.src_addr = Some(src); + self.dst_addr = Some(dst); + self + } + + pub fn tcp6(mut self, src: SocketAddr, dst: SocketAddr) -> Self { + self.family = "TCP6"; + self.src_addr = Some(src); + self.dst_addr = Some(dst); + self + } + + pub fn build(&self) -> Vec { + match (self.src_addr, self.dst_addr) { + (Some(src), Some(dst)) => { + format!( + "PROXY {} {} {} {} {}\r\n", + self.family, + src.ip(), + dst.ip(), + src.port(), + dst.port() + ).into_bytes() + } + _ => b"PROXY UNKNOWN\r\n".to_vec(), + } + } +} + +impl Default for ProxyProtocolV1Builder { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Cursor; + + #[tokio::test] + async fn test_parse_v1_tcp4() { + let header = b"PROXY TCP4 192.168.1.1 10.0.0.1 12345 443\r\n"; + let mut cursor = Cursor::new(&header[PROXY_V1_MIN_LEN..]); + let default = "0.0.0.0:0".parse().unwrap(); + + // Simulate that we've already read the signature + let info = parse_v1(&mut cursor, default).await.unwrap(); + + assert_eq!(info.version, 1); + assert_eq!(info.src_addr.ip().to_string(), "192.168.1.1"); + assert_eq!(info.src_addr.port(), 12345); + assert!(info.dst_addr.is_some()); + } + + #[tokio::test] + async fn test_parse_v1_unknown() { + let header = b"PROXY UNKNOWN\r\n"; + let mut cursor = Cursor::new(&header[PROXY_V1_MIN_LEN..]); + let default: SocketAddr = "1.2.3.4:5678".parse().unwrap(); + + let info = parse_v1(&mut cursor, default).await.unwrap(); + + assert_eq!(info.version, 1); + assert_eq!(info.src_addr, default); + } + + #[tokio::test] + async fn test_parse_v2_tcp4() { + // v2 header for TCP4 + let mut header = [0u8; 16]; + header[..12].copy_from_slice(PROXY_V2_SIGNATURE); + header[12] = 0x21; // v2, PROXY command + header[13] = 0x11; // AF_INET, STREAM + header[14] = 0x00; + header[15] = 0x0c; // 12 bytes of address data + + let addr_data = [ + 192, 168, 1, 1, // src IP + 10, 0, 0, 1, // dst IP + 0x30, 0x39, // src port (12345) + 0x01, 0xbb, // dst port (443) + ]; + + let mut cursor = Cursor::new(addr_data.to_vec()); + let default = "0.0.0.0:0".parse().unwrap(); + + let info = parse_v2(&mut cursor, &header, default).await.unwrap(); + + assert_eq!(info.version, 2); + assert_eq!(info.src_addr.ip().to_string(), "192.168.1.1"); + assert_eq!(info.src_addr.port(), 12345); + } + + #[tokio::test] + async fn test_parse_v2_local() { + let mut header = [0u8; 16]; + header[..12].copy_from_slice(PROXY_V2_SIGNATURE); + header[12] = 0x20; // v2, LOCAL command + header[13] = 0x00; + header[14] = 0x00; + header[15] = 0x00; // 0 bytes of address data + + let mut cursor = Cursor::new(Vec::new()); + let default: SocketAddr = "1.2.3.4:5678".parse().unwrap(); + + let info = parse_v2(&mut cursor, &header, default).await.unwrap(); + + assert_eq!(info.version, 2); + assert_eq!(info.src_addr, default); + } + + #[test] + fn test_v1_builder() { + let src: SocketAddr = "192.168.1.1:12345".parse().unwrap(); + let dst: SocketAddr = "10.0.0.1:443".parse().unwrap(); + + let header = ProxyProtocolV1Builder::new() + .tcp4(src, dst) + .build(); + + let expected = b"PROXY TCP4 192.168.1.1 10.0.0.1 12345 443\r\n"; + assert_eq!(header, expected); + } + + #[test] + fn test_v1_builder_unknown() { + let header = ProxyProtocolV1Builder::new().build(); + assert_eq!(header, b"PROXY UNKNOWN\r\n"); + } +} \ No newline at end of file diff --git a/src/transport/socket.rs b/src/transport/socket.rs new file mode 100644 index 0000000..10c227a --- /dev/null +++ b/src/transport/socket.rs @@ -0,0 +1,230 @@ +//! TCP Socket Configuration + +use std::io::Result; +use std::net::SocketAddr; +use std::time::Duration; +use tokio::net::TcpStream; +use socket2::{Socket, TcpKeepalive, Domain, Type, Protocol}; +use tracing::debug; + +/// Configure TCP socket with recommended settings for proxy use +pub fn configure_tcp_socket( + stream: &TcpStream, + keepalive: bool, + keepalive_interval: Duration, +) -> Result<()> { + let socket = socket2::SockRef::from(stream); + + // Disable Nagle's algorithm for lower latency + socket.set_nodelay(true)?; + + // Set keepalive if enabled + if keepalive { + let keepalive = TcpKeepalive::new() + .with_time(keepalive_interval); + + // Platform-specific keepalive settings + #[cfg(any(target_os = "linux", target_os = "macos", target_os = "ios"))] + let keepalive = keepalive.with_interval(keepalive_interval); + + socket.set_tcp_keepalive(&keepalive)?; + } + + // Set buffer sizes + set_buffer_sizes(&socket, 65536, 65536)?; + + Ok(()) +} + +/// Set socket buffer sizes +fn set_buffer_sizes(socket: &socket2::SockRef, recv: usize, send: usize) -> Result<()> { + // These may fail on some systems, so we ignore errors + let _ = socket.set_recv_buffer_size(recv); + let _ = socket.set_send_buffer_size(send); + Ok(()) +} + +/// Configure socket for accepting client connections +pub fn configure_client_socket( + stream: &TcpStream, + keepalive_secs: u64, + ack_timeout_secs: u64, +) -> Result<()> { + let socket = socket2::SockRef::from(stream); + + // Disable Nagle's algorithm + socket.set_nodelay(true)?; + + // Set keepalive + let keepalive = TcpKeepalive::new() + .with_time(Duration::from_secs(keepalive_secs)); + + #[cfg(any(target_os = "linux", target_os = "macos", target_os = "ios"))] + let keepalive = keepalive.with_interval(Duration::from_secs(keepalive_secs)); + + socket.set_tcp_keepalive(&keepalive)?; + + // Set TCP user timeout (Linux only) + #[cfg(target_os = "linux")] + { + use std::os::unix::io::AsRawFd; + let fd = stream.as_raw_fd(); + let timeout_ms = (ack_timeout_secs * 1000) as libc::c_int; + unsafe { + libc::setsockopt( + fd, + libc::IPPROTO_TCP, + libc::TCP_USER_TIMEOUT, + &timeout_ms as *const _ as *const libc::c_void, + std::mem::size_of::() as libc::socklen_t, + ); + } + } + + Ok(()) +} + +/// Set socket to send RST on close (for masking) +pub fn set_linger_zero(stream: &TcpStream) -> Result<()> { + let socket = socket2::SockRef::from(stream); + socket.set_linger(Some(Duration::ZERO))?; + Ok(()) +} + +/// Create a new TCP socket for outgoing connections +pub fn create_outgoing_socket(addr: SocketAddr) -> Result { + let domain = if addr.is_ipv4() { + Domain::IPV4 + } else { + Domain::IPV6 + }; + + let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; + + // Set non-blocking + socket.set_nonblocking(true)?; + + // Disable Nagle + socket.set_nodelay(true)?; + + Ok(socket) +} + +/// Get local address of a socket +pub fn get_local_addr(stream: &TcpStream) -> Option { + stream.local_addr().ok() +} + +/// Get peer address of a socket +pub fn get_peer_addr(stream: &TcpStream) -> Option { + stream.peer_addr().ok() +} + +/// Check if address is IPv6 +pub fn is_ipv6(addr: &SocketAddr) -> bool { + addr.is_ipv6() +} + +/// Parse IPv4-mapped IPv6 address to IPv4 +pub fn normalize_ip(addr: SocketAddr) -> SocketAddr { + match addr { + SocketAddr::V6(v6) => { + if let Some(v4) = v6.ip().to_ipv4_mapped() { + SocketAddr::new(std::net::IpAddr::V4(v4), v6.port()) + } else { + addr + } + } + _ => addr, + } +} + +/// Socket options for server listening +#[derive(Debug, Clone)] +pub struct ListenOptions { + /// Enable SO_REUSEADDR + pub reuse_addr: bool, + /// Enable SO_REUSEPORT (Linux/BSD) + pub reuse_port: bool, + /// Backlog size + pub backlog: u32, + /// IPv6 only (disable dual-stack) + pub ipv6_only: bool, +} + +impl Default for ListenOptions { + fn default() -> Self { + Self { + reuse_addr: true, + reuse_port: true, + backlog: 1024, + ipv6_only: false, + } + } +} + +/// Create a listening socket with the specified options +pub fn create_listener(addr: SocketAddr, options: &ListenOptions) -> Result { + let domain = if addr.is_ipv4() { + Domain::IPV4 + } else { + Domain::IPV6 + }; + + let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; + + if options.reuse_addr { + socket.set_reuse_address(true)?; + } + + #[cfg(unix)] + if options.reuse_port { + socket.set_reuse_port(true)?; + } + + if addr.is_ipv6() && options.ipv6_only { + socket.set_only_v6(true)?; + } + + socket.set_nonblocking(true)?; + socket.bind(&addr.into())?; + socket.listen(options.backlog as i32)?; + + debug!(addr = %addr, "Created listening socket"); + + Ok(socket) +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::net::TcpListener; + + #[tokio::test] + async fn test_configure_socket() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let stream = TcpStream::connect(addr).await.unwrap(); + configure_tcp_socket(&stream, true, Duration::from_secs(30)).unwrap(); + } + + #[test] + fn test_normalize_ip() { + // IPv4 stays IPv4 + let v4: SocketAddr = "192.168.1.1:8080".parse().unwrap(); + assert_eq!(normalize_ip(v4), v4); + + // Pure IPv6 stays IPv6 + let v6: SocketAddr = "[::1]:8080".parse().unwrap(); + assert_eq!(normalize_ip(v6), v6); + } + + #[test] + fn test_listen_options_default() { + let opts = ListenOptions::default(); + assert!(opts.reuse_addr); + assert!(opts.reuse_port); + assert_eq!(opts.backlog, 1024); + } +} \ No newline at end of file diff --git a/src/util/ip.rs b/src/util/ip.rs new file mode 100644 index 0000000..fda108c --- /dev/null +++ b/src/util/ip.rs @@ -0,0 +1,118 @@ +//! IP Addr Detect + +use std::net::IpAddr; +use std::time::Duration; +use tracing::{debug, warn}; + +/// Detected IP addresses +#[derive(Debug, Clone, Default)] +pub struct IpInfo { + pub ipv4: Option, + pub ipv6: Option, +} + +impl IpInfo { + /// Check if any IP is detected + pub fn has_any(&self) -> bool { + self.ipv4.is_some() || self.ipv6.is_some() + } + + /// Get preferred IP (IPv6 if available and preferred) + pub fn preferred(&self, prefer_ipv6: bool) -> Option { + if prefer_ipv6 { + self.ipv6.or(self.ipv4) + } else { + self.ipv4.or(self.ipv6) + } + } +} + +/// URLs for IP detection +const IPV4_URLS: &[&str] = &[ + "http://v4.ident.me/", + "http://ipv4.icanhazip.com/", + "http://api.ipify.org/", +]; + +const IPV6_URLS: &[&str] = &[ + "http://v6.ident.me/", + "http://ipv6.icanhazip.com/", + "http://api6.ipify.org/", +]; + +/// Detect public IP addresses +pub async fn detect_ip() -> IpInfo { + let mut info = IpInfo::default(); + + // Detect IPv4 + for url in IPV4_URLS { + if let Some(ip) = fetch_ip(url).await { + if ip.is_ipv4() { + info.ipv4 = Some(ip); + debug!(ip = %ip, "Detected IPv4 address"); + break; + } + } + } + + // Detect IPv6 + for url in IPV6_URLS { + if let Some(ip) = fetch_ip(url).await { + if ip.is_ipv6() { + info.ipv6 = Some(ip); + debug!(ip = %ip, "Detected IPv6 address"); + break; + } + } + } + + if !info.has_any() { + warn!("Failed to detect public IP address"); + } + + info +} + +/// Fetch IP from URL +async fn fetch_ip(url: &str) -> Option { + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(5)) + .build() + .ok()?; + + let response = client.get(url).send().await.ok()?; + let text = response.text().await.ok()?; + + text.trim().parse().ok() +} + +/// Synchronous IP detection (for startup) +pub fn detect_ip_sync() -> IpInfo { + tokio::runtime::Handle::current().block_on(detect_ip()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ip_info() { + let info = IpInfo::default(); + assert!(!info.has_any()); + + let info = IpInfo { + ipv4: Some("1.2.3.4".parse().unwrap()), + ipv6: None, + }; + assert!(info.has_any()); + assert_eq!(info.preferred(false), Some("1.2.3.4".parse().unwrap())); + assert_eq!(info.preferred(true), Some("1.2.3.4".parse().unwrap())); + + let info = IpInfo { + ipv4: Some("1.2.3.4".parse().unwrap()), + ipv6: Some("::1".parse().unwrap()), + }; + assert_eq!(info.preferred(false), Some("1.2.3.4".parse().unwrap())); + assert_eq!(info.preferred(true), Some("::1".parse().unwrap())); + } +} \ No newline at end of file diff --git a/src/util/mod.rs b/src/util/mod.rs new file mode 100644 index 0000000..5d293d2 --- /dev/null +++ b/src/util/mod.rs @@ -0,0 +1,7 @@ +//! Utils + +pub mod ip; +pub mod time; + +pub use ip::*; +pub use time::*; \ No newline at end of file diff --git a/src/util/time.rs b/src/util/time.rs new file mode 100644 index 0000000..7db1633 --- /dev/null +++ b/src/util/time.rs @@ -0,0 +1,76 @@ +//! Time Sync + +use std::time::Duration; +use chrono::{DateTime, Utc}; +use tracing::{debug, warn, error}; + +const TIME_SYNC_URL: &str = "https://core.telegram.org/getProxySecret"; +const MAX_TIME_SKEW_SECS: i64 = 30; + +/// Time sync result +#[derive(Debug, Clone)] +pub struct TimeSyncResult { + pub server_time: DateTime, + pub local_time: DateTime, + pub skew_secs: i64, + pub is_skewed: bool, +} + +/// Check time synchronization with Telegram servers +pub async fn check_time_sync() -> Option { + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(10)) + .build() + .ok()?; + + let response = client.get(TIME_SYNC_URL).send().await.ok()?; + + // Get Date header + let date_header = response.headers().get("date")?; + let date_str = date_header.to_str().ok()?; + + // Parse date + let server_time = DateTime::parse_from_rfc2822(date_str) + .ok()? + .with_timezone(&Utc); + + let local_time = Utc::now(); + let skew_secs = (local_time - server_time).num_seconds(); + let is_skewed = skew_secs.abs() > MAX_TIME_SKEW_SECS; + + let result = TimeSyncResult { + server_time, + local_time, + skew_secs, + is_skewed, + }; + + if is_skewed { + warn!( + server = %server_time, + local = %local_time, + skew = skew_secs, + "Time skew detected" + ); + } else { + debug!(skew = skew_secs, "Time sync OK"); + } + + Some(result) +} + +/// Background time sync task +pub async fn time_sync_task(check_interval: Duration) -> ! { + loop { + if let Some(result) = check_time_sync().await { + if result.is_skewed { + error!( + "System clock is off by {} seconds. Please sync your clock.", + result.skew_secs + ); + } + } + + tokio::time::sleep(check_interval).await; + } +} \ No newline at end of file