1.0.0
Tschuss Status Quo - Hallo, Zukunft!
This commit is contained in:
60
Cargo.toml
Normal file
60
Cargo.toml
Normal file
@@ -0,0 +1,60 @@
|
||||
[package]
|
||||
name = "telemt"
|
||||
version = "1.0.0"
|
||||
edition = "2021"
|
||||
rust-version = "1.75"
|
||||
|
||||
[dependencies]
|
||||
# C
|
||||
libc = "0.2"
|
||||
|
||||
# Async runtime
|
||||
tokio = { version = "1.35", features = ["full", "tracing"] }
|
||||
tokio-util = { version = "0.7", features = ["codec"] }
|
||||
|
||||
# Crypto
|
||||
aes = "0.8"
|
||||
ctr = "0.9"
|
||||
cbc = "0.1"
|
||||
sha2 = "0.10"
|
||||
sha1 = "0.10"
|
||||
md-5 = "0.10"
|
||||
hmac = "0.12"
|
||||
crc32fast = "1.3"
|
||||
|
||||
# Network
|
||||
socket2 = { version = "0.5", features = ["all"] }
|
||||
rustls = "0.22"
|
||||
|
||||
# Serial
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
toml = "0.8"
|
||||
|
||||
# Utils
|
||||
bytes = "1.5"
|
||||
thiserror = "1.0"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
parking_lot = "0.12"
|
||||
dashmap = "5.5"
|
||||
lru = "0.12"
|
||||
rand = "0.8"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
hex = "0.4"
|
||||
base64 = "0.21"
|
||||
url = "2.5"
|
||||
regex = "1.10"
|
||||
once_cell = "1.19"
|
||||
|
||||
# HTTP
|
||||
reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = "0.4"
|
||||
criterion = "0.5"
|
||||
proptest = "1.4"
|
||||
|
||||
[[bench]]
|
||||
name = "crypto_bench"
|
||||
harness = false
|
||||
12
benches/crypto_bench.rs
Normal file
12
benches/crypto_bench.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
// Cryptobench
|
||||
use criterion::{black_box, criterion_group, Criterion};
|
||||
|
||||
fn bench_aes_ctr(c: &mut Criterion) {
|
||||
c.bench_function("aes_ctr_encrypt_64kb", |b| {
|
||||
let data = vec![0u8; 65536];
|
||||
b.iter(|| {
|
||||
let mut enc = AesCtr::new(&[0u8; 32], 0);
|
||||
black_box(enc.encrypt(&data))
|
||||
})
|
||||
});
|
||||
}
|
||||
13
config.toml
Normal file
13
config.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
port = 443
|
||||
|
||||
[users]
|
||||
user1 = "00000000000000000000000000000000"
|
||||
|
||||
[modes]
|
||||
classic = true
|
||||
secure = true
|
||||
tls = true
|
||||
|
||||
tls_domain = "www.github.com"
|
||||
fast_mode = true
|
||||
prefer_ipv6 = false
|
||||
227
src/config/mod.rs
Normal file
227
src/config/mod.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
//! Configuration
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::path::Path;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::error::{ProxyError, Result};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProxyModes {
|
||||
#[serde(default)]
|
||||
pub classic: bool,
|
||||
#[serde(default)]
|
||||
pub secure: bool,
|
||||
#[serde(default = "default_true")]
|
||||
pub tls: bool,
|
||||
}
|
||||
|
||||
fn default_true() -> bool { true }
|
||||
|
||||
impl Default for ProxyModes {
|
||||
fn default() -> Self {
|
||||
Self { classic: true, secure: true, tls: true }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProxyConfig {
|
||||
#[serde(default = "default_port")]
|
||||
pub port: u16,
|
||||
|
||||
#[serde(default)]
|
||||
pub users: HashMap<String, String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub ad_tag: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub modes: ProxyModes,
|
||||
|
||||
#[serde(default = "default_tls_domain")]
|
||||
pub tls_domain: String,
|
||||
|
||||
#[serde(default = "default_true")]
|
||||
pub mask: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub mask_host: Option<String>,
|
||||
|
||||
#[serde(default = "default_mask_port")]
|
||||
pub mask_port: u16,
|
||||
|
||||
#[serde(default)]
|
||||
pub prefer_ipv6: bool,
|
||||
|
||||
#[serde(default = "default_true")]
|
||||
pub fast_mode: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub use_middle_proxy: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub user_max_tcp_conns: HashMap<String, usize>,
|
||||
|
||||
#[serde(default)]
|
||||
pub user_expirations: HashMap<String, DateTime<Utc>>,
|
||||
|
||||
#[serde(default)]
|
||||
pub user_data_quota: HashMap<String, u64>,
|
||||
|
||||
#[serde(default = "default_replay_check_len")]
|
||||
pub replay_check_len: usize,
|
||||
|
||||
#[serde(default)]
|
||||
pub ignore_time_skew: bool,
|
||||
|
||||
#[serde(default = "default_handshake_timeout")]
|
||||
pub client_handshake_timeout: u64,
|
||||
|
||||
#[serde(default = "default_connect_timeout")]
|
||||
pub tg_connect_timeout: u64,
|
||||
|
||||
#[serde(default = "default_keepalive")]
|
||||
pub client_keepalive: u64,
|
||||
|
||||
#[serde(default = "default_ack_timeout")]
|
||||
pub client_ack_timeout: u64,
|
||||
|
||||
#[serde(default = "default_listen_addr")]
|
||||
pub listen_addr_ipv4: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub listen_addr_ipv6: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub listen_unix_sock: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub metrics_port: Option<u16>,
|
||||
|
||||
#[serde(default = "default_metrics_whitelist")]
|
||||
pub metrics_whitelist: Vec<IpAddr>,
|
||||
|
||||
#[serde(default = "default_fake_cert_len")]
|
||||
pub fake_cert_len: usize,
|
||||
}
|
||||
|
||||
fn default_port() -> u16 { 443 }
|
||||
fn default_tls_domain() -> String { "www.google.com".to_string() }
|
||||
fn default_mask_port() -> u16 { 443 }
|
||||
fn default_replay_check_len() -> usize { 65536 }
|
||||
fn default_handshake_timeout() -> u64 { 10 }
|
||||
fn default_connect_timeout() -> u64 { 10 }
|
||||
fn default_keepalive() -> u64 { 600 }
|
||||
fn default_ack_timeout() -> u64 { 300 }
|
||||
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
|
||||
fn default_fake_cert_len() -> usize { 2048 }
|
||||
|
||||
fn default_metrics_whitelist() -> Vec<IpAddr> {
|
||||
vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"::1".parse().unwrap(),
|
||||
]
|
||||
}
|
||||
|
||||
impl Default for ProxyConfig {
|
||||
fn default() -> Self {
|
||||
let mut users = HashMap::new();
|
||||
users.insert("default".to_string(), "00000000000000000000000000000000".to_string());
|
||||
|
||||
Self {
|
||||
port: default_port(),
|
||||
users,
|
||||
ad_tag: None,
|
||||
modes: ProxyModes::default(),
|
||||
tls_domain: default_tls_domain(),
|
||||
mask: true,
|
||||
mask_host: None,
|
||||
mask_port: default_mask_port(),
|
||||
prefer_ipv6: false,
|
||||
fast_mode: true,
|
||||
use_middle_proxy: false,
|
||||
user_max_tcp_conns: HashMap::new(),
|
||||
user_expirations: HashMap::new(),
|
||||
user_data_quota: HashMap::new(),
|
||||
replay_check_len: default_replay_check_len(),
|
||||
ignore_time_skew: false,
|
||||
client_handshake_timeout: default_handshake_timeout(),
|
||||
tg_connect_timeout: default_connect_timeout(),
|
||||
client_keepalive: default_keepalive(),
|
||||
client_ack_timeout: default_ack_timeout(),
|
||||
listen_addr_ipv4: default_listen_addr(),
|
||||
listen_addr_ipv6: Some("::".to_string()),
|
||||
listen_unix_sock: None,
|
||||
metrics_port: None,
|
||||
metrics_whitelist: default_metrics_whitelist(),
|
||||
fake_cert_len: default_fake_cert_len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProxyConfig {
|
||||
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
|
||||
let content = std::fs::read_to_string(path)
|
||||
.map_err(|e| ProxyError::Config(e.to_string()))?;
|
||||
|
||||
let mut config: ProxyConfig = toml::from_str(&content)
|
||||
.map_err(|e| ProxyError::Config(e.to_string()))?;
|
||||
|
||||
// Validate secrets
|
||||
for (user, secret) in &config.users {
|
||||
if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 {
|
||||
return Err(ProxyError::InvalidSecret {
|
||||
user: user.clone(),
|
||||
reason: "Must be 32 hex characters".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Default mask_host
|
||||
if config.mask_host.is_none() {
|
||||
config.mask_host = Some(config.tls_domain.clone());
|
||||
}
|
||||
|
||||
// Random fake_cert_len
|
||||
use rand::Rng;
|
||||
config.fake_cert_len = rand::thread_rng().gen_range(1024..4096);
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
if self.users.is_empty() {
|
||||
return Err(ProxyError::Config("No users configured".to_string()));
|
||||
}
|
||||
|
||||
if !self.modes.classic && !self.modes.secure && !self.modes.tls {
|
||||
return Err(ProxyError::Config("No modes enabled".to_string()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = ProxyConfig::default();
|
||||
assert_eq!(config.port, 443);
|
||||
assert!(config.modes.tls);
|
||||
assert_eq!(config.client_keepalive, 600);
|
||||
assert_eq!(config.client_ack_timeout, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_validate() {
|
||||
let mut config = ProxyConfig::default();
|
||||
assert!(config.validate().is_ok());
|
||||
|
||||
config.users.clear();
|
||||
assert!(config.validate().is_err());
|
||||
}
|
||||
}
|
||||
351
src/crypto/aes.rs
Normal file
351
src/crypto/aes.rs
Normal file
@@ -0,0 +1,351 @@
|
||||
//! AES
|
||||
|
||||
use aes::Aes256;
|
||||
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
|
||||
use cbc::{Encryptor as CbcEncryptor, Decryptor as CbcDecryptor};
|
||||
use cbc::cipher::{BlockEncryptMut, BlockDecryptMut, block_padding::NoPadding};
|
||||
use crate::error::{ProxyError, Result};
|
||||
|
||||
type Aes256Ctr = Ctr128BE<Aes256>;
|
||||
type Aes256CbcEnc = CbcEncryptor<Aes256>;
|
||||
type Aes256CbcDec = CbcDecryptor<Aes256>;
|
||||
|
||||
/// AES-256-CTR encryptor/decryptor
|
||||
pub struct AesCtr {
|
||||
cipher: Aes256Ctr,
|
||||
}
|
||||
|
||||
impl AesCtr {
|
||||
pub fn new(key: &[u8; 32], iv: u128) -> Self {
|
||||
let iv_bytes = iv.to_be_bytes();
|
||||
Self {
|
||||
cipher: Aes256Ctr::new(key.into(), (&iv_bytes).into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result<Self> {
|
||||
if key.len() != 32 {
|
||||
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
|
||||
}
|
||||
if iv.len() != 16 {
|
||||
return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() });
|
||||
}
|
||||
|
||||
let key: [u8; 32] = key.try_into().unwrap();
|
||||
let iv = u128::from_be_bytes(iv.try_into().unwrap());
|
||||
Ok(Self::new(&key, iv))
|
||||
}
|
||||
|
||||
/// Encrypt/decrypt data in-place (CTR mode is symmetric)
|
||||
pub fn apply(&mut self, data: &mut [u8]) {
|
||||
self.cipher.apply_keystream(data);
|
||||
}
|
||||
|
||||
/// Encrypt data, returning new buffer
|
||||
pub fn encrypt(&mut self, data: &[u8]) -> Vec<u8> {
|
||||
let mut output = data.to_vec();
|
||||
self.apply(&mut output);
|
||||
output
|
||||
}
|
||||
|
||||
/// Decrypt data (for CTR, identical to encrypt)
|
||||
pub fn decrypt(&mut self, data: &[u8]) -> Vec<u8> {
|
||||
self.encrypt(data)
|
||||
}
|
||||
}
|
||||
|
||||
/// AES-256-CBC Ciphermagic
|
||||
pub struct AesCbc {
|
||||
key: [u8; 32],
|
||||
iv: [u8; 16],
|
||||
}
|
||||
|
||||
impl AesCbc {
|
||||
pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self {
|
||||
Self { key, iv }
|
||||
}
|
||||
|
||||
pub fn from_slices(key: &[u8], iv: &[u8]) -> Result<Self> {
|
||||
if key.len() != 32 {
|
||||
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
|
||||
}
|
||||
if iv.len() != 16 {
|
||||
return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() });
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
key: key.try_into().unwrap(),
|
||||
iv: iv.try_into().unwrap(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Encrypt data using CBC mode
|
||||
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
|
||||
if data.len() % 16 != 0 {
|
||||
return Err(ProxyError::Crypto(
|
||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||
));
|
||||
}
|
||||
|
||||
if data.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut buffer = data.to_vec();
|
||||
|
||||
let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into());
|
||||
|
||||
for chunk in buffer.chunks_mut(16) {
|
||||
encryptor.encrypt_block_mut(chunk.into());
|
||||
}
|
||||
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
/// Decrypt data using CBC mode
|
||||
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
|
||||
if data.len() % 16 != 0 {
|
||||
return Err(ProxyError::Crypto(
|
||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||
));
|
||||
}
|
||||
|
||||
if data.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut buffer = data.to_vec();
|
||||
|
||||
let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into());
|
||||
|
||||
for chunk in buffer.chunks_mut(16) {
|
||||
decryptor.decrypt_block_mut(chunk.into());
|
||||
}
|
||||
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
/// Encrypt data in-place
|
||||
pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
|
||||
if data.len() % 16 != 0 {
|
||||
return Err(ProxyError::Crypto(
|
||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||
));
|
||||
}
|
||||
|
||||
if data.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into());
|
||||
|
||||
for chunk in data.chunks_mut(16) {
|
||||
encryptor.encrypt_block_mut(chunk.into());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Decrypt data in-place
|
||||
pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
|
||||
if data.len() % 16 != 0 {
|
||||
return Err(ProxyError::Crypto(
|
||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||
));
|
||||
}
|
||||
|
||||
if data.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into());
|
||||
|
||||
for chunk in data.chunks_mut(16) {
|
||||
decryptor.decrypt_block_mut(chunk.into());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for unified encryption interface
|
||||
pub trait Encryptor: Send + Sync {
|
||||
fn encrypt(&mut self, data: &[u8]) -> Vec<u8>;
|
||||
}
|
||||
|
||||
/// Trait for unified decryption interface
|
||||
pub trait Decryptor: Send + Sync {
|
||||
fn decrypt(&mut self, data: &[u8]) -> Vec<u8>;
|
||||
}
|
||||
|
||||
impl Encryptor for AesCtr {
|
||||
fn encrypt(&mut self, data: &[u8]) -> Vec<u8> {
|
||||
AesCtr::encrypt(self, data)
|
||||
}
|
||||
}
|
||||
|
||||
impl Decryptor for AesCtr {
|
||||
fn decrypt(&mut self, data: &[u8]) -> Vec<u8> {
|
||||
AesCtr::decrypt(self, data)
|
||||
}
|
||||
}
|
||||
|
||||
/// No-op encryptor for fast mode
|
||||
pub struct PassthroughEncryptor;
|
||||
|
||||
impl Encryptor for PassthroughEncryptor {
|
||||
fn encrypt(&mut self, data: &[u8]) -> Vec<u8> {
|
||||
data.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
impl Decryptor for PassthroughEncryptor {
|
||||
fn decrypt(&mut self, data: &[u8]) -> Vec<u8> {
|
||||
data.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_aes_ctr_roundtrip() {
|
||||
let key = [0u8; 32];
|
||||
let iv = 12345u128;
|
||||
|
||||
let original = b"Hello, MTProto!";
|
||||
|
||||
let mut enc = AesCtr::new(&key, iv);
|
||||
let encrypted = enc.encrypt(original);
|
||||
|
||||
let mut dec = AesCtr::new(&key, iv);
|
||||
let decrypted = dec.decrypt(&encrypted);
|
||||
|
||||
assert_eq!(original.as_slice(), decrypted.as_slice());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aes_cbc_roundtrip() {
|
||||
let key = [0u8; 32];
|
||||
let iv = [0u8; 16];
|
||||
|
||||
// Must be aligned to 16 bytes
|
||||
let original = [0u8; 32];
|
||||
|
||||
let cipher = AesCbc::new(key, iv);
|
||||
let encrypted = cipher.encrypt(&original).unwrap();
|
||||
let decrypted = cipher.decrypt(&encrypted).unwrap();
|
||||
|
||||
assert_eq!(original.as_slice(), decrypted.as_slice());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aes_cbc_chaining_works() {
|
||||
let key = [0x42u8; 32];
|
||||
let iv = [0x00u8; 16];
|
||||
|
||||
let plaintext = [0xAA_u8; 32];
|
||||
|
||||
let cipher = AesCbc::new(key, iv);
|
||||
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
||||
|
||||
// CBC Corrections
|
||||
let block1 = &ciphertext[0..16];
|
||||
let block2 = &ciphertext[16..32];
|
||||
|
||||
assert_ne!(block1, block2, "CBC chaining broken: identical plaintext blocks produced identical ciphertext");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aes_cbc_known_vector() {
|
||||
let key = [0u8; 32];
|
||||
let iv = [0u8; 16];
|
||||
|
||||
// 3 Datablocks
|
||||
let plaintext = [
|
||||
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
|
||||
0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF,
|
||||
// Block 2
|
||||
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
|
||||
0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF,
|
||||
// Block 3 - different
|
||||
0xFF, 0xEE, 0xDD, 0xCC, 0xBB, 0xAA, 0x99, 0x88,
|
||||
0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11, 0x00,
|
||||
];
|
||||
|
||||
let cipher = AesCbc::new(key, iv);
|
||||
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
||||
|
||||
// Decrypt + Verify
|
||||
let decrypted = cipher.decrypt(&ciphertext).unwrap();
|
||||
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
|
||||
|
||||
// Verify Ciphertexts Block 1 != Block 2
|
||||
assert_ne!(&ciphertext[0..16], &ciphertext[16..32]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aes_cbc_in_place() {
|
||||
let key = [0x12u8; 32];
|
||||
let iv = [0x34u8; 16];
|
||||
|
||||
let original = [0x56u8; 48]; // 3 blocks
|
||||
let mut buffer = original.clone();
|
||||
|
||||
let cipher = AesCbc::new(key, iv);
|
||||
|
||||
cipher.encrypt_in_place(&mut buffer).unwrap();
|
||||
assert_ne!(&buffer[..], &original[..]);
|
||||
|
||||
cipher.decrypt_in_place(&mut buffer).unwrap();
|
||||
assert_eq!(&buffer[..], &original[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aes_cbc_empty_data() {
|
||||
let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
|
||||
|
||||
let encrypted = cipher.encrypt(&[]).unwrap();
|
||||
assert!(encrypted.is_empty());
|
||||
|
||||
let decrypted = cipher.decrypt(&[]).unwrap();
|
||||
assert!(decrypted.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aes_cbc_unaligned_error() {
|
||||
let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
|
||||
|
||||
// 15 bytes
|
||||
let result = cipher.encrypt(&[0u8; 15]);
|
||||
assert!(result.is_err());
|
||||
|
||||
// 17 bytes
|
||||
let result = cipher.encrypt(&[0u8; 17]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aes_cbc_avalanche_effect() {
|
||||
// Cipherplane
|
||||
|
||||
let key = [0xAB; 32];
|
||||
let iv = [0xCD; 16];
|
||||
|
||||
let mut plaintext1 = [0u8; 32];
|
||||
let mut plaintext2 = [0u8; 32];
|
||||
plaintext2[0] = 0x01; // Один бит отличается
|
||||
|
||||
let cipher = AesCbc::new(key, iv);
|
||||
|
||||
let ciphertext1 = cipher.encrypt(&plaintext1).unwrap();
|
||||
let ciphertext2 = cipher.encrypt(&plaintext2).unwrap();
|
||||
|
||||
// First Blocks Diff
|
||||
assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]);
|
||||
|
||||
// Second Blocks Diff
|
||||
assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]);
|
||||
}
|
||||
}
|
||||
90
src/crypto/hash.rs
Normal file
90
src/crypto/hash.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use md5::Md5;
|
||||
use sha1::Sha1;
|
||||
use sha2::Digest;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
/// SHA-256
|
||||
pub fn sha256(data: &[u8]) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(data);
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
/// SHA-256 HMAC
|
||||
pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] {
|
||||
let mut mac = HmacSha256::new_from_slice(key)
|
||||
.expect("HMAC accepts any key length");
|
||||
mac.update(data);
|
||||
mac.finalize().into_bytes().into()
|
||||
}
|
||||
|
||||
/// SHA-1
|
||||
pub fn sha1(data: &[u8]) -> [u8; 20] {
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(data);
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
/// MD5
|
||||
pub fn md5(data: &[u8]) -> [u8; 16] {
|
||||
let mut hasher = Md5::new();
|
||||
hasher.update(data);
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
/// CRC32
|
||||
pub fn crc32(data: &[u8]) -> u32 {
|
||||
crc32fast::hash(data)
|
||||
}
|
||||
|
||||
/// Middle Proxy Keygen
|
||||
pub fn derive_middleproxy_keys(
|
||||
nonce_srv: &[u8; 16],
|
||||
nonce_clt: &[u8; 16],
|
||||
clt_ts: &[u8; 4],
|
||||
srv_ip: Option<&[u8]>,
|
||||
clt_port: &[u8; 2],
|
||||
purpose: &[u8],
|
||||
clt_ip: Option<&[u8]>,
|
||||
srv_port: &[u8; 2],
|
||||
secret: &[u8],
|
||||
clt_ipv6: Option<&[u8; 16]>,
|
||||
srv_ipv6: Option<&[u8; 16]>,
|
||||
) -> ([u8; 32], [u8; 16]) {
|
||||
const EMPTY_IP: [u8; 4] = [0, 0, 0, 0];
|
||||
|
||||
let srv_ip = srv_ip.unwrap_or(&EMPTY_IP);
|
||||
let clt_ip = clt_ip.unwrap_or(&EMPTY_IP);
|
||||
|
||||
let mut s = Vec::with_capacity(256);
|
||||
s.extend_from_slice(nonce_srv);
|
||||
s.extend_from_slice(nonce_clt);
|
||||
s.extend_from_slice(clt_ts);
|
||||
s.extend_from_slice(srv_ip);
|
||||
s.extend_from_slice(clt_port);
|
||||
s.extend_from_slice(purpose);
|
||||
s.extend_from_slice(clt_ip);
|
||||
s.extend_from_slice(srv_port);
|
||||
s.extend_from_slice(secret);
|
||||
s.extend_from_slice(nonce_srv);
|
||||
|
||||
if let (Some(clt_v6), Some(srv_v6)) = (clt_ipv6, srv_ipv6) {
|
||||
s.extend_from_slice(clt_v6);
|
||||
s.extend_from_slice(srv_v6);
|
||||
}
|
||||
|
||||
s.extend_from_slice(nonce_clt);
|
||||
|
||||
let md5_1 = md5(&s[1..]);
|
||||
let sha1_sum = sha1(&s);
|
||||
let md5_2 = md5(&s[2..]);
|
||||
|
||||
let mut key = [0u8; 32];
|
||||
key[..12].copy_from_slice(&md5_1[..12]);
|
||||
key[12..].copy_from_slice(&sha1_sum);
|
||||
|
||||
(key, md5_2)
|
||||
}
|
||||
9
src/crypto/mod.rs
Normal file
9
src/crypto/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
//! Crypto
|
||||
|
||||
pub mod aes;
|
||||
pub mod hash;
|
||||
pub mod random;
|
||||
|
||||
pub use aes::{AesCtr, AesCbc};
|
||||
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32};
|
||||
pub use random::{SecureRandom, SECURE_RANDOM};
|
||||
212
src/crypto/random.rs
Normal file
212
src/crypto/random.rs
Normal file
@@ -0,0 +1,212 @@
|
||||
//! Pseudorandom
|
||||
|
||||
use rand::{Rng, RngCore, SeedableRng};
|
||||
use rand::rngs::StdRng;
|
||||
use parking_lot::Mutex;
|
||||
use crate::crypto::AesCtr;
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
/// Global secure random instance
|
||||
pub static SECURE_RANDOM: Lazy<SecureRandom> = Lazy::new(SecureRandom::new);
|
||||
|
||||
/// Cryptographically secure PRNG with AES-CTR
|
||||
pub struct SecureRandom {
|
||||
inner: Mutex<SecureRandomInner>,
|
||||
}
|
||||
|
||||
struct SecureRandomInner {
|
||||
rng: StdRng,
|
||||
cipher: AesCtr,
|
||||
buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl SecureRandom {
|
||||
pub fn new() -> Self {
|
||||
let mut rng = StdRng::from_entropy();
|
||||
|
||||
let mut key = [0u8; 32];
|
||||
rng.fill_bytes(&mut key);
|
||||
let iv: u128 = rng.gen();
|
||||
|
||||
Self {
|
||||
inner: Mutex::new(SecureRandomInner {
|
||||
rng,
|
||||
cipher: AesCtr::new(&key, iv),
|
||||
buffer: Vec::with_capacity(1024),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate random bytes
|
||||
pub fn bytes(&self, len: usize) -> Vec<u8> {
|
||||
let mut inner = self.inner.lock();
|
||||
const CHUNK_SIZE: usize = 512;
|
||||
|
||||
while inner.buffer.len() < len {
|
||||
let mut chunk = vec![0u8; CHUNK_SIZE];
|
||||
inner.rng.fill_bytes(&mut chunk);
|
||||
inner.cipher.apply(&mut chunk);
|
||||
inner.buffer.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
inner.buffer.drain(..len).collect()
|
||||
}
|
||||
|
||||
/// Generate random number in range [0, max)
|
||||
pub fn range(&self, max: usize) -> usize {
|
||||
if max == 0 {
|
||||
return 0;
|
||||
}
|
||||
let mut inner = self.inner.lock();
|
||||
inner.rng.gen_range(0..max)
|
||||
}
|
||||
|
||||
/// Generate random bits
|
||||
pub fn bits(&self, k: usize) -> u64 {
|
||||
if k == 0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let bytes_needed = (k + 7) / 8;
|
||||
let bytes = self.bytes(bytes_needed.min(8));
|
||||
|
||||
let mut result = 0u64;
|
||||
for (i, &b) in bytes.iter().enumerate() {
|
||||
if i >= 8 {
|
||||
break;
|
||||
}
|
||||
result |= (b as u64) << (i * 8);
|
||||
}
|
||||
|
||||
// Mask extra bits
|
||||
if k < 64 {
|
||||
result &= (1u64 << k) - 1;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Choose random element from slice
|
||||
pub fn choose<'a, T>(&self, slice: &'a [T]) -> Option<&'a T> {
|
||||
if slice.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(&slice[self.range(slice.len())])
|
||||
}
|
||||
}
|
||||
|
||||
/// Shuffle slice in place
|
||||
pub fn shuffle<T>(&self, slice: &mut [T]) {
|
||||
let mut inner = self.inner.lock();
|
||||
for i in (1..slice.len()).rev() {
|
||||
let j = inner.rng.gen_range(0..=i);
|
||||
slice.swap(i, j);
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate random u32
|
||||
pub fn u32(&self) -> u32 {
|
||||
let mut inner = self.inner.lock();
|
||||
inner.rng.gen()
|
||||
}
|
||||
|
||||
/// Generate random u64
|
||||
pub fn u64(&self) -> u64 {
|
||||
let mut inner = self.inner.lock();
|
||||
inner.rng.gen()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SecureRandom {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[test]
|
||||
fn test_bytes_uniqueness() {
|
||||
let rng = SecureRandom::new();
|
||||
let a = rng.bytes(32);
|
||||
let b = rng.bytes(32);
|
||||
assert_ne!(a, b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bytes_length() {
|
||||
let rng = SecureRandom::new();
|
||||
assert_eq!(rng.bytes(0).len(), 0);
|
||||
assert_eq!(rng.bytes(1).len(), 1);
|
||||
assert_eq!(rng.bytes(100).len(), 100);
|
||||
assert_eq!(rng.bytes(1000).len(), 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range() {
|
||||
let rng = SecureRandom::new();
|
||||
|
||||
for _ in 0..1000 {
|
||||
let n = rng.range(10);
|
||||
assert!(n < 10);
|
||||
}
|
||||
|
||||
assert_eq!(rng.range(1), 0);
|
||||
assert_eq!(rng.range(0), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bits() {
|
||||
let rng = SecureRandom::new();
|
||||
|
||||
// Single bit should be 0 or 1
|
||||
for _ in 0..100 {
|
||||
assert!(rng.bits(1) <= 1);
|
||||
}
|
||||
|
||||
// 8 bits should be 0-255
|
||||
for _ in 0..100 {
|
||||
assert!(rng.bits(8) <= 255);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_choose() {
|
||||
let rng = SecureRandom::new();
|
||||
let items = vec![1, 2, 3, 4, 5];
|
||||
|
||||
let mut seen = HashSet::new();
|
||||
for _ in 0..1000 {
|
||||
if let Some(&item) = rng.choose(&items) {
|
||||
seen.insert(item);
|
||||
}
|
||||
}
|
||||
|
||||
// Should have seen all items
|
||||
assert_eq!(seen.len(), 5);
|
||||
|
||||
// Empty slice should return None
|
||||
let empty: Vec<i32> = vec![];
|
||||
assert!(rng.choose(&empty).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shuffle() {
|
||||
let rng = SecureRandom::new();
|
||||
let original = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
|
||||
|
||||
let mut shuffled = original.clone();
|
||||
rng.shuffle(&mut shuffled);
|
||||
|
||||
// Should contain same elements
|
||||
let mut sorted = shuffled.clone();
|
||||
sorted.sort();
|
||||
assert_eq!(sorted, original);
|
||||
|
||||
// Should be different order (with very high probability)
|
||||
assert_ne!(shuffled, original);
|
||||
}
|
||||
}
|
||||
176
src/error.rs
Normal file
176
src/error.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
//! Error Types
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ProxyError {
|
||||
// ============= Crypto Errors =============
|
||||
|
||||
#[error("Crypto error: {0}")]
|
||||
Crypto(String),
|
||||
|
||||
#[error("Invalid key length: expected {expected}, got {got}")]
|
||||
InvalidKeyLength { expected: usize, got: usize },
|
||||
|
||||
// ============= Protocol Errors =============
|
||||
|
||||
#[error("Invalid handshake: {0}")]
|
||||
InvalidHandshake(String),
|
||||
|
||||
#[error("Invalid protocol tag: {0:02x?}")]
|
||||
InvalidProtoTag([u8; 4]),
|
||||
|
||||
#[error("Invalid TLS record: type={record_type}, version={version:02x?}")]
|
||||
InvalidTlsRecord { record_type: u8, version: [u8; 2] },
|
||||
|
||||
#[error("Replay attack detected from {addr}")]
|
||||
ReplayAttack { addr: SocketAddr },
|
||||
|
||||
#[error("Time skew detected: client={client_time}, server={server_time}")]
|
||||
TimeSkew { client_time: u32, server_time: u32 },
|
||||
|
||||
#[error("Invalid message length: {len} (min={min}, max={max})")]
|
||||
InvalidMessageLength { len: usize, min: usize, max: usize },
|
||||
|
||||
#[error("Checksum mismatch: expected={expected:08x}, got={got:08x}")]
|
||||
ChecksumMismatch { expected: u32, got: u32 },
|
||||
|
||||
#[error("Sequence number mismatch: expected={expected}, got={got}")]
|
||||
SeqNoMismatch { expected: i32, got: i32 },
|
||||
|
||||
// ============= Network Errors =============
|
||||
|
||||
#[error("Connection timeout to {addr}")]
|
||||
ConnectionTimeout { addr: String },
|
||||
|
||||
#[error("Connection refused by {addr}")]
|
||||
ConnectionRefused { addr: String },
|
||||
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
// ============= Proxy Protocol Errors =============
|
||||
|
||||
#[error("Invalid proxy protocol header")]
|
||||
InvalidProxyProtocol,
|
||||
|
||||
// ============= Config Errors =============
|
||||
|
||||
#[error("Config error: {0}")]
|
||||
Config(String),
|
||||
|
||||
#[error("Invalid secret for user {user}: {reason}")]
|
||||
InvalidSecret { user: String, reason: String },
|
||||
|
||||
// ============= User Errors =============
|
||||
|
||||
#[error("User {user} expired")]
|
||||
UserExpired { user: String },
|
||||
|
||||
#[error("User {user} exceeded connection limit")]
|
||||
ConnectionLimitExceeded { user: String },
|
||||
|
||||
#[error("User {user} exceeded data quota")]
|
||||
DataQuotaExceeded { user: String },
|
||||
|
||||
#[error("Unknown user")]
|
||||
UnknownUser,
|
||||
|
||||
// ============= General Errors =============
|
||||
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
/// Convenient Result type alias
|
||||
pub type Result<T> = std::result::Result<T, ProxyError>;
|
||||
|
||||
/// Result with optional bad client handling
|
||||
#[derive(Debug)]
|
||||
pub enum HandshakeResult<T> {
|
||||
/// Handshake succeeded
|
||||
Success(T),
|
||||
/// Client failed validation, needs masking
|
||||
BadClient,
|
||||
/// Error occurred
|
||||
Error(ProxyError),
|
||||
}
|
||||
|
||||
impl<T> HandshakeResult<T> {
|
||||
/// Check if successful
|
||||
pub fn is_success(&self) -> bool {
|
||||
matches!(self, HandshakeResult::Success(_))
|
||||
}
|
||||
|
||||
/// Check if bad client
|
||||
pub fn is_bad_client(&self) -> bool {
|
||||
matches!(self, HandshakeResult::BadClient)
|
||||
}
|
||||
|
||||
/// Convert to Result, treating BadClient as error
|
||||
pub fn into_result(self) -> Result<T> {
|
||||
match self {
|
||||
HandshakeResult::Success(v) => Ok(v),
|
||||
HandshakeResult::BadClient => Err(ProxyError::InvalidHandshake("Bad client".into())),
|
||||
HandshakeResult::Error(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Map the success value
|
||||
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U> {
|
||||
match self {
|
||||
HandshakeResult::Success(v) => HandshakeResult::Success(f(v)),
|
||||
HandshakeResult::BadClient => HandshakeResult::BadClient,
|
||||
HandshakeResult::Error(e) => HandshakeResult::Error(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<ProxyError> for HandshakeResult<T> {
|
||||
fn from(err: ProxyError) -> Self {
|
||||
HandshakeResult::Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<std::io::Error> for HandshakeResult<T> {
|
||||
fn from(err: std::io::Error) -> Self {
|
||||
HandshakeResult::Error(ProxyError::Io(err))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_handshake_result() {
|
||||
let success: HandshakeResult<i32> = HandshakeResult::Success(42);
|
||||
assert!(success.is_success());
|
||||
assert!(!success.is_bad_client());
|
||||
|
||||
let bad: HandshakeResult<i32> = HandshakeResult::BadClient;
|
||||
assert!(!bad.is_success());
|
||||
assert!(bad.is_bad_client());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handshake_result_map() {
|
||||
let success: HandshakeResult<i32> = HandshakeResult::Success(42);
|
||||
let mapped = success.map(|x| x * 2);
|
||||
|
||||
match mapped {
|
||||
HandshakeResult::Success(v) => assert_eq!(v, 84),
|
||||
_ => panic!("Expected success"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_display() {
|
||||
let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() };
|
||||
assert!(err.to_string().contains("1.2.3.4:443"));
|
||||
|
||||
let err = ProxyError::InvalidProxyProtocol;
|
||||
assert!(err.to_string().contains("proxy protocol"));
|
||||
}
|
||||
}
|
||||
158
src/main.rs
Normal file
158
src/main.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
//! Telemt - MTProxy on Rust
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::signal;
|
||||
use tracing::{info, error, Level};
|
||||
use tracing_subscriber::{FmtSubscriber, EnvFilter};
|
||||
|
||||
mod error;
|
||||
mod crypto;
|
||||
mod protocol;
|
||||
mod stream;
|
||||
mod transport;
|
||||
mod proxy;
|
||||
mod config;
|
||||
mod stats;
|
||||
mod util;
|
||||
|
||||
use config::ProxyConfig;
|
||||
use stats::{Stats, ReplayChecker};
|
||||
use transport::ConnectionPool;
|
||||
use proxy::ClientHandler;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||
// Initialize logging with env filter
|
||||
// Use RUST_LOG=debug or RUST_LOG=trace for more details
|
||||
let filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
|
||||
let subscriber = FmtSubscriber::builder()
|
||||
.with_env_filter(filter)
|
||||
.with_target(true)
|
||||
.with_thread_ids(false)
|
||||
.with_file(false)
|
||||
.with_line_number(false)
|
||||
.finish();
|
||||
|
||||
tracing::subscriber::set_global_default(subscriber)?;
|
||||
|
||||
// Load configuration
|
||||
let config_path = std::env::args()
|
||||
.nth(1)
|
||||
.unwrap_or_else(|| "config.toml".to_string());
|
||||
|
||||
info!("Loading configuration from {}", config_path);
|
||||
|
||||
let config = ProxyConfig::load(&config_path).unwrap_or_else(|e| {
|
||||
error!("Failed to load config: {}", e);
|
||||
info!("Using default configuration");
|
||||
ProxyConfig::default()
|
||||
});
|
||||
|
||||
if let Err(e) = config.validate() {
|
||||
error!("Invalid configuration: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let config = Arc::new(config);
|
||||
|
||||
info!("Starting MTProto Proxy on port {}", config.port);
|
||||
info!("Fast mode: {}", config.fast_mode);
|
||||
info!("Modes: classic={}, secure={}, tls={}",
|
||||
config.modes.classic, config.modes.secure, config.modes.tls);
|
||||
|
||||
// Initialize components
|
||||
let stats = Arc::new(Stats::new());
|
||||
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len));
|
||||
let pool = Arc::new(ConnectionPool::new());
|
||||
|
||||
// Create handler
|
||||
let handler = Arc::new(ClientHandler::new(
|
||||
Arc::clone(&config),
|
||||
Arc::clone(&stats),
|
||||
Arc::clone(&replay_checker),
|
||||
Arc::clone(&pool),
|
||||
));
|
||||
|
||||
// Start listener
|
||||
let addr: SocketAddr = format!("{}:{}", config.listen_addr_ipv4, config.port)
|
||||
.parse()?;
|
||||
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
info!("Listening on {}", addr);
|
||||
|
||||
// Print proxy links
|
||||
print_proxy_links(&config);
|
||||
|
||||
info!("Use RUST_LOG=debug or RUST_LOG=trace for more detailed logging");
|
||||
|
||||
// Main accept loop
|
||||
let accept_loop = async {
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((stream, peer)) => {
|
||||
let handler = Arc::clone(&handler);
|
||||
tokio::spawn(async move {
|
||||
handler.handle(stream, peer).await;
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Accept error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Graceful shutdown
|
||||
tokio::select! {
|
||||
_ = accept_loop => {}
|
||||
_ = signal::ctrl_c() => {
|
||||
info!("Shutting down...");
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
pool.close_all().await;
|
||||
|
||||
info!("Goodbye!");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn print_proxy_links(config: &ProxyConfig) {
|
||||
println!("\n=== Proxy Links ===\n");
|
||||
|
||||
for (user, secret) in &config.users {
|
||||
if config.modes.tls {
|
||||
let tls_secret = format!(
|
||||
"ee{}{}",
|
||||
secret,
|
||||
hex::encode(config.tls_domain.as_bytes())
|
||||
);
|
||||
println!(
|
||||
"{} (TLS): tg://proxy?server=IP&port={}&secret={}",
|
||||
user, config.port, tls_secret
|
||||
);
|
||||
}
|
||||
|
||||
if config.modes.secure {
|
||||
println!(
|
||||
"{} (Secure): tg://proxy?server=IP&port={}&secret=dd{}",
|
||||
user, config.port, secret
|
||||
);
|
||||
}
|
||||
|
||||
if config.modes.classic {
|
||||
println!(
|
||||
"{} (Classic): tg://proxy?server=IP&port={}&secret={}",
|
||||
user, config.port, secret
|
||||
);
|
||||
}
|
||||
|
||||
println!();
|
||||
}
|
||||
|
||||
println!("===================\n");
|
||||
}
|
||||
261
src/protocol/constants.rs
Normal file
261
src/protocol/constants.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
//! Protocol constants and datacenter addresses
|
||||
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
// ============= Telegram Datacenters =============
|
||||
|
||||
pub const TG_DATACENTER_PORT: u16 = 443;
|
||||
|
||||
pub static TG_DATACENTERS_V4: Lazy<Vec<IpAddr>> = Lazy::new(|| {
|
||||
vec![
|
||||
IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)),
|
||||
IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)),
|
||||
IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)),
|
||||
IpAddr::V4(Ipv4Addr::new(149, 154, 167, 91)),
|
||||
IpAddr::V4(Ipv4Addr::new(149, 154, 171, 5)),
|
||||
]
|
||||
});
|
||||
|
||||
pub static TG_DATACENTERS_V6: Lazy<Vec<IpAddr>> = Lazy::new(|| {
|
||||
vec![
|
||||
IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()),
|
||||
IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()),
|
||||
IpAddr::V6("2001:b28:f23d:f003::a".parse().unwrap()),
|
||||
IpAddr::V6("2001:67c:04e8:f004::a".parse().unwrap()),
|
||||
IpAddr::V6("2001:b28:f23f:f005::a".parse().unwrap()),
|
||||
]
|
||||
});
|
||||
|
||||
// ============= Middle Proxies (for advertising) =============
|
||||
|
||||
pub static TG_MIDDLE_PROXIES_V4: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
|
||||
Lazy::new(|| {
|
||||
let mut m = std::collections::HashMap::new();
|
||||
m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
|
||||
m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
|
||||
m.insert(2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]);
|
||||
m.insert(-2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]);
|
||||
m.insert(3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]);
|
||||
m.insert(-3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]);
|
||||
m.insert(4, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888)]);
|
||||
m.insert(-4, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 165, 109)), 8888)]);
|
||||
m.insert(5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]);
|
||||
m.insert(-5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]);
|
||||
m
|
||||
});
|
||||
|
||||
pub static TG_MIDDLE_PROXIES_V6: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
|
||||
Lazy::new(|| {
|
||||
let mut m = std::collections::HashMap::new();
|
||||
m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
|
||||
m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
|
||||
m.insert(2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]);
|
||||
m.insert(-2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]);
|
||||
m.insert(3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]);
|
||||
m.insert(-3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]);
|
||||
m.insert(4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]);
|
||||
m.insert(-4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]);
|
||||
m.insert(5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]);
|
||||
m.insert(-5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]);
|
||||
m
|
||||
});
|
||||
|
||||
// ============= Protocol Tags =============
|
||||
|
||||
/// MTProto transport protocol variants
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[repr(u32)]
|
||||
pub enum ProtoTag {
|
||||
/// Abridged protocol - compact framing
|
||||
Abridged = 0xefefefef,
|
||||
/// Intermediate protocol - simple 4-byte length prefix
|
||||
Intermediate = 0xeeeeeeee,
|
||||
/// Secure intermediate - with random padding
|
||||
Secure = 0xdddddddd,
|
||||
}
|
||||
|
||||
impl ProtoTag {
|
||||
/// Parse protocol tag from 4 bytes
|
||||
pub fn from_bytes(bytes: [u8; 4]) -> Option<Self> {
|
||||
match u32::from_le_bytes(bytes) {
|
||||
0xefefefef => Some(ProtoTag::Abridged),
|
||||
0xeeeeeeee => Some(ProtoTag::Intermediate),
|
||||
0xdddddddd => Some(ProtoTag::Secure),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to 4 bytes (little-endian)
|
||||
pub fn to_bytes(self) -> [u8; 4] {
|
||||
(self as u32).to_le_bytes()
|
||||
}
|
||||
|
||||
/// Get protocol tag as bytes slice
|
||||
pub fn as_bytes(&self) -> &'static [u8; 4] {
|
||||
match self {
|
||||
ProtoTag::Abridged => &PROTO_TAG_ABRIDGED,
|
||||
ProtoTag::Intermediate => &PROTO_TAG_INTERMEDIATE,
|
||||
ProtoTag::Secure => &PROTO_TAG_SECURE,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Protocol tag bytes
|
||||
pub const PROTO_TAG_ABRIDGED: [u8; 4] = [0xef, 0xef, 0xef, 0xef];
|
||||
pub const PROTO_TAG_INTERMEDIATE: [u8; 4] = [0xee, 0xee, 0xee, 0xee];
|
||||
pub const PROTO_TAG_SECURE: [u8; 4] = [0xdd, 0xdd, 0xdd, 0xdd];
|
||||
|
||||
// ============= Handshake Layout =============
|
||||
|
||||
/// Bytes to skip at the start of handshake
|
||||
pub const SKIP_LEN: usize = 8;
|
||||
/// Pre-key length (before hashing with secret)
|
||||
pub const PREKEY_LEN: usize = 32;
|
||||
/// AES key length
|
||||
pub const KEY_LEN: usize = 32;
|
||||
/// AES IV length
|
||||
pub const IV_LEN: usize = 16;
|
||||
/// Total handshake length
|
||||
pub const HANDSHAKE_LEN: usize = 64;
|
||||
/// Position of protocol tag in decrypted handshake
|
||||
pub const PROTO_TAG_POS: usize = 56;
|
||||
/// Position of datacenter index
|
||||
pub const DC_IDX_POS: usize = 60;
|
||||
|
||||
// ============= Message Limits =============
|
||||
|
||||
/// Minimum message length
|
||||
pub const MIN_MSG_LEN: usize = 12;
|
||||
/// Maximum message length (16 MB)
|
||||
pub const MAX_MSG_LEN: usize = 1 << 24;
|
||||
/// CBC block padding size
|
||||
pub const CBC_PADDING: usize = 16;
|
||||
/// Padding filler bytes
|
||||
pub const PADDING_FILLER: [u8; 4] = [0x04, 0x00, 0x00, 0x00];
|
||||
|
||||
// ============= TLS Constants =============
|
||||
|
||||
/// Minimum certificate length for detection
|
||||
pub const MIN_CERT_LEN: usize = 1024;
|
||||
/// TLS 1.3 version bytes
|
||||
pub const TLS_VERSION: [u8; 2] = [0x03, 0x03];
|
||||
/// TLS record type: Handshake
|
||||
pub const TLS_RECORD_HANDSHAKE: u8 = 0x16;
|
||||
/// TLS record type: Change Cipher Spec
|
||||
pub const TLS_RECORD_CHANGE_CIPHER: u8 = 0x14;
|
||||
/// TLS record type: Application Data
|
||||
pub const TLS_RECORD_APPLICATION: u8 = 0x17;
|
||||
/// TLS record type: Alert
|
||||
pub const TLS_RECORD_ALERT: u8 = 0x15;
|
||||
/// Maximum TLS record size
|
||||
pub const MAX_TLS_RECORD_SIZE: usize = 16384;
|
||||
/// Maximum TLS chunk size (with overhead)
|
||||
pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 24;
|
||||
|
||||
// ============= Timeouts =============
|
||||
|
||||
/// Default handshake timeout in seconds
|
||||
pub const DEFAULT_HANDSHAKE_TIMEOUT_SECS: u64 = 10;
|
||||
/// Default connect timeout in seconds
|
||||
pub const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10;
|
||||
/// Default keepalive interval in seconds
|
||||
pub const DEFAULT_KEEPALIVE_SECS: u64 = 600;
|
||||
/// Default ACK timeout in seconds
|
||||
pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300;
|
||||
|
||||
// ============= Buffer Sizes =============
|
||||
|
||||
/// Default buffer size
|
||||
pub const DEFAULT_BUFFER_SIZE: usize = 65536;
|
||||
/// Small buffer size for bad client handling
|
||||
pub const SMALL_BUFFER_SIZE: usize = 8192;
|
||||
|
||||
// ============= Statistics =============
|
||||
|
||||
/// Duration buckets for histogram metrics
|
||||
pub static DURATION_BUCKETS: &[f64] = &[
|
||||
0.1, 0.5, 1.0, 2.0, 5.0, 15.0, 60.0, 300.0, 600.0, 1800.0,
|
||||
];
|
||||
|
||||
// ============= Reserved Nonce Patterns =============
|
||||
|
||||
/// Reserved first bytes of nonce (must avoid)
|
||||
pub static RESERVED_NONCE_FIRST_BYTES: &[u8] = &[0xef];
|
||||
|
||||
/// Reserved 4-byte beginnings of nonce
|
||||
pub static RESERVED_NONCE_BEGINNINGS: &[[u8; 4]] = &[
|
||||
[0x48, 0x45, 0x41, 0x44], // HEAD
|
||||
[0x50, 0x4F, 0x53, 0x54], // POST
|
||||
[0x47, 0x45, 0x54, 0x20], // GET
|
||||
[0xee, 0xee, 0xee, 0xee], // Intermediate
|
||||
[0xdd, 0xdd, 0xdd, 0xdd], // Secure
|
||||
[0x16, 0x03, 0x01, 0x02], // TLS
|
||||
];
|
||||
|
||||
/// Reserved continuation bytes (bytes 4-7)
|
||||
pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[
|
||||
[0x00, 0x00, 0x00, 0x00],
|
||||
];
|
||||
|
||||
// ============= RPC Constants (for Middle Proxy) =============
|
||||
|
||||
/// RPC Proxy Request
|
||||
pub const RPC_PROXY_REQ: [u8; 4] = [0xee, 0xf1, 0xce, 0x36];
|
||||
/// RPC Proxy Answer
|
||||
pub const RPC_PROXY_ANS: [u8; 4] = [0x0d, 0xda, 0x03, 0x44];
|
||||
/// RPC Close Extended
|
||||
pub const RPC_CLOSE_EXT: [u8; 4] = [0xa2, 0x34, 0xb6, 0x5e];
|
||||
/// RPC Simple ACK
|
||||
pub const RPC_SIMPLE_ACK: [u8; 4] = [0x9b, 0x40, 0xac, 0x3b];
|
||||
/// RPC Unknown
|
||||
pub const RPC_UNKNOWN: [u8; 4] = [0xdf, 0xa2, 0x30, 0x57];
|
||||
/// RPC Handshake
|
||||
pub const RPC_HANDSHAKE: [u8; 4] = [0xf5, 0xee, 0x82, 0x76];
|
||||
/// RPC Nonce
|
||||
pub const RPC_NONCE: [u8; 4] = [0xaa, 0x87, 0xcb, 0x7a];
|
||||
|
||||
/// RPC Flags
|
||||
pub mod rpc_flags {
|
||||
pub const FLAG_NOT_ENCRYPTED: u32 = 0x2;
|
||||
pub const FLAG_HAS_AD_TAG: u32 = 0x8;
|
||||
pub const FLAG_MAGIC: u32 = 0x1000;
|
||||
pub const FLAG_EXTMODE2: u32 = 0x20000;
|
||||
pub const FLAG_PAD: u32 = 0x8000000;
|
||||
pub const FLAG_INTERMEDIATE: u32 = 0x20000000;
|
||||
pub const FLAG_ABRIDGED: u32 = 0x40000000;
|
||||
pub const FLAG_QUICKACK: u32 = 0x80000000;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_proto_tag_roundtrip() {
|
||||
for tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
|
||||
let bytes = tag.to_bytes();
|
||||
let parsed = ProtoTag::from_bytes(bytes).unwrap();
|
||||
assert_eq!(tag, parsed);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_proto_tag_values() {
|
||||
assert_eq!(ProtoTag::Abridged.to_bytes(), PROTO_TAG_ABRIDGED);
|
||||
assert_eq!(ProtoTag::Intermediate.to_bytes(), PROTO_TAG_INTERMEDIATE);
|
||||
assert_eq!(ProtoTag::Secure.to_bytes(), PROTO_TAG_SECURE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_proto_tag() {
|
||||
assert!(ProtoTag::from_bytes([0, 0, 0, 0]).is_none());
|
||||
assert!(ProtoTag::from_bytes([0xff, 0xff, 0xff, 0xff]).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_datacenters_count() {
|
||||
assert_eq!(TG_DATACENTERS_V4.len(), 5);
|
||||
assert_eq!(TG_DATACENTERS_V6.len(), 5);
|
||||
}
|
||||
}
|
||||
120
src/protocol/frame.rs
Normal file
120
src/protocol/frame.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
//! MTProto frame types and metadata
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Extra metadata associated with a frame
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct FrameExtra {
|
||||
/// Quick ACK flag - request immediate acknowledgment
|
||||
pub quickack: bool,
|
||||
/// Simple ACK - this is an acknowledgment message
|
||||
pub simple_ack: bool,
|
||||
/// Skip sending - internal flag to skip forwarding
|
||||
pub skip_send: bool,
|
||||
/// Custom key-value metadata
|
||||
pub custom: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl FrameExtra {
|
||||
/// Create new empty frame extra
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Create with quickack flag set
|
||||
pub fn with_quickack() -> Self {
|
||||
Self {
|
||||
quickack: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with simple_ack flag set
|
||||
pub fn with_simple_ack() -> Self {
|
||||
Self {
|
||||
simple_ack: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if any flags are set
|
||||
pub fn has_flags(&self) -> bool {
|
||||
self.quickack || self.simple_ack || self.skip_send
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of reading a frame
|
||||
#[derive(Debug)]
|
||||
pub enum FrameReadResult {
|
||||
/// Successfully read a frame with data and metadata
|
||||
Data(Vec<u8>, FrameExtra),
|
||||
/// Connection closed normally
|
||||
Closed,
|
||||
/// Need more data (for non-blocking reads)
|
||||
WouldBlock,
|
||||
}
|
||||
|
||||
/// Frame encoding/decoding mode
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum FrameMode {
|
||||
/// Abridged - 1 or 4 byte length prefix
|
||||
Abridged,
|
||||
/// Intermediate - 4 byte length prefix
|
||||
Intermediate,
|
||||
/// Secure Intermediate - 4 byte length with padding
|
||||
SecureIntermediate,
|
||||
/// Full MTProto - with seq_no and CRC32
|
||||
Full,
|
||||
}
|
||||
|
||||
impl FrameMode {
|
||||
/// Get maximum overhead for this frame mode
|
||||
pub fn max_overhead(&self) -> usize {
|
||||
match self {
|
||||
FrameMode::Abridged => 4,
|
||||
FrameMode::Intermediate => 4,
|
||||
FrameMode::SecureIntermediate => 4 + 3, // length + padding
|
||||
FrameMode::Full => 12 + 16, // header + max CBC padding
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate message length for MTProto
|
||||
pub fn validate_message_length(len: usize) -> bool {
|
||||
use super::constants::{MIN_MSG_LEN, MAX_MSG_LEN, PADDING_FILLER};
|
||||
|
||||
len >= MIN_MSG_LEN && len <= MAX_MSG_LEN && len % PADDING_FILLER.len() == 0
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_frame_extra_default() {
|
||||
let extra = FrameExtra::default();
|
||||
assert!(!extra.quickack);
|
||||
assert!(!extra.simple_ack);
|
||||
assert!(!extra.skip_send);
|
||||
assert!(!extra.has_flags());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frame_extra_flags() {
|
||||
let extra = FrameExtra::with_quickack();
|
||||
assert!(extra.quickack);
|
||||
assert!(extra.has_flags());
|
||||
|
||||
let extra = FrameExtra::with_simple_ack();
|
||||
assert!(extra.simple_ack);
|
||||
assert!(extra.has_flags());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_message_length() {
|
||||
assert!(validate_message_length(12)); // MIN_MSG_LEN
|
||||
assert!(validate_message_length(16));
|
||||
assert!(!validate_message_length(8)); // Too small
|
||||
assert!(!validate_message_length(13)); // Not aligned to 4
|
||||
}
|
||||
}
|
||||
11
src/protocol/mod.rs
Normal file
11
src/protocol/mod.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
//! MTProto Defs + Cons
|
||||
|
||||
pub mod constants;
|
||||
pub mod frame;
|
||||
pub mod obfuscation;
|
||||
pub mod tls;
|
||||
|
||||
pub use constants::*;
|
||||
pub use frame::*;
|
||||
pub use obfuscation::*;
|
||||
pub use tls::*;
|
||||
217
src/protocol/obfuscation.rs
Normal file
217
src/protocol/obfuscation.rs
Normal file
@@ -0,0 +1,217 @@
|
||||
//! MTProto Obfuscation
|
||||
|
||||
use crate::crypto::{sha256, AesCtr};
|
||||
use crate::error::Result;
|
||||
use super::constants::*;
|
||||
|
||||
/// Obfuscation parameters from handshake
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ObfuscationParams {
|
||||
/// Key for decrypting client -> proxy traffic
|
||||
pub decrypt_key: [u8; 32],
|
||||
/// IV for decrypting client -> proxy traffic
|
||||
pub decrypt_iv: u128,
|
||||
/// Key for encrypting proxy -> client traffic
|
||||
pub encrypt_key: [u8; 32],
|
||||
/// IV for encrypting proxy -> client traffic
|
||||
pub encrypt_iv: u128,
|
||||
/// Protocol tag (abridged/intermediate/secure)
|
||||
pub proto_tag: ProtoTag,
|
||||
/// Datacenter index
|
||||
pub dc_idx: i16,
|
||||
}
|
||||
|
||||
impl ObfuscationParams {
|
||||
/// Parse obfuscation parameters from handshake bytes
|
||||
/// Returns None if handshake doesn't match any user secret
|
||||
pub fn from_handshake(
|
||||
handshake: &[u8; HANDSHAKE_LEN],
|
||||
secrets: &[(String, Vec<u8>)], // (username, secret_bytes)
|
||||
) -> Option<(Self, String)> {
|
||||
// Extract prekey and IV for decryption
|
||||
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
|
||||
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
|
||||
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
|
||||
|
||||
// Reversed for encryption direction
|
||||
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
|
||||
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
|
||||
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
|
||||
|
||||
for (username, secret) in secrets {
|
||||
// Derive decryption key
|
||||
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
|
||||
dec_key_input.extend_from_slice(dec_prekey);
|
||||
dec_key_input.extend_from_slice(secret);
|
||||
let decrypt_key = sha256(&dec_key_input);
|
||||
|
||||
let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
|
||||
|
||||
// Create decryptor and decrypt handshake
|
||||
let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv);
|
||||
let decrypted = decryptor.decrypt(handshake);
|
||||
|
||||
// Check protocol tag
|
||||
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
|
||||
Some(tag) => tag,
|
||||
None => continue, // Try next secret
|
||||
};
|
||||
|
||||
// Extract DC index
|
||||
let dc_idx = i16::from_le_bytes(
|
||||
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
|
||||
);
|
||||
|
||||
// Derive encryption key
|
||||
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
|
||||
enc_key_input.extend_from_slice(enc_prekey);
|
||||
enc_key_input.extend_from_slice(secret);
|
||||
let encrypt_key = sha256(&enc_key_input);
|
||||
let encrypt_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
|
||||
|
||||
return Some((
|
||||
ObfuscationParams {
|
||||
decrypt_key,
|
||||
decrypt_iv,
|
||||
encrypt_key,
|
||||
encrypt_iv,
|
||||
proto_tag,
|
||||
dc_idx,
|
||||
},
|
||||
username.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Create AES-CTR decryptor for client -> proxy direction
|
||||
pub fn create_decryptor(&self) -> AesCtr {
|
||||
AesCtr::new(&self.decrypt_key, self.decrypt_iv)
|
||||
}
|
||||
|
||||
/// Create AES-CTR encryptor for proxy -> client direction
|
||||
pub fn create_encryptor(&self) -> AesCtr {
|
||||
AesCtr::new(&self.encrypt_key, self.encrypt_iv)
|
||||
}
|
||||
|
||||
/// Get the combined encrypt key and IV for fast mode
|
||||
pub fn enc_key_iv(&self) -> Vec<u8> {
|
||||
let mut result = Vec::with_capacity(KEY_LEN + IV_LEN);
|
||||
result.extend_from_slice(&self.encrypt_key);
|
||||
result.extend_from_slice(&self.encrypt_iv.to_be_bytes());
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a valid random nonce for Telegram handshake
|
||||
pub fn generate_nonce<R: FnMut(usize) -> Vec<u8>>(mut random_bytes: R) -> [u8; HANDSHAKE_LEN] {
|
||||
loop {
|
||||
let nonce_vec = random_bytes(HANDSHAKE_LEN);
|
||||
let mut nonce = [0u8; HANDSHAKE_LEN];
|
||||
nonce.copy_from_slice(&nonce_vec);
|
||||
|
||||
if is_valid_nonce(&nonce) {
|
||||
return nonce;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if nonce is valid (not matching reserved patterns)
|
||||
pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
|
||||
// Check first byte
|
||||
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check first 4 bytes
|
||||
let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
|
||||
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check bytes 4-7
|
||||
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
|
||||
if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Prepare nonce for sending to Telegram
|
||||
pub fn prepare_tg_nonce(
|
||||
nonce: &mut [u8; HANDSHAKE_LEN],
|
||||
proto_tag: ProtoTag,
|
||||
enc_key_iv: Option<&[u8]>, // For fast mode
|
||||
) {
|
||||
// Set protocol tag
|
||||
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
|
||||
|
||||
// For fast mode, copy the reversed enc_key_iv
|
||||
if let Some(key_iv) = enc_key_iv {
|
||||
let reversed: Vec<u8> = key_iv.iter().rev().copied().collect();
|
||||
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed);
|
||||
}
|
||||
}
|
||||
|
||||
/// Encrypt the outgoing nonce for Telegram
|
||||
pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
|
||||
// Derive encryption key from the nonce itself
|
||||
let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
||||
let enc_key = sha256(key_iv);
|
||||
let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap());
|
||||
|
||||
let mut encryptor = AesCtr::new(&enc_key, enc_iv);
|
||||
|
||||
// Only encrypt from PROTO_TAG_POS onwards
|
||||
let mut result = nonce.to_vec();
|
||||
let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]);
|
||||
result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_valid_nonce() {
|
||||
// Valid nonce
|
||||
let mut valid = [0x42u8; HANDSHAKE_LEN];
|
||||
valid[4..8].copy_from_slice(&[1, 2, 3, 4]);
|
||||
assert!(is_valid_nonce(&valid));
|
||||
|
||||
// Invalid: starts with 0xef
|
||||
let mut invalid = [0x00u8; HANDSHAKE_LEN];
|
||||
invalid[0] = 0xef;
|
||||
assert!(!is_valid_nonce(&invalid));
|
||||
|
||||
// Invalid: starts with HEAD
|
||||
let mut invalid = [0x00u8; HANDSHAKE_LEN];
|
||||
invalid[..4].copy_from_slice(b"HEAD");
|
||||
assert!(!is_valid_nonce(&invalid));
|
||||
|
||||
// Invalid: bytes 4-7 are zeros
|
||||
let mut invalid = [0x42u8; HANDSHAKE_LEN];
|
||||
invalid[4..8].copy_from_slice(&[0, 0, 0, 0]);
|
||||
assert!(!is_valid_nonce(&invalid));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_nonce() {
|
||||
let mut counter = 0u8;
|
||||
let nonce = generate_nonce(|n| {
|
||||
counter = counter.wrapping_add(1);
|
||||
vec![counter; n]
|
||||
});
|
||||
|
||||
assert!(is_valid_nonce(&nonce));
|
||||
assert_eq!(nonce.len(), HANDSHAKE_LEN);
|
||||
}
|
||||
}
|
||||
244
src/protocol/tls.rs
Normal file
244
src/protocol/tls.rs
Normal file
@@ -0,0 +1,244 @@
|
||||
//! Fake TLS 1.3 Handshake
|
||||
|
||||
use crate::crypto::{sha256_hmac, random::SECURE_RANDOM};
|
||||
use crate::error::{ProxyError, Result};
|
||||
use super::constants::*;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
/// TLS handshake digest length
|
||||
pub const TLS_DIGEST_LEN: usize = 32;
|
||||
/// Position of digest in TLS ClientHello
|
||||
pub const TLS_DIGEST_POS: usize = 11;
|
||||
/// Length to store for replay protection (first 16 bytes of digest)
|
||||
pub const TLS_DIGEST_HALF_LEN: usize = 16;
|
||||
|
||||
/// Time skew limits for anti-replay (in seconds)
|
||||
pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before
|
||||
pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after
|
||||
|
||||
/// Result of validating TLS handshake
|
||||
#[derive(Debug)]
|
||||
pub struct TlsValidation {
|
||||
/// Username that validated
|
||||
pub user: String,
|
||||
/// Session ID from ClientHello
|
||||
pub session_id: Vec<u8>,
|
||||
/// Client digest for response generation
|
||||
pub digest: [u8; TLS_DIGEST_LEN],
|
||||
/// Timestamp extracted from digest
|
||||
pub timestamp: u32,
|
||||
}
|
||||
|
||||
/// Validate TLS ClientHello against user secrets
|
||||
pub fn validate_tls_handshake(
|
||||
handshake: &[u8],
|
||||
secrets: &[(String, Vec<u8>)],
|
||||
ignore_time_skew: bool,
|
||||
) -> Option<TlsValidation> {
|
||||
if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Extract digest
|
||||
let digest: [u8; TLS_DIGEST_LEN] = handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
|
||||
.try_into()
|
||||
.ok()?;
|
||||
|
||||
// Extract session ID
|
||||
let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN;
|
||||
let session_id_len = handshake.get(session_id_len_pos).copied()? as usize;
|
||||
let session_id_start = session_id_len_pos + 1;
|
||||
|
||||
if handshake.len() < session_id_start + session_id_len {
|
||||
return None;
|
||||
}
|
||||
|
||||
let session_id = handshake[session_id_start..session_id_start + session_id_len].to_vec();
|
||||
|
||||
// Build message for HMAC (with zeroed digest)
|
||||
let mut msg = handshake.to_vec();
|
||||
msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
|
||||
|
||||
// Get current time
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
|
||||
for (user, secret) in secrets {
|
||||
let computed = sha256_hmac(secret, &msg);
|
||||
|
||||
// XOR digests
|
||||
let xored: Vec<u8> = digest.iter()
|
||||
.zip(computed.iter())
|
||||
.map(|(a, b)| a ^ b)
|
||||
.collect();
|
||||
|
||||
// Check that first 28 bytes are zeros (timestamp in last 4)
|
||||
if !xored[..28].iter().all(|&b| b == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Extract timestamp
|
||||
let timestamp = u32::from_le_bytes(xored[28..32].try_into().unwrap());
|
||||
let time_diff = now - timestamp as i64;
|
||||
|
||||
// Check time skew
|
||||
if !ignore_time_skew {
|
||||
// Allow very small timestamps (boot time instead of unix time)
|
||||
let is_boot_time = timestamp < 60 * 60 * 24 * 1000;
|
||||
|
||||
if !is_boot_time && (time_diff < TIME_SKEW_MIN || time_diff > TIME_SKEW_MAX) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return Some(TlsValidation {
|
||||
user: user.clone(),
|
||||
session_id,
|
||||
digest,
|
||||
timestamp,
|
||||
});
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Generate a fake X25519 public key for TLS
|
||||
/// This generates a value that looks like a valid X25519 key
|
||||
pub fn gen_fake_x25519_key() -> [u8; 32] {
|
||||
// For simplicity, just generate random 32 bytes
|
||||
// In real X25519, this would be a point on the curve
|
||||
let bytes = SECURE_RANDOM.bytes(32);
|
||||
bytes.try_into().unwrap()
|
||||
}
|
||||
|
||||
/// Build TLS ServerHello response
|
||||
pub fn build_server_hello(
|
||||
secret: &[u8],
|
||||
client_digest: &[u8; TLS_DIGEST_LEN],
|
||||
session_id: &[u8],
|
||||
fake_cert_len: usize,
|
||||
) -> Vec<u8> {
|
||||
let x25519_key = gen_fake_x25519_key();
|
||||
|
||||
// TLS extensions
|
||||
let mut extensions = Vec::new();
|
||||
extensions.extend_from_slice(&[0x00, 0x2e]); // Extension length placeholder
|
||||
extensions.extend_from_slice(&[0x00, 0x33, 0x00, 0x24]); // Key share extension
|
||||
extensions.extend_from_slice(&[0x00, 0x1d, 0x00, 0x20]); // X25519 curve
|
||||
extensions.extend_from_slice(&x25519_key);
|
||||
extensions.extend_from_slice(&[0x00, 0x2b, 0x00, 0x02, 0x03, 0x04]); // Supported versions
|
||||
|
||||
// ServerHello body
|
||||
let mut srv_hello = Vec::new();
|
||||
srv_hello.extend_from_slice(&TLS_VERSION);
|
||||
srv_hello.extend_from_slice(&[0u8; TLS_DIGEST_LEN]); // Placeholder for digest
|
||||
srv_hello.push(session_id.len() as u8);
|
||||
srv_hello.extend_from_slice(session_id);
|
||||
srv_hello.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256
|
||||
srv_hello.push(0x00); // No compression
|
||||
srv_hello.extend_from_slice(&extensions);
|
||||
|
||||
// Build complete packet
|
||||
let mut hello_pkt = Vec::new();
|
||||
|
||||
// ServerHello record
|
||||
hello_pkt.push(TLS_RECORD_HANDSHAKE);
|
||||
hello_pkt.extend_from_slice(&TLS_VERSION);
|
||||
hello_pkt.extend_from_slice(&((srv_hello.len() + 4) as u16).to_be_bytes());
|
||||
hello_pkt.push(0x02); // ServerHello message type
|
||||
let len_bytes = (srv_hello.len() as u32).to_be_bytes();
|
||||
hello_pkt.extend_from_slice(&len_bytes[1..4]); // 3-byte length
|
||||
hello_pkt.extend_from_slice(&srv_hello);
|
||||
|
||||
// Change Cipher Spec record
|
||||
hello_pkt.extend_from_slice(&[
|
||||
TLS_RECORD_CHANGE_CIPHER,
|
||||
TLS_VERSION[0], TLS_VERSION[1],
|
||||
0x00, 0x01, 0x01
|
||||
]);
|
||||
|
||||
// Application Data record (fake certificate)
|
||||
let fake_cert = SECURE_RANDOM.bytes(fake_cert_len);
|
||||
hello_pkt.push(TLS_RECORD_APPLICATION);
|
||||
hello_pkt.extend_from_slice(&TLS_VERSION);
|
||||
hello_pkt.extend_from_slice(&(fake_cert.len() as u16).to_be_bytes());
|
||||
hello_pkt.extend_from_slice(&fake_cert);
|
||||
|
||||
// Compute HMAC for the response
|
||||
let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + hello_pkt.len());
|
||||
hmac_input.extend_from_slice(client_digest);
|
||||
hmac_input.extend_from_slice(&hello_pkt);
|
||||
let response_digest = sha256_hmac(secret, &hmac_input);
|
||||
|
||||
// Insert computed digest
|
||||
// Position: after record header (5) + message type/length (4) + version (2) = 11
|
||||
hello_pkt[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
|
||||
.copy_from_slice(&response_digest);
|
||||
|
||||
hello_pkt
|
||||
}
|
||||
|
||||
/// Check if bytes look like a TLS ClientHello
|
||||
pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
|
||||
if first_bytes.len() < 3 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TLS record header: 0x16 0x03 0x01
|
||||
first_bytes[0] == TLS_RECORD_HANDSHAKE
|
||||
&& first_bytes[1] == 0x03
|
||||
&& first_bytes[2] == 0x01
|
||||
}
|
||||
|
||||
/// Parse TLS record header, returns (record_type, length)
|
||||
pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> {
|
||||
let record_type = header[0];
|
||||
let version = [header[1], header[2]];
|
||||
|
||||
// We accept both TLS 1.0 header (for ClientHello) and TLS 1.2/1.3
|
||||
if version != [0x03, 0x01] && version != TLS_VERSION {
|
||||
return None;
|
||||
}
|
||||
|
||||
let length = u16::from_be_bytes([header[3], header[4]]);
|
||||
Some((record_type, length))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_tls_handshake() {
|
||||
assert!(is_tls_handshake(&[0x16, 0x03, 0x01]));
|
||||
assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00]));
|
||||
assert!(!is_tls_handshake(&[0x17, 0x03, 0x01])); // Application data
|
||||
assert!(!is_tls_handshake(&[0x16, 0x03, 0x02])); // Wrong version
|
||||
assert!(!is_tls_handshake(&[0x16, 0x03])); // Too short
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_tls_record_header() {
|
||||
let header = [0x16, 0x03, 0x01, 0x02, 0x00];
|
||||
let result = parse_tls_record_header(&header).unwrap();
|
||||
assert_eq!(result.0, TLS_RECORD_HANDSHAKE);
|
||||
assert_eq!(result.1, 512);
|
||||
|
||||
let header = [0x17, 0x03, 0x03, 0x40, 0x00];
|
||||
let result = parse_tls_record_header(&header).unwrap();
|
||||
assert_eq!(result.0, TLS_RECORD_APPLICATION);
|
||||
assert_eq!(result.1, 16384);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gen_fake_x25519_key() {
|
||||
let key1 = gen_fake_x25519_key();
|
||||
let key2 = gen_fake_x25519_key();
|
||||
|
||||
assert_eq!(key1.len(), 32);
|
||||
assert_eq!(key2.len(), 32);
|
||||
assert_ne!(key1, key2); // Should be random
|
||||
}
|
||||
}
|
||||
378
src/proxy/client.rs
Normal file
378
src/proxy/client.rs
Normal file
@@ -0,0 +1,378 @@
|
||||
//! Client Handler
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, info, warn, error, trace};
|
||||
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::error::{ProxyError, Result, HandshakeResult};
|
||||
use crate::protocol::constants::*;
|
||||
use crate::protocol::tls;
|
||||
use crate::stats::{Stats, ReplayChecker};
|
||||
use crate::transport::{ConnectionPool, configure_client_socket};
|
||||
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter};
|
||||
use crate::crypto::AesCtr;
|
||||
|
||||
use super::handshake::{
|
||||
handle_tls_handshake, handle_mtproto_handshake,
|
||||
HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce,
|
||||
};
|
||||
use super::relay::relay_bidirectional;
|
||||
use super::masking::handle_bad_client;
|
||||
|
||||
/// Client connection handler
|
||||
pub struct ClientHandler {
|
||||
config: Arc<ProxyConfig>,
|
||||
stats: Arc<Stats>,
|
||||
replay_checker: Arc<ReplayChecker>,
|
||||
pool: Arc<ConnectionPool>,
|
||||
}
|
||||
|
||||
impl ClientHandler {
|
||||
/// Create new client handler
|
||||
pub fn new(
|
||||
config: Arc<ProxyConfig>,
|
||||
stats: Arc<Stats>,
|
||||
replay_checker: Arc<ReplayChecker>,
|
||||
pool: Arc<ConnectionPool>,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
stats,
|
||||
replay_checker,
|
||||
pool,
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a client connection
|
||||
pub async fn handle(&self, stream: TcpStream, peer: SocketAddr) {
|
||||
self.stats.increment_connects_all();
|
||||
|
||||
debug!(peer = %peer, "New connection");
|
||||
|
||||
// Configure socket
|
||||
if let Err(e) = configure_client_socket(
|
||||
&stream,
|
||||
self.config.client_keepalive,
|
||||
self.config.client_ack_timeout,
|
||||
) {
|
||||
debug!(peer = %peer, error = %e, "Failed to configure client socket");
|
||||
}
|
||||
|
||||
// Perform handshake with timeout
|
||||
let handshake_timeout = Duration::from_secs(self.config.client_handshake_timeout);
|
||||
|
||||
let result = timeout(
|
||||
handshake_timeout,
|
||||
self.do_handshake(stream, peer)
|
||||
).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => {
|
||||
debug!(peer = %peer, "Connection handled successfully");
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
debug!(peer = %peer, error = %e, "Handshake failed");
|
||||
}
|
||||
Err(_) => {
|
||||
self.stats.increment_handshake_timeouts();
|
||||
debug!(peer = %peer, "Handshake timeout");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform handshake and relay
|
||||
async fn do_handshake(&self, mut stream: TcpStream, peer: SocketAddr) -> Result<()> {
|
||||
// Read first bytes to determine handshake type
|
||||
let mut first_bytes = [0u8; 5];
|
||||
stream.read_exact(&mut first_bytes).await?;
|
||||
|
||||
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
||||
|
||||
debug!(peer = %peer, is_tls = is_tls, first_bytes = %hex::encode(&first_bytes), "Handshake type detected");
|
||||
|
||||
if is_tls {
|
||||
self.handle_tls_client(stream, peer, first_bytes).await
|
||||
} else {
|
||||
self.handle_direct_client(stream, peer, first_bytes).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle TLS-wrapped client
|
||||
async fn handle_tls_client(
|
||||
&self,
|
||||
mut stream: TcpStream,
|
||||
peer: SocketAddr,
|
||||
first_bytes: [u8; 5],
|
||||
) -> Result<()> {
|
||||
// Read TLS handshake length
|
||||
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
||||
|
||||
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake");
|
||||
|
||||
if tls_len < 512 {
|
||||
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
|
||||
self.stats.increment_connects_bad();
|
||||
handle_bad_client(stream, &first_bytes, &self.config).await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Read full TLS handshake
|
||||
let mut handshake = vec![0u8; 5 + tls_len];
|
||||
handshake[..5].copy_from_slice(&first_bytes);
|
||||
stream.read_exact(&mut handshake[5..]).await?;
|
||||
|
||||
// Split stream for reading/writing
|
||||
let (read_half, write_half) = stream.into_split();
|
||||
|
||||
// Handle TLS handshake
|
||||
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
|
||||
&handshake,
|
||||
read_half,
|
||||
write_half,
|
||||
peer,
|
||||
&self.config,
|
||||
&self.replay_checker,
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
HandshakeResult::BadClient => {
|
||||
self.stats.increment_connects_bad();
|
||||
return Ok(());
|
||||
}
|
||||
HandshakeResult::Error(e) => return Err(e),
|
||||
};
|
||||
|
||||
// Read MTProto handshake through TLS
|
||||
debug!(peer = %peer, "Reading MTProto handshake through TLS");
|
||||
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
|
||||
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
|
||||
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
|
||||
|
||||
// Handle MTProto handshake
|
||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||
&mtproto_handshake,
|
||||
tls_reader,
|
||||
tls_writer,
|
||||
peer,
|
||||
&self.config,
|
||||
&self.replay_checker,
|
||||
true,
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
HandshakeResult::BadClient => {
|
||||
self.stats.increment_connects_bad();
|
||||
return Ok(());
|
||||
}
|
||||
HandshakeResult::Error(e) => return Err(e),
|
||||
};
|
||||
|
||||
// Handle authenticated client
|
||||
self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await
|
||||
}
|
||||
|
||||
/// Handle direct (non-TLS) client
|
||||
async fn handle_direct_client(
|
||||
&self,
|
||||
mut stream: TcpStream,
|
||||
peer: SocketAddr,
|
||||
first_bytes: [u8; 5],
|
||||
) -> Result<()> {
|
||||
// Check if non-TLS modes are enabled
|
||||
if !self.config.modes.classic && !self.config.modes.secure {
|
||||
debug!(peer = %peer, "Non-TLS modes disabled");
|
||||
self.stats.increment_connects_bad();
|
||||
handle_bad_client(stream, &first_bytes, &self.config).await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Read rest of handshake
|
||||
let mut handshake = [0u8; HANDSHAKE_LEN];
|
||||
handshake[..5].copy_from_slice(&first_bytes);
|
||||
stream.read_exact(&mut handshake[5..]).await?;
|
||||
|
||||
// Split stream
|
||||
let (read_half, write_half) = stream.into_split();
|
||||
|
||||
// Handle MTProto handshake
|
||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||
&handshake,
|
||||
read_half,
|
||||
write_half,
|
||||
peer,
|
||||
&self.config,
|
||||
&self.replay_checker,
|
||||
false,
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
HandshakeResult::BadClient => {
|
||||
self.stats.increment_connects_bad();
|
||||
return Ok(());
|
||||
}
|
||||
HandshakeResult::Error(e) => return Err(e),
|
||||
};
|
||||
|
||||
self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await
|
||||
}
|
||||
|
||||
/// Handle authenticated client - connect to Telegram and relay
|
||||
async fn handle_authenticated_inner<R, W>(
|
||||
&self,
|
||||
client_reader: CryptoReader<R>,
|
||||
client_writer: CryptoWriter<W>,
|
||||
success: HandshakeSuccess,
|
||||
) -> Result<()>
|
||||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let user = &success.user;
|
||||
|
||||
// Check user limits
|
||||
if let Err(e) = self.check_user_limits(user) {
|
||||
warn!(user = %user, error = %e, "User limit exceeded");
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
// Get datacenter address
|
||||
let dc_addr = self.get_dc_addr(success.dc_idx)?;
|
||||
|
||||
info!(
|
||||
user = %user,
|
||||
peer = %success.peer,
|
||||
dc = success.dc_idx,
|
||||
dc_addr = %dc_addr,
|
||||
proto = ?success.proto_tag,
|
||||
fast_mode = self.config.fast_mode,
|
||||
"Connecting to Telegram"
|
||||
);
|
||||
|
||||
// Connect to Telegram
|
||||
let tg_stream = self.pool.get(dc_addr).await?;
|
||||
|
||||
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected to Telegram, performing handshake");
|
||||
|
||||
// Perform Telegram handshake and get crypto streams
|
||||
let (tg_reader, tg_writer) = self.do_tg_handshake(
|
||||
tg_stream,
|
||||
&success,
|
||||
).await?;
|
||||
|
||||
debug!(peer = %success.peer, "Telegram handshake complete, starting relay");
|
||||
|
||||
// Update stats
|
||||
self.stats.increment_user_connects(user);
|
||||
self.stats.increment_user_curr_connects(user);
|
||||
|
||||
// Relay traffic - передаём Arc::clone(&self.stats)
|
||||
let relay_result = relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
tg_reader,
|
||||
tg_writer,
|
||||
user,
|
||||
Arc::clone(&self.stats),
|
||||
).await;
|
||||
|
||||
// Update stats
|
||||
self.stats.decrement_user_curr_connects(user);
|
||||
|
||||
match &relay_result {
|
||||
Ok(()) => debug!(user = %user, peer = %success.peer, "Relay completed normally"),
|
||||
Err(e) => debug!(user = %user, peer = %success.peer, error = %e, "Relay ended with error"),
|
||||
}
|
||||
|
||||
relay_result
|
||||
}
|
||||
|
||||
/// Check user limits (expiration, connection count, data quota)
|
||||
fn check_user_limits(&self, user: &str) -> Result<()> {
|
||||
// Check expiration
|
||||
if let Some(expiration) = self.config.user_expirations.get(user) {
|
||||
if chrono::Utc::now() > *expiration {
|
||||
return Err(ProxyError::UserExpired { user: user.to_string() });
|
||||
}
|
||||
}
|
||||
|
||||
// Check connection limit
|
||||
if let Some(limit) = self.config.user_max_tcp_conns.get(user) {
|
||||
let current = self.stats.get_user_curr_connects(user);
|
||||
if current >= *limit as u64 {
|
||||
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
|
||||
}
|
||||
}
|
||||
|
||||
// Check data quota
|
||||
if let Some(quota) = self.config.user_data_quota.get(user) {
|
||||
let used = self.stats.get_user_total_octets(user);
|
||||
if used >= *quota {
|
||||
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get datacenter address by index
|
||||
fn get_dc_addr(&self, dc_idx: i16) -> Result<SocketAddr> {
|
||||
let idx = (dc_idx.abs() - 1) as usize;
|
||||
|
||||
let datacenters = if self.config.prefer_ipv6 {
|
||||
&*TG_DATACENTERS_V6
|
||||
} else {
|
||||
&*TG_DATACENTERS_V4
|
||||
};
|
||||
|
||||
datacenters.get(idx)
|
||||
.map(|ip| SocketAddr::new(*ip, TG_DATACENTER_PORT))
|
||||
.ok_or_else(|| ProxyError::InvalidHandshake(
|
||||
format!("Invalid DC index: {}", dc_idx)
|
||||
))
|
||||
}
|
||||
|
||||
/// Perform handshake with Telegram server
|
||||
/// Returns crypto reader and writer for TG connection
|
||||
async fn do_tg_handshake(
|
||||
&self,
|
||||
mut stream: TcpStream,
|
||||
success: &HandshakeSuccess,
|
||||
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
|
||||
// Generate nonce with keys for TG
|
||||
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
|
||||
success.proto_tag,
|
||||
&success.dec_key, // Client's dec key
|
||||
success.dec_iv,
|
||||
self.config.fast_mode,
|
||||
);
|
||||
|
||||
// Encrypt nonce
|
||||
let encrypted_nonce = encrypt_tg_nonce(&nonce);
|
||||
|
||||
debug!(
|
||||
peer = %success.peer,
|
||||
nonce_head = %hex::encode(&nonce[..16]),
|
||||
encrypted_head = %hex::encode(&encrypted_nonce[..16]),
|
||||
"Sending nonce to Telegram"
|
||||
);
|
||||
|
||||
// Send to Telegram
|
||||
stream.write_all(&encrypted_nonce).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
debug!(peer = %success.peer, "Nonce sent to Telegram");
|
||||
|
||||
// Split stream and wrap with crypto
|
||||
let (read_half, write_half) = stream.into_split();
|
||||
|
||||
let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv);
|
||||
let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv);
|
||||
|
||||
let tg_reader = CryptoReader::new(read_half, decryptor);
|
||||
let tg_writer = CryptoWriter::new(write_half, encryptor);
|
||||
|
||||
Ok((tg_reader, tg_writer))
|
||||
}
|
||||
}
|
||||
411
src/proxy/handshake.rs
Normal file
411
src/proxy/handshake.rs
Normal file
@@ -0,0 +1,411 @@
|
||||
//! MTProto Handshake Magics
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tracing::{debug, warn, trace, info};
|
||||
|
||||
use crate::crypto::{sha256, AesCtr};
|
||||
use crate::crypto::random::SECURE_RANDOM;
|
||||
use crate::protocol::constants::*;
|
||||
use crate::protocol::tls;
|
||||
use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
|
||||
use crate::error::{ProxyError, HandshakeResult};
|
||||
use crate::stats::ReplayChecker;
|
||||
use crate::config::ProxyConfig;
|
||||
|
||||
/// Result of successful handshake
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HandshakeSuccess {
|
||||
/// Authenticated user name
|
||||
pub user: String,
|
||||
/// Target datacenter index
|
||||
pub dc_idx: i16,
|
||||
/// Protocol variant (abridged/intermediate/secure)
|
||||
pub proto_tag: ProtoTag,
|
||||
/// Decryption key and IV (for reading from client)
|
||||
pub dec_key: [u8; 32],
|
||||
pub dec_iv: u128,
|
||||
/// Encryption key and IV (for writing to client)
|
||||
pub enc_key: [u8; 32],
|
||||
pub enc_iv: u128,
|
||||
/// Client address
|
||||
pub peer: SocketAddr,
|
||||
/// Whether TLS was used
|
||||
pub is_tls: bool,
|
||||
}
|
||||
|
||||
/// Handle fake TLS handshake
|
||||
pub async fn handle_tls_handshake<R, W>(
|
||||
handshake: &[u8],
|
||||
reader: R,
|
||||
mut writer: W,
|
||||
peer: SocketAddr,
|
||||
config: &ProxyConfig,
|
||||
replay_checker: &ReplayChecker,
|
||||
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String)>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
|
||||
|
||||
// Check minimum length
|
||||
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
|
||||
debug!(peer = %peer, "TLS handshake too short");
|
||||
return HandshakeResult::BadClient;
|
||||
}
|
||||
|
||||
// Extract digest for replay check
|
||||
let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN];
|
||||
let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN];
|
||||
|
||||
// Check for replay
|
||||
if replay_checker.check_tls_digest(digest_half) {
|
||||
warn!(peer = %peer, "TLS replay attack detected");
|
||||
return HandshakeResult::BadClient;
|
||||
}
|
||||
|
||||
// Build secrets list
|
||||
let secrets: Vec<(String, Vec<u8>)> = config.users.iter()
|
||||
.filter_map(|(name, hex)| {
|
||||
hex::decode(hex).ok().map(|bytes| (name.clone(), bytes))
|
||||
})
|
||||
.collect();
|
||||
|
||||
debug!(peer = %peer, num_users = secrets.len(), "Validating TLS handshake against users");
|
||||
|
||||
// Validate handshake
|
||||
let validation = match tls::validate_tls_handshake(
|
||||
handshake,
|
||||
&secrets,
|
||||
config.ignore_time_skew,
|
||||
) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
debug!(peer = %peer, "TLS handshake validation failed - no matching user");
|
||||
return HandshakeResult::BadClient;
|
||||
}
|
||||
};
|
||||
|
||||
// Get secret for response
|
||||
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
|
||||
Some((_, s)) => s,
|
||||
None => return HandshakeResult::BadClient,
|
||||
};
|
||||
|
||||
// Build and send response
|
||||
let response = tls::build_server_hello(
|
||||
secret,
|
||||
&validation.digest,
|
||||
&validation.session_id,
|
||||
config.fake_cert_len,
|
||||
);
|
||||
|
||||
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
|
||||
|
||||
if let Err(e) = writer.write_all(&response).await {
|
||||
return HandshakeResult::Error(ProxyError::Io(e));
|
||||
}
|
||||
|
||||
if let Err(e) = writer.flush().await {
|
||||
return HandshakeResult::Error(ProxyError::Io(e));
|
||||
}
|
||||
|
||||
// Record for replay protection
|
||||
replay_checker.add_tls_digest(digest_half);
|
||||
|
||||
info!(
|
||||
peer = %peer,
|
||||
user = %validation.user,
|
||||
"TLS handshake successful"
|
||||
);
|
||||
|
||||
HandshakeResult::Success((
|
||||
FakeTlsReader::new(reader),
|
||||
FakeTlsWriter::new(writer),
|
||||
validation.user,
|
||||
))
|
||||
}
|
||||
|
||||
/// Handle MTProto obfuscation handshake
|
||||
pub async fn handle_mtproto_handshake<R, W>(
|
||||
handshake: &[u8; HANDSHAKE_LEN],
|
||||
reader: R,
|
||||
writer: W,
|
||||
peer: SocketAddr,
|
||||
config: &ProxyConfig,
|
||||
replay_checker: &ReplayChecker,
|
||||
is_tls: bool,
|
||||
) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess)>
|
||||
where
|
||||
R: AsyncRead + Unpin + Send,
|
||||
W: AsyncWrite + Unpin + Send,
|
||||
{
|
||||
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
|
||||
|
||||
// Extract prekey and IV
|
||||
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
|
||||
|
||||
debug!(
|
||||
peer = %peer,
|
||||
dec_prekey_iv = %hex::encode(dec_prekey_iv),
|
||||
"Extracted prekey+IV from handshake"
|
||||
);
|
||||
|
||||
// Check for replay
|
||||
if replay_checker.check_handshake(dec_prekey_iv) {
|
||||
warn!(peer = %peer, "MTProto replay attack detected");
|
||||
return HandshakeResult::BadClient;
|
||||
}
|
||||
|
||||
// Reversed for encryption direction
|
||||
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
|
||||
|
||||
// Try each user's secret
|
||||
for (user, secret_hex) in &config.users {
|
||||
let secret = match hex::decode(secret_hex) {
|
||||
Ok(s) => s,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
// Derive decryption key
|
||||
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
|
||||
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
|
||||
|
||||
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
|
||||
dec_key_input.extend_from_slice(dec_prekey);
|
||||
dec_key_input.extend_from_slice(&secret);
|
||||
let dec_key = sha256(&dec_key_input);
|
||||
|
||||
let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
|
||||
|
||||
// Decrypt handshake to check protocol tag
|
||||
let mut decryptor = AesCtr::new(&dec_key, dec_iv);
|
||||
let decrypted = decryptor.decrypt(handshake);
|
||||
|
||||
trace!(
|
||||
peer = %peer,
|
||||
user = %user,
|
||||
decrypted_tail = %hex::encode(&decrypted[PROTO_TAG_POS..]),
|
||||
"Decrypted handshake tail"
|
||||
);
|
||||
|
||||
// Check protocol tag
|
||||
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
|
||||
Some(tag) => tag,
|
||||
None => {
|
||||
trace!(peer = %peer, user = %user, tag = %hex::encode(tag_bytes), "Invalid proto tag");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
debug!(peer = %peer, user = %user, proto = ?proto_tag, "Found valid proto tag");
|
||||
|
||||
// Check if mode is enabled
|
||||
let mode_ok = match proto_tag {
|
||||
ProtoTag::Secure => {
|
||||
if is_tls { config.modes.tls } else { config.modes.secure }
|
||||
}
|
||||
ProtoTag::Intermediate | ProtoTag::Abridged => config.modes.classic,
|
||||
};
|
||||
|
||||
if !mode_ok {
|
||||
debug!(peer = %peer, user = %user, proto = ?proto_tag, "Mode not enabled");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Extract DC index
|
||||
let dc_idx = i16::from_le_bytes(
|
||||
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
|
||||
);
|
||||
|
||||
// Derive encryption key
|
||||
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
|
||||
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
|
||||
|
||||
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
|
||||
enc_key_input.extend_from_slice(enc_prekey);
|
||||
enc_key_input.extend_from_slice(&secret);
|
||||
let enc_key = sha256(&enc_key_input);
|
||||
|
||||
let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
|
||||
|
||||
// Record for replay protection
|
||||
replay_checker.add_handshake(dec_prekey_iv);
|
||||
|
||||
// Create new cipher instances
|
||||
let decryptor = AesCtr::new(&dec_key, dec_iv);
|
||||
let encryptor = AesCtr::new(&enc_key, enc_iv);
|
||||
|
||||
let success = HandshakeSuccess {
|
||||
user: user.clone(),
|
||||
dc_idx,
|
||||
proto_tag,
|
||||
dec_key,
|
||||
dec_iv,
|
||||
enc_key,
|
||||
enc_iv,
|
||||
peer,
|
||||
is_tls,
|
||||
};
|
||||
|
||||
info!(
|
||||
peer = %peer,
|
||||
user = %user,
|
||||
dc = dc_idx,
|
||||
proto = ?proto_tag,
|
||||
tls = is_tls,
|
||||
"MTProto handshake successful"
|
||||
);
|
||||
|
||||
return HandshakeResult::Success((
|
||||
CryptoReader::new(reader, decryptor),
|
||||
CryptoWriter::new(writer, encryptor),
|
||||
success,
|
||||
));
|
||||
}
|
||||
|
||||
debug!(peer = %peer, "MTProto handshake: no matching user found");
|
||||
HandshakeResult::BadClient
|
||||
}
|
||||
|
||||
/// Generate nonce for Telegram connection
|
||||
///
|
||||
/// In FAST MODE: we use the same keys for TG as for client, but reversed.
|
||||
/// This means: client's enc_key becomes TG's dec_key and vice versa.
|
||||
pub fn generate_tg_nonce(
|
||||
proto_tag: ProtoTag,
|
||||
client_dec_key: &[u8; 32],
|
||||
client_dec_iv: u128,
|
||||
fast_mode: bool,
|
||||
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
|
||||
loop {
|
||||
let bytes = SECURE_RANDOM.bytes(HANDSHAKE_LEN);
|
||||
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
|
||||
|
||||
// Check reserved patterns
|
||||
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
|
||||
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
|
||||
if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Set protocol tag
|
||||
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
|
||||
|
||||
// Fast mode: copy client's dec_key+iv (this becomes TG's enc direction)
|
||||
// In fast mode, we make TG use the same keys as client but swapped:
|
||||
// - What we decrypt FROM TG = what we encrypt TO client (so no re-encryption needed)
|
||||
// - What we encrypt TO TG = what we decrypt FROM client
|
||||
if fast_mode {
|
||||
// Put client's dec_key + dec_iv into nonce[8:56]
|
||||
// This will be used by TG for encryption TO us
|
||||
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key);
|
||||
nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN]
|
||||
.copy_from_slice(&client_dec_iv.to_be_bytes());
|
||||
}
|
||||
|
||||
// Now compute what keys WE will use for TG connection
|
||||
// enc_key_iv = nonce[8:56] (for encrypting TO TG)
|
||||
// dec_key_iv = nonce[8:56] reversed (for decrypting FROM TG)
|
||||
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
||||
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
|
||||
|
||||
let tg_enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
|
||||
let tg_enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
|
||||
|
||||
let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
|
||||
let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
|
||||
|
||||
debug!(
|
||||
fast_mode = fast_mode,
|
||||
tg_enc_key = %hex::encode(&tg_enc_key[..8]),
|
||||
tg_dec_key = %hex::encode(&tg_dec_key[..8]),
|
||||
"Generated TG nonce"
|
||||
);
|
||||
|
||||
return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv);
|
||||
}
|
||||
}
|
||||
|
||||
/// Encrypt nonce for sending to Telegram
|
||||
///
|
||||
/// Only the part from PROTO_TAG_POS onwards is encrypted.
|
||||
/// The encryption key is derived from enc_key_iv in the nonce itself.
|
||||
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
|
||||
// enc_key_iv is at nonce[8:56]
|
||||
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
||||
|
||||
// Key for encrypting is just the first 32 bytes of enc_key_iv
|
||||
let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
|
||||
let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
|
||||
|
||||
let mut encryptor = AesCtr::new(&key, iv);
|
||||
|
||||
// Encrypt the entire nonce first, then take only the encrypted tail
|
||||
let encrypted_full = encryptor.encrypt(nonce);
|
||||
|
||||
// Result: unencrypted head + encrypted tail
|
||||
let mut result = nonce[..PROTO_TAG_POS].to_vec();
|
||||
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
|
||||
|
||||
trace!(
|
||||
original = %hex::encode(&nonce[PROTO_TAG_POS..]),
|
||||
encrypted = %hex::encode(&result[PROTO_TAG_POS..]),
|
||||
"Encrypted nonce tail"
|
||||
);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_tg_nonce() {
|
||||
let client_dec_key = [0x42u8; 32];
|
||||
let client_dec_iv = 12345u128;
|
||||
|
||||
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) =
|
||||
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false);
|
||||
|
||||
// Check length
|
||||
assert_eq!(nonce.len(), HANDSHAKE_LEN);
|
||||
|
||||
// Check proto tag is set
|
||||
let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap();
|
||||
assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encrypt_tg_nonce() {
|
||||
let client_dec_key = [0x42u8; 32];
|
||||
let client_dec_iv = 12345u128;
|
||||
|
||||
let (nonce, _, _, _, _) =
|
||||
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false);
|
||||
|
||||
let encrypted = encrypt_tg_nonce(&nonce);
|
||||
|
||||
assert_eq!(encrypted.len(), HANDSHAKE_LEN);
|
||||
|
||||
// First PROTO_TAG_POS bytes should be unchanged
|
||||
assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]);
|
||||
|
||||
// Rest should be different (encrypted)
|
||||
assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]);
|
||||
}
|
||||
}
|
||||
115
src/proxy/masking.rs
Normal file
115
src/proxy/masking.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
//! Masking - forward unrecognized traffic to mask host
|
||||
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::time::timeout;
|
||||
use tracing::debug;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::transport::set_linger_zero;
|
||||
|
||||
const MASK_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const MASK_BUFFER_SIZE: usize = 8192;
|
||||
|
||||
/// Handle a bad client by forwarding to mask host
|
||||
pub async fn handle_bad_client(
|
||||
mut client: TcpStream,
|
||||
initial_data: &[u8],
|
||||
config: &ProxyConfig,
|
||||
) {
|
||||
if !config.mask {
|
||||
// Masking disabled, just consume data
|
||||
consume_client_data(client).await;
|
||||
return;
|
||||
}
|
||||
|
||||
let mask_host = config.mask_host.as_deref()
|
||||
.unwrap_or(&config.tls_domain);
|
||||
let mask_port = config.mask_port;
|
||||
|
||||
debug!(
|
||||
host = %mask_host,
|
||||
port = mask_port,
|
||||
"Forwarding bad client to mask host"
|
||||
);
|
||||
|
||||
// Connect to mask host
|
||||
let mask_addr = format!("{}:{}", mask_host, mask_port);
|
||||
let connect_result = timeout(
|
||||
MASK_TIMEOUT,
|
||||
TcpStream::connect(&mask_addr)
|
||||
).await;
|
||||
|
||||
let mut mask_stream = match connect_result {
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => {
|
||||
debug!(error = %e, "Failed to connect to mask host");
|
||||
consume_client_data(client).await;
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Timeout connecting to mask host");
|
||||
consume_client_data(client).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Send initial data to mask host
|
||||
if mask_stream.write_all(initial_data).await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Relay traffic
|
||||
let (mut client_read, mut client_write) = client.into_split();
|
||||
let (mut mask_read, mut mask_write) = mask_stream.into_split();
|
||||
|
||||
let c2m = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
|
||||
loop {
|
||||
match client_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => {
|
||||
let _ = mask_write.shutdown().await;
|
||||
break;
|
||||
}
|
||||
Ok(n) => {
|
||||
if mask_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let m2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
|
||||
loop {
|
||||
match mask_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => {
|
||||
let _ = client_write.shutdown().await;
|
||||
break;
|
||||
}
|
||||
Ok(n) => {
|
||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for either to complete
|
||||
tokio::select! {
|
||||
_ = c2m => {}
|
||||
_ = m2c => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Just consume all data from client without responding
|
||||
async fn consume_client_data(mut client: TcpStream) {
|
||||
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
|
||||
while let Ok(n) = client.read(&mut buf).await {
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
11
src/proxy/mod.rs
Normal file
11
src/proxy/mod.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
//! Proxy Defs
|
||||
|
||||
pub mod handshake;
|
||||
pub mod client;
|
||||
pub mod relay;
|
||||
pub mod masking;
|
||||
|
||||
pub use handshake::*;
|
||||
pub use client::ClientHandler;
|
||||
pub use relay::*;
|
||||
pub use masking::*;
|
||||
162
src/proxy/relay.rs
Normal file
162
src/proxy/relay.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
//! Bidirectional Relay
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||
use tracing::{debug, trace, warn};
|
||||
use crate::error::Result;
|
||||
use crate::stats::Stats;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
const BUFFER_SIZE: usize = 65536;
|
||||
|
||||
/// Relay data bidirectionally between client and server
|
||||
pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
||||
mut client_reader: CR,
|
||||
mut client_writer: CW,
|
||||
mut server_reader: SR,
|
||||
mut server_writer: SW,
|
||||
user: &str,
|
||||
stats: Arc<Stats>,
|
||||
) -> Result<()>
|
||||
where
|
||||
CR: AsyncRead + Unpin + Send + 'static,
|
||||
CW: AsyncWrite + Unpin + Send + 'static,
|
||||
SR: AsyncRead + Unpin + Send + 'static,
|
||||
SW: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let user_c2s = user.to_string();
|
||||
let user_s2c = user.to_string();
|
||||
|
||||
// Используем Arc::clone вместо stats.clone()
|
||||
let stats_c2s = Arc::clone(&stats);
|
||||
let stats_s2c = Arc::clone(&stats);
|
||||
|
||||
let c2s_bytes = Arc::new(AtomicU64::new(0));
|
||||
let s2c_bytes = Arc::new(AtomicU64::new(0));
|
||||
let c2s_bytes_clone = Arc::clone(&c2s_bytes);
|
||||
let s2c_bytes_clone = Arc::clone(&s2c_bytes);
|
||||
|
||||
// Client -> Server task
|
||||
let c2s = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; BUFFER_SIZE];
|
||||
let mut total_bytes = 0u64;
|
||||
let mut msg_count = 0u64;
|
||||
|
||||
loop {
|
||||
match client_reader.read(&mut buf).await {
|
||||
Ok(0) => {
|
||||
debug!(
|
||||
user = %user_c2s,
|
||||
total_bytes = total_bytes,
|
||||
msgs = msg_count,
|
||||
"Client closed connection (C->S)"
|
||||
);
|
||||
let _ = server_writer.shutdown().await;
|
||||
break;
|
||||
}
|
||||
Ok(n) => {
|
||||
total_bytes += n as u64;
|
||||
msg_count += 1;
|
||||
c2s_bytes_clone.store(total_bytes, Ordering::Relaxed);
|
||||
|
||||
stats_c2s.add_user_octets_from(&user_c2s, n as u64);
|
||||
stats_c2s.increment_user_msgs_from(&user_c2s);
|
||||
|
||||
trace!(
|
||||
user = %user_c2s,
|
||||
bytes = n,
|
||||
total = total_bytes,
|
||||
data_preview = %hex::encode(&buf[..n.min(32)]),
|
||||
"C->S data"
|
||||
);
|
||||
|
||||
if let Err(e) = server_writer.write_all(&buf[..n]).await {
|
||||
debug!(user = %user_c2s, error = %e, "Failed to write to server");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = server_writer.flush().await {
|
||||
debug!(user = %user_c2s, error = %e, "Failed to flush to server");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Server -> Client task
|
||||
let s2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; BUFFER_SIZE];
|
||||
let mut total_bytes = 0u64;
|
||||
let mut msg_count = 0u64;
|
||||
|
||||
loop {
|
||||
match server_reader.read(&mut buf).await {
|
||||
Ok(0) => {
|
||||
debug!(
|
||||
user = %user_s2c,
|
||||
total_bytes = total_bytes,
|
||||
msgs = msg_count,
|
||||
"Server closed connection (S->C)"
|
||||
);
|
||||
let _ = client_writer.shutdown().await;
|
||||
break;
|
||||
}
|
||||
Ok(n) => {
|
||||
total_bytes += n as u64;
|
||||
msg_count += 1;
|
||||
s2c_bytes_clone.store(total_bytes, Ordering::Relaxed);
|
||||
|
||||
stats_s2c.add_user_octets_to(&user_s2c, n as u64);
|
||||
stats_s2c.increment_user_msgs_to(&user_s2c);
|
||||
|
||||
trace!(
|
||||
user = %user_s2c,
|
||||
bytes = n,
|
||||
total = total_bytes,
|
||||
data_preview = %hex::encode(&buf[..n.min(32)]),
|
||||
"S->C data"
|
||||
);
|
||||
|
||||
if let Err(e) = client_writer.write_all(&buf[..n]).await {
|
||||
debug!(user = %user_s2c, error = %e, "Failed to write to client");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = client_writer.flush().await {
|
||||
debug!(user = %user_s2c, error = %e, "Failed to flush to client");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for either direction to complete
|
||||
tokio::select! {
|
||||
result = c2s => {
|
||||
if let Err(e) = result {
|
||||
warn!(error = %e, "C->S task panicked");
|
||||
}
|
||||
}
|
||||
result = s2c => {
|
||||
if let Err(e) = result {
|
||||
warn!(error = %e, "S->C task panicked");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
c2s_bytes = c2s_bytes.load(Ordering::Relaxed),
|
||||
s2c_bytes = s2c_bytes.load(Ordering::Relaxed),
|
||||
"Relay finished"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
223
src/stats/mod.rs
Normal file
223
src/stats/mod.rs
Normal file
@@ -0,0 +1,223 @@
|
||||
//! Statistics
|
||||
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::RwLock;
|
||||
use lru::LruCache;
|
||||
use std::num::NonZeroUsize;
|
||||
|
||||
/// Thread-safe statistics
|
||||
#[derive(Default)]
|
||||
pub struct Stats {
|
||||
// Global counters
|
||||
connects_all: AtomicU64,
|
||||
connects_bad: AtomicU64,
|
||||
handshake_timeouts: AtomicU64,
|
||||
|
||||
// Per-user stats
|
||||
user_stats: DashMap<String, UserStats>,
|
||||
|
||||
// Start time
|
||||
start_time: RwLock<Option<Instant>>,
|
||||
}
|
||||
|
||||
/// Per-user statistics
|
||||
#[derive(Default)]
|
||||
pub struct UserStats {
|
||||
pub connects: AtomicU64,
|
||||
pub curr_connects: AtomicU64,
|
||||
pub octets_from_client: AtomicU64,
|
||||
pub octets_to_client: AtomicU64,
|
||||
pub msgs_from_client: AtomicU64,
|
||||
pub msgs_to_client: AtomicU64,
|
||||
}
|
||||
|
||||
impl Stats {
|
||||
pub fn new() -> Self {
|
||||
let stats = Self::default();
|
||||
*stats.start_time.write() = Some(Instant::now());
|
||||
stats
|
||||
}
|
||||
|
||||
// Global stats
|
||||
pub fn increment_connects_all(&self) {
|
||||
self.connects_all.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_connects_bad(&self) {
|
||||
self.connects_bad.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_handshake_timeouts(&self) {
|
||||
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn get_connects_all(&self) -> u64 {
|
||||
self.connects_all.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn get_connects_bad(&self) -> u64 {
|
||||
self.connects_bad.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
// User stats
|
||||
pub fn increment_user_connects(&self, user: &str) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.connects
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_user_curr_connects(&self, user: &str) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.curr_connects
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn decrement_user_curr_connects(&self, user: &str) {
|
||||
if let Some(stats) = self.user_stats.get(user) {
|
||||
stats.curr_connects.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_user_curr_connects(&self, user: &str) -> u64 {
|
||||
self.user_stats
|
||||
.get(user)
|
||||
.map(|s| s.curr_connects.load(Ordering::Relaxed))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn add_user_octets_from(&self, user: &str, bytes: u64) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.octets_from_client
|
||||
.fetch_add(bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.octets_to_client
|
||||
.fetch_add(bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_user_msgs_from(&self, user: &str) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.msgs_from_client
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_user_msgs_to(&self, user: &str) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.msgs_to_client
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn get_user_total_octets(&self, user: &str) -> u64 {
|
||||
self.user_stats
|
||||
.get(user)
|
||||
.map(|s| {
|
||||
s.octets_from_client.load(Ordering::Relaxed) +
|
||||
s.octets_to_client.load(Ordering::Relaxed)
|
||||
})
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn uptime_secs(&self) -> f64 {
|
||||
self.start_time.read()
|
||||
.map(|t| t.elapsed().as_secs_f64())
|
||||
.unwrap_or(0.0)
|
||||
}
|
||||
}
|
||||
|
||||
// Arc<Stats> Hightech Stats :D
|
||||
|
||||
/// Replay attack checker using LRU cache
|
||||
pub struct ReplayChecker {
|
||||
handshakes: RwLock<LruCache<Vec<u8>, ()>>,
|
||||
tls_digests: RwLock<LruCache<Vec<u8>, ()>>,
|
||||
}
|
||||
|
||||
impl ReplayChecker {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
let cap = NonZeroUsize::new(capacity.max(1)).unwrap();
|
||||
Self {
|
||||
handshakes: RwLock::new(LruCache::new(cap)),
|
||||
tls_digests: RwLock::new(LruCache::new(cap)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_handshake(&self, data: &[u8]) -> bool {
|
||||
self.handshakes.read().contains(&data.to_vec())
|
||||
}
|
||||
|
||||
pub fn add_handshake(&self, data: &[u8]) {
|
||||
self.handshakes.write().put(data.to_vec(), ());
|
||||
}
|
||||
|
||||
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
|
||||
self.tls_digests.read().contains(&data.to_vec())
|
||||
}
|
||||
|
||||
pub fn add_tls_digest(&self, data: &[u8]) {
|
||||
self.tls_digests.write().put(data.to_vec(), ());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_stats_shared_counters() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
// Симулируем использование из разных "задач"
|
||||
let stats1 = Arc::clone(&stats);
|
||||
let stats2 = Arc::clone(&stats);
|
||||
|
||||
stats1.increment_connects_all();
|
||||
stats2.increment_connects_all();
|
||||
stats1.increment_connects_all();
|
||||
|
||||
// Все инкременты должны быть видны
|
||||
assert_eq!(stats.get_connects_all(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_user_stats_shared() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let stats1 = Arc::clone(&stats);
|
||||
let stats2 = Arc::clone(&stats);
|
||||
|
||||
stats1.add_user_octets_from("user1", 100);
|
||||
stats2.add_user_octets_from("user1", 200);
|
||||
stats1.add_user_octets_to("user1", 50);
|
||||
|
||||
assert_eq!(stats.get_user_total_octets("user1"), 350);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_user_connects() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
stats.increment_user_curr_connects("user1");
|
||||
stats.increment_user_curr_connects("user1");
|
||||
assert_eq!(stats.get_user_curr_connects("user1"), 2);
|
||||
|
||||
stats.decrement_user_curr_connects("user1");
|
||||
assert_eq!(stats.get_user_curr_connects("user1"), 1);
|
||||
}
|
||||
}
|
||||
474
src/stream/crypto_stream.rs
Normal file
474
src/stream/crypto_stream.rs
Normal file
@@ -0,0 +1,474 @@
|
||||
//! Encrypted stream wrappers using AES-CTR
|
||||
|
||||
use bytes::{Bytes, BytesMut, BufMut};
|
||||
use std::io::{Error, ErrorKind, Result};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||
use crate::crypto::AesCtr;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
/// Reader that decrypts data using AES-CTR
|
||||
pub struct CryptoReader<R> {
|
||||
upstream: R,
|
||||
decryptor: AesCtr,
|
||||
buffer: BytesMut,
|
||||
}
|
||||
|
||||
impl<R> CryptoReader<R> {
|
||||
/// Create new crypto reader
|
||||
pub fn new(upstream: R, decryptor: AesCtr) -> Self {
|
||||
Self {
|
||||
upstream,
|
||||
decryptor,
|
||||
buffer: BytesMut::with_capacity(8192),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get reference to upstream
|
||||
pub fn get_ref(&self) -> &R {
|
||||
&self.upstream
|
||||
}
|
||||
|
||||
/// Get mutable reference to upstream
|
||||
pub fn get_mut(&mut self) -> &mut R {
|
||||
&mut self.upstream
|
||||
}
|
||||
|
||||
/// Consume and return upstream
|
||||
pub fn into_inner(self) -> R {
|
||||
self.upstream
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> AsyncRead for CryptoReader<R> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<Result<()>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
if !this.buffer.is_empty() {
|
||||
let to_copy = this.buffer.len().min(buf.remaining());
|
||||
buf.put_slice(&this.buffer.split_to(to_copy));
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
// Zero-copy Reader
|
||||
let before = buf.filled().len();
|
||||
|
||||
match Pin::new(&mut this.upstream).poll_read(cx, buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let after = buf.filled().len();
|
||||
let bytes_read = after - before;
|
||||
|
||||
if bytes_read > 0 {
|
||||
// Decrypt in-place
|
||||
let filled = buf.filled_mut();
|
||||
this.decryptor.apply(&mut filled[before..after]);
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> CryptoReader<R> {
|
||||
/// Read and decrypt exactly n bytes with Async
|
||||
pub async fn read_exact_decrypt(&mut self, n: usize) -> Result<Bytes> {
|
||||
let mut result = BytesMut::with_capacity(n);
|
||||
|
||||
if !self.buffer.is_empty() {
|
||||
let to_take = self.buffer.len().min(n);
|
||||
result.extend_from_slice(&self.buffer.split_to(to_take));
|
||||
}
|
||||
|
||||
// Reread
|
||||
while result.len() < n {
|
||||
let mut temp = vec![0u8; n - result.len()];
|
||||
let read = self.upstream.read(&mut temp).await?;
|
||||
|
||||
if read == 0 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "Connection closed"));
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
self.decryptor.apply(&mut temp[..read]);
|
||||
result.extend_from_slice(&temp[..read]);
|
||||
}
|
||||
|
||||
Ok(result.freeze())
|
||||
}
|
||||
}
|
||||
|
||||
/// Writer that encrypts data using AES-CTR
|
||||
pub struct CryptoWriter<W> {
|
||||
upstream: W,
|
||||
encryptor: AesCtr,
|
||||
pending: BytesMut,
|
||||
}
|
||||
|
||||
impl<W> CryptoWriter<W> {
|
||||
pub fn new(upstream: W, encryptor: AesCtr) -> Self {
|
||||
Self {
|
||||
upstream,
|
||||
encryptor,
|
||||
pending: BytesMut::with_capacity(8192),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_ref(&self) -> &W {
|
||||
&self.upstream
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self) -> &mut W {
|
||||
&mut self.upstream
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> W {
|
||||
self.upstream
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
if !this.pending.is_empty() {
|
||||
match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) {
|
||||
Poll::Ready(Ok(written)) => {
|
||||
let _ = this.pending.split_to(written);
|
||||
|
||||
if !this.pending.is_empty() {
|
||||
cx.waker().wake_by_ref();
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
// Pending Null
|
||||
if buf.is_empty() {
|
||||
return Poll::Ready(Ok(0));
|
||||
}
|
||||
|
||||
// Encrypt
|
||||
let mut encrypted = buf.to_vec();
|
||||
this.encryptor.apply(&mut encrypted);
|
||||
|
||||
// Write Try
|
||||
match Pin::new(&mut this.upstream).poll_write(cx, &encrypted) {
|
||||
Poll::Ready(Ok(written)) => {
|
||||
if written < encrypted.len() {
|
||||
// Partial write — сохраняем остаток в pending
|
||||
this.pending.extend_from_slice(&encrypted[written..]);
|
||||
}
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => {
|
||||
this.pending.extend_from_slice(&encrypted);
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
while !this.pending.is_empty() {
|
||||
match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) {
|
||||
Poll::Ready(Ok(0)) => {
|
||||
return Poll::Ready(Err(Error::new(
|
||||
ErrorKind::WriteZero,
|
||||
"Failed to write pending data during flush",
|
||||
)));
|
||||
}
|
||||
Poll::Ready(Ok(written)) => {
|
||||
let _ = this.pending.split_to(written);
|
||||
}
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
Pin::new(&mut this.upstream).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
while !this.pending.is_empty() {
|
||||
match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) {
|
||||
Poll::Ready(Ok(0)) => {
|
||||
break;
|
||||
}
|
||||
Poll::Ready(Ok(written)) => {
|
||||
let _ = this.pending.split_to(written);
|
||||
}
|
||||
Poll::Ready(Err(_)) => {
|
||||
break;
|
||||
}
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
Pin::new(&mut this.upstream).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// Passthrough stream for fast mode - no encryption/decryption
|
||||
pub struct PassthroughStream<S> {
|
||||
inner: S,
|
||||
}
|
||||
|
||||
impl<S> PassthroughStream<S> {
|
||||
pub fn new(inner: S) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> S {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> AsyncRead for PassthroughStream<S> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> AsyncWrite for PassthroughStream<S> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::VecDeque;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll, Waker, RawWaker, RawWakerVTable};
|
||||
use tokio::io::duplex;
|
||||
|
||||
/// Mock writer
|
||||
struct PartialWriter {
|
||||
chunk_size: usize,
|
||||
data: Vec<u8>,
|
||||
write_count: usize,
|
||||
}
|
||||
|
||||
impl PartialWriter {
|
||||
fn new(chunk_size: usize) -> Self {
|
||||
Self {
|
||||
chunk_size,
|
||||
data: Vec::new(),
|
||||
write_count: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for PartialWriter {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
self.write_count += 1;
|
||||
let to_write = buf.len().min(self.chunk_size);
|
||||
self.data.extend_from_slice(&buf[..to_write]);
|
||||
Poll::Ready(Ok(to_write))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
fn noop_waker() -> Waker {
|
||||
const VTABLE: RawWakerVTable = RawWakerVTable::new(
|
||||
|_| RawWaker::new(std::ptr::null(), &VTABLE),
|
||||
|_| {},
|
||||
|_| {},
|
||||
|_| {},
|
||||
);
|
||||
unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crypto_writer_partial_write_correctness() {
|
||||
let key = [0x42u8; 32];
|
||||
let iv = 12345u128;
|
||||
|
||||
// 10-byte Writer
|
||||
let mock_writer = PartialWriter::new(10);
|
||||
let encryptor = AesCtr::new(&key, iv);
|
||||
let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor);
|
||||
|
||||
let waker = noop_waker();
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
// 25 byte
|
||||
let original = b"Hello, this is test data!";
|
||||
|
||||
// First Write
|
||||
let result = Pin::new(&mut crypto_writer).poll_write(&mut cx, original);
|
||||
assert!(matches!(result, Poll::Ready(Ok(25))));
|
||||
|
||||
// Flush before continue Pending
|
||||
loop {
|
||||
match Pin::new(&mut crypto_writer).poll_flush(&mut cx) {
|
||||
Poll::Ready(Ok(())) => break,
|
||||
Poll::Ready(Err(e)) => panic!("Flush error: {}", e),
|
||||
Poll::Pending => continue,
|
||||
}
|
||||
}
|
||||
|
||||
// Write Check
|
||||
let encrypted = &crypto_writer.upstream.data;
|
||||
assert_eq!(encrypted.len(), 25);
|
||||
|
||||
// Decrypt + Verify
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
let mut decrypted = encrypted.clone();
|
||||
decryptor.apply(&mut decrypted);
|
||||
|
||||
assert_eq!(&decrypted, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crypto_writer_multiple_partial_writes() {
|
||||
let key = [0xAB; 32];
|
||||
let iv = 9999u128;
|
||||
|
||||
let mock_writer = PartialWriter::new(3);
|
||||
let encryptor = AesCtr::new(&key, iv);
|
||||
let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor);
|
||||
|
||||
let waker = noop_waker();
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let data1 = b"First";
|
||||
let data2 = b"Second";
|
||||
let data3 = b"Third";
|
||||
|
||||
Pin::new(&mut crypto_writer).poll_write(&mut cx, data1).unwrap();
|
||||
// Flush
|
||||
while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {}
|
||||
|
||||
Pin::new(&mut crypto_writer).poll_write(&mut cx, data2).unwrap();
|
||||
while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {}
|
||||
|
||||
Pin::new(&mut crypto_writer).poll_write(&mut cx, data3).unwrap();
|
||||
while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {}
|
||||
|
||||
// Assemble
|
||||
let mut expected = Vec::new();
|
||||
expected.extend_from_slice(data1);
|
||||
expected.extend_from_slice(data2);
|
||||
expected.extend_from_slice(data3);
|
||||
|
||||
// Decrypt
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
let mut decrypted = crypto_writer.upstream.data.clone();
|
||||
decryptor.apply(&mut decrypted);
|
||||
|
||||
assert_eq!(decrypted, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_crypto_stream_roundtrip() {
|
||||
let key = [0u8; 32];
|
||||
let iv = 12345u128;
|
||||
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let encryptor = AesCtr::new(&key, iv);
|
||||
let decryptor = AesCtr::new(&key, iv);
|
||||
|
||||
let mut writer = CryptoWriter::new(client, encryptor);
|
||||
let mut reader = CryptoReader::new(server, decryptor);
|
||||
|
||||
// Write
|
||||
let original = b"Hello, encrypted world!";
|
||||
writer.write_all(original).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
// Read
|
||||
let mut buf = vec![0u8; original.len()];
|
||||
reader.read_exact(&mut buf).await.unwrap();
|
||||
|
||||
assert_eq!(&buf, original);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_crypto_stream_large_data() {
|
||||
let key = [0x55u8; 32];
|
||||
let iv = 777u128;
|
||||
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let encryptor = AesCtr::new(&key, iv);
|
||||
let decryptor = AesCtr::new(&key, iv);
|
||||
|
||||
let mut writer = CryptoWriter::new(client, encryptor);
|
||||
let mut reader = CryptoReader::new(server, decryptor);
|
||||
|
||||
// Hugeload
|
||||
let original: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
|
||||
|
||||
// Write
|
||||
let write_data = original.clone();
|
||||
let write_handle = tokio::spawn(async move {
|
||||
writer.write_all(&write_data).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
writer.shutdown().await.unwrap();
|
||||
});
|
||||
|
||||
// Read
|
||||
let mut received = Vec::new();
|
||||
let mut buf = vec![0u8; 1024];
|
||||
loop {
|
||||
match reader.read(&mut buf).await {
|
||||
Ok(0) => break,
|
||||
Ok(n) => received.extend_from_slice(&buf[..n]),
|
||||
Err(e) => panic!("Read error: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
write_handle.await.unwrap();
|
||||
|
||||
assert_eq!(received, original);
|
||||
}
|
||||
}
|
||||
585
src/stream/frame_stream.rs
Normal file
585
src/stream/frame_stream.rs
Normal file
@@ -0,0 +1,585 @@
|
||||
//! MTProto frame stream wrappers
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use std::io::{Error, ErrorKind, Result};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||
use crate::protocol::constants::*;
|
||||
use crate::crypto::crc32;
|
||||
use crate::crypto::random::SECURE_RANDOM;
|
||||
use super::traits::{FrameMeta, LayeredStream};
|
||||
|
||||
// ============= Abridged (Compact) Frame =============
|
||||
|
||||
/// Reader for abridged MTProto framing
|
||||
pub struct AbridgedFrameReader<R> {
|
||||
upstream: R,
|
||||
}
|
||||
|
||||
impl<R> AbridgedFrameReader<R> {
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self { upstream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> AbridgedFrameReader<R> {
|
||||
/// Read a frame and return (data, metadata)
|
||||
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
|
||||
let mut meta = FrameMeta::new();
|
||||
|
||||
// Read length byte
|
||||
let mut len_byte = [0u8];
|
||||
self.upstream.read_exact(&mut len_byte).await?;
|
||||
|
||||
let mut len = len_byte[0] as usize;
|
||||
|
||||
// Check QuickACK flag (high bit)
|
||||
if len >= 0x80 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80;
|
||||
}
|
||||
|
||||
// Extended length (3 bytes)
|
||||
if len == 0x7f {
|
||||
let mut len_bytes = [0u8; 3];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize;
|
||||
}
|
||||
|
||||
// Length is in 4-byte words
|
||||
let byte_len = len * 4;
|
||||
|
||||
// Read data
|
||||
let mut data = vec![0u8; byte_len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
Ok((Bytes::from(data), meta))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> LayeredStream<R> for AbridgedFrameReader<R> {
|
||||
fn upstream(&self) -> &R { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
|
||||
fn into_upstream(self) -> R { self.upstream }
|
||||
}
|
||||
|
||||
/// Writer for abridged MTProto framing
|
||||
pub struct AbridgedFrameWriter<W> {
|
||||
upstream: W,
|
||||
}
|
||||
|
||||
impl<W> AbridgedFrameWriter<W> {
|
||||
pub fn new(upstream: W) -> Self {
|
||||
Self { upstream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> AbridgedFrameWriter<W> {
|
||||
/// Write a frame
|
||||
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
|
||||
if data.len() % 4 != 0 {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("Abridged frame must be aligned to 4 bytes, got {}", data.len()),
|
||||
));
|
||||
}
|
||||
|
||||
// Simple ACK: send reversed data
|
||||
if meta.simple_ack {
|
||||
let reversed: Vec<u8> = data.iter().rev().copied().collect();
|
||||
self.upstream.write_all(&reversed).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let len_div_4 = data.len() / 4;
|
||||
|
||||
if len_div_4 < 0x7f {
|
||||
// Short length (1 byte)
|
||||
self.upstream.write_all(&[len_div_4 as u8]).await?;
|
||||
} else if len_div_4 < (1 << 24) {
|
||||
// Long length (4 bytes: 0x7f + 3 bytes)
|
||||
let mut header = [0x7f, 0, 0, 0];
|
||||
header[1..4].copy_from_slice(&(len_div_4 as u32).to_le_bytes()[..3]);
|
||||
self.upstream.write_all(&header).await?;
|
||||
} else {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("Frame too large: {} bytes", data.len()),
|
||||
));
|
||||
}
|
||||
|
||||
self.upstream.write_all(data).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
self.upstream.flush().await
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> LayeredStream<W> for AbridgedFrameWriter<W> {
|
||||
fn upstream(&self) -> &W { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
|
||||
fn into_upstream(self) -> W { self.upstream }
|
||||
}
|
||||
|
||||
// ============= Intermediate Frame =============
|
||||
|
||||
/// Reader for intermediate MTProto framing
|
||||
pub struct IntermediateFrameReader<R> {
|
||||
upstream: R,
|
||||
}
|
||||
|
||||
impl<R> IntermediateFrameReader<R> {
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self { upstream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> IntermediateFrameReader<R> {
|
||||
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
|
||||
let mut meta = FrameMeta::new();
|
||||
|
||||
// Read 4-byte length
|
||||
let mut len_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
|
||||
let mut len = u32::from_le_bytes(len_bytes) as usize;
|
||||
|
||||
// Check QuickACK flag (high bit)
|
||||
if len > 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
|
||||
// Read data
|
||||
let mut data = vec![0u8; len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
Ok((Bytes::from(data), meta))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> LayeredStream<R> for IntermediateFrameReader<R> {
|
||||
fn upstream(&self) -> &R { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
|
||||
fn into_upstream(self) -> R { self.upstream }
|
||||
}
|
||||
|
||||
/// Writer for intermediate MTProto framing
|
||||
pub struct IntermediateFrameWriter<W> {
|
||||
upstream: W,
|
||||
}
|
||||
|
||||
impl<W> IntermediateFrameWriter<W> {
|
||||
pub fn new(upstream: W) -> Self {
|
||||
Self { upstream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> IntermediateFrameWriter<W> {
|
||||
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
|
||||
if meta.simple_ack {
|
||||
self.upstream.write_all(data).await?;
|
||||
} else {
|
||||
let len_bytes = (data.len() as u32).to_le_bytes();
|
||||
self.upstream.write_all(&len_bytes).await?;
|
||||
self.upstream.write_all(data).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
self.upstream.flush().await
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> LayeredStream<W> for IntermediateFrameWriter<W> {
|
||||
fn upstream(&self) -> &W { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
|
||||
fn into_upstream(self) -> W { self.upstream }
|
||||
}
|
||||
|
||||
// ============= Secure Intermediate Frame =============
|
||||
|
||||
/// Reader for secure intermediate MTProto framing (with padding)
|
||||
pub struct SecureIntermediateFrameReader<R> {
|
||||
upstream: R,
|
||||
}
|
||||
|
||||
impl<R> SecureIntermediateFrameReader<R> {
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self { upstream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
|
||||
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
|
||||
let mut meta = FrameMeta::new();
|
||||
|
||||
// Read 4-byte length
|
||||
let mut len_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
|
||||
let mut len = u32::from_le_bytes(len_bytes) as usize;
|
||||
|
||||
// Check QuickACK flag
|
||||
if len > 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
|
||||
// Read data (including padding)
|
||||
let mut data = vec![0u8; len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
// Strip padding (not aligned to 4)
|
||||
if len % 4 != 0 {
|
||||
let actual_len = len - (len % 4);
|
||||
data.truncate(actual_len);
|
||||
}
|
||||
|
||||
Ok((Bytes::from(data), meta))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> LayeredStream<R> for SecureIntermediateFrameReader<R> {
|
||||
fn upstream(&self) -> &R { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
|
||||
fn into_upstream(self) -> R { self.upstream }
|
||||
}
|
||||
|
||||
/// Writer for secure intermediate MTProto framing
|
||||
pub struct SecureIntermediateFrameWriter<W> {
|
||||
upstream: W,
|
||||
}
|
||||
|
||||
impl<W> SecureIntermediateFrameWriter<W> {
|
||||
pub fn new(upstream: W) -> Self {
|
||||
Self { upstream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
|
||||
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
|
||||
if meta.simple_ack {
|
||||
self.upstream.write_all(data).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Add random padding (0-3 bytes)
|
||||
let padding_len = SECURE_RANDOM.range(4);
|
||||
let padding = SECURE_RANDOM.bytes(padding_len);
|
||||
|
||||
let total_len = data.len() + padding_len;
|
||||
let len_bytes = (total_len as u32).to_le_bytes();
|
||||
|
||||
self.upstream.write_all(&len_bytes).await?;
|
||||
self.upstream.write_all(data).await?;
|
||||
self.upstream.write_all(&padding).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
self.upstream.flush().await
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> LayeredStream<W> for SecureIntermediateFrameWriter<W> {
|
||||
fn upstream(&self) -> &W { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
|
||||
fn into_upstream(self) -> W { self.upstream }
|
||||
}
|
||||
|
||||
// ============= Full MTProto Frame (with CRC) =============
|
||||
|
||||
/// Reader for full MTProto framing with sequence numbers and CRC32
|
||||
pub struct MtprotoFrameReader<R> {
|
||||
upstream: R,
|
||||
seq_no: i32,
|
||||
}
|
||||
|
||||
impl<R> MtprotoFrameReader<R> {
|
||||
pub fn new(upstream: R, start_seq: i32) -> Self {
|
||||
Self { upstream, seq_no: start_seq }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> MtprotoFrameReader<R> {
|
||||
pub async fn read_frame(&mut self) -> Result<Bytes> {
|
||||
loop {
|
||||
// Read length (4 bytes)
|
||||
let mut len_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
let len = u32::from_le_bytes(len_bytes) as usize;
|
||||
|
||||
// Skip padding-only messages
|
||||
if len == 4 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Validate length
|
||||
if len < MIN_MSG_LEN || len > MAX_MSG_LEN || len % PADDING_FILLER.len() != 0 {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("Invalid message length: {}", len),
|
||||
));
|
||||
}
|
||||
|
||||
// Read sequence number
|
||||
let mut seq_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut seq_bytes).await?;
|
||||
let msg_seq = i32::from_le_bytes(seq_bytes);
|
||||
|
||||
if msg_seq != self.seq_no {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("Sequence mismatch: expected {}, got {}", self.seq_no, msg_seq),
|
||||
));
|
||||
}
|
||||
self.seq_no += 1;
|
||||
|
||||
// Read data (length - 4 len - 4 seq - 4 crc = len - 12)
|
||||
let data_len = len - 12;
|
||||
let mut data = vec![0u8; data_len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
// Read and verify CRC32
|
||||
let mut crc_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut crc_bytes).await?;
|
||||
let expected_crc = u32::from_le_bytes(crc_bytes);
|
||||
|
||||
// Compute CRC over len + seq + data
|
||||
let mut crc_input = Vec::with_capacity(8 + data_len);
|
||||
crc_input.extend_from_slice(&len_bytes);
|
||||
crc_input.extend_from_slice(&seq_bytes);
|
||||
crc_input.extend_from_slice(&data);
|
||||
let computed_crc = crc32(&crc_input);
|
||||
|
||||
if computed_crc != expected_crc {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("CRC mismatch: expected {:08x}, got {:08x}", expected_crc, computed_crc),
|
||||
));
|
||||
}
|
||||
|
||||
return Ok(Bytes::from(data));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Writer for full MTProto framing
|
||||
pub struct MtprotoFrameWriter<W> {
|
||||
upstream: W,
|
||||
seq_no: i32,
|
||||
}
|
||||
|
||||
impl<W> MtprotoFrameWriter<W> {
|
||||
pub fn new(upstream: W, start_seq: i32) -> Self {
|
||||
Self { upstream, seq_no: start_seq }
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> MtprotoFrameWriter<W> {
|
||||
pub async fn write_frame(&mut self, msg: &[u8]) -> Result<()> {
|
||||
// Total length: 4 (len) + 4 (seq) + data + 4 (crc)
|
||||
let len = msg.len() + 12;
|
||||
|
||||
let len_bytes = (len as u32).to_le_bytes();
|
||||
let seq_bytes = self.seq_no.to_le_bytes();
|
||||
self.seq_no += 1;
|
||||
|
||||
// Compute CRC
|
||||
let mut crc_input = Vec::with_capacity(8 + msg.len());
|
||||
crc_input.extend_from_slice(&len_bytes);
|
||||
crc_input.extend_from_slice(&seq_bytes);
|
||||
crc_input.extend_from_slice(msg);
|
||||
let checksum = crc32(&crc_input);
|
||||
let crc_bytes = checksum.to_le_bytes();
|
||||
|
||||
// Calculate padding for CBC alignment
|
||||
let total_len = len_bytes.len() + seq_bytes.len() + msg.len() + crc_bytes.len();
|
||||
let padding_needed = (CBC_PADDING - (total_len % CBC_PADDING)) % CBC_PADDING;
|
||||
let padding_count = padding_needed / PADDING_FILLER.len();
|
||||
|
||||
// Write everything
|
||||
self.upstream.write_all(&len_bytes).await?;
|
||||
self.upstream.write_all(&seq_bytes).await?;
|
||||
self.upstream.write_all(msg).await?;
|
||||
self.upstream.write_all(&crc_bytes).await?;
|
||||
|
||||
for _ in 0..padding_count {
|
||||
self.upstream.write_all(&PADDING_FILLER).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
self.upstream.flush().await
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Frame Type Enum =============
|
||||
|
||||
/// Enum for different frame stream types
|
||||
pub enum FrameReaderKind<R> {
|
||||
Abridged(AbridgedFrameReader<R>),
|
||||
Intermediate(IntermediateFrameReader<R>),
|
||||
SecureIntermediate(SecureIntermediateFrameReader<R>),
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> FrameReaderKind<R> {
|
||||
pub fn new(upstream: R, proto_tag: ProtoTag) -> Self {
|
||||
match proto_tag {
|
||||
ProtoTag::Abridged => FrameReaderKind::Abridged(AbridgedFrameReader::new(upstream)),
|
||||
ProtoTag::Intermediate => FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream)),
|
||||
ProtoTag::Secure => FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
|
||||
match self {
|
||||
FrameReaderKind::Abridged(r) => r.read_frame().await,
|
||||
FrameReaderKind::Intermediate(r) => r.read_frame().await,
|
||||
FrameReaderKind::SecureIntermediate(r) => r.read_frame().await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum FrameWriterKind<W> {
|
||||
Abridged(AbridgedFrameWriter<W>),
|
||||
Intermediate(IntermediateFrameWriter<W>),
|
||||
SecureIntermediate(SecureIntermediateFrameWriter<W>),
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
|
||||
pub fn new(upstream: W, proto_tag: ProtoTag) -> Self {
|
||||
match proto_tag {
|
||||
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
|
||||
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
|
||||
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
|
||||
match self {
|
||||
FrameWriterKind::Abridged(w) => w.write_frame(data, meta).await,
|
||||
FrameWriterKind::Intermediate(w) => w.write_frame(data, meta).await,
|
||||
FrameWriterKind::SecureIntermediate(w) => w.write_frame(data, meta).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
match self {
|
||||
FrameWriterKind::Abridged(w) => w.flush().await,
|
||||
FrameWriterKind::Intermediate(w) => w.flush().await,
|
||||
FrameWriterKind::SecureIntermediate(w) => w.flush().await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::duplex;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_abridged_roundtrip() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = AbridgedFrameWriter::new(client);
|
||||
let mut reader = AbridgedFrameReader::new(server);
|
||||
|
||||
// Short frame
|
||||
let data = vec![1u8, 2, 3, 4]; // 4 bytes = 1 word
|
||||
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let (received, _meta) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_abridged_long_frame() {
|
||||
let (client, server) = duplex(65536);
|
||||
|
||||
let mut writer = AbridgedFrameWriter::new(client);
|
||||
let mut reader = AbridgedFrameReader::new(server);
|
||||
|
||||
// Long frame (> 0x7f words = 508 bytes)
|
||||
let data: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
|
||||
let padded_len = (data.len() + 3) / 4 * 4;
|
||||
let mut padded = data.clone();
|
||||
padded.resize(padded_len, 0);
|
||||
|
||||
writer.write_frame(&padded, &FrameMeta::new()).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let (received, _meta) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &padded[..]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_intermediate_roundtrip() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = IntermediateFrameWriter::new(client);
|
||||
let mut reader = IntermediateFrameReader::new(server);
|
||||
|
||||
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
|
||||
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let (received, _meta) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_secure_intermediate_padding() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = SecureIntermediateFrameWriter::new(client);
|
||||
let mut reader = SecureIntermediateFrameReader::new(server);
|
||||
|
||||
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
|
||||
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let (received, _meta) = reader.read_frame().await.unwrap();
|
||||
// Received should have padding stripped to align to 4
|
||||
let expected_len = (data.len() / 4) * 4;
|
||||
assert_eq!(received.len(), expected_len);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mtproto_frame_roundtrip() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = MtprotoFrameWriter::new(client, 0);
|
||||
let mut reader = MtprotoFrameReader::new(server, 0);
|
||||
|
||||
// Message must be padded properly
|
||||
let data = vec![0u8; 16]; // Aligned to 4 and CBC_PADDING
|
||||
writer.write_frame(&data).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let received = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_frame_reader_kind() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate);
|
||||
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
|
||||
|
||||
let data = vec![1u8, 2, 3, 4];
|
||||
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let (received, _) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
}
|
||||
}
|
||||
10
src/stream/mod.rs
Normal file
10
src/stream/mod.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
//! Stream wrappers for MTProto protocol layers
|
||||
|
||||
pub mod traits;
|
||||
pub mod crypto_stream;
|
||||
pub mod tls_stream;
|
||||
pub mod frame_stream;
|
||||
|
||||
pub use crypto_stream::{CryptoReader, CryptoWriter, PassthroughStream};
|
||||
pub use tls_stream::{FakeTlsReader, FakeTlsWriter};
|
||||
pub use frame_stream::*;
|
||||
277
src/stream/tls_stream.rs
Normal file
277
src/stream/tls_stream.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
//! Fake TLS 1.3 stream wrappers
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use std::io::{Error, ErrorKind, Result};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||
use crate::protocol::constants::{
|
||||
TLS_VERSION, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
|
||||
MAX_TLS_CHUNK_SIZE,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
|
||||
/// Reader that unwraps TLS 1.3 records
|
||||
pub struct FakeTlsReader<R> {
|
||||
upstream: R,
|
||||
buffer: BytesMut,
|
||||
pending_read: Option<PendingTlsRead>,
|
||||
}
|
||||
|
||||
struct PendingTlsRead {
|
||||
record_type: u8,
|
||||
remaining: usize,
|
||||
}
|
||||
|
||||
impl<R> FakeTlsReader<R> {
|
||||
/// Create new fake TLS reader
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self {
|
||||
upstream,
|
||||
buffer: BytesMut::with_capacity(16384),
|
||||
pending_read: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get reference to upstream
|
||||
pub fn get_ref(&self) -> &R {
|
||||
&self.upstream
|
||||
}
|
||||
|
||||
/// Get mutable reference to upstream
|
||||
pub fn get_mut(&mut self) -> &mut R {
|
||||
&mut self.upstream
|
||||
}
|
||||
|
||||
/// Consume and return upstream
|
||||
pub fn into_inner(self) -> R {
|
||||
self.upstream
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
|
||||
/// Read exactly n bytes through TLS layer
|
||||
pub async fn read_exact(&mut self, n: usize) -> Result<Bytes> {
|
||||
while self.buffer.len() < n {
|
||||
let data = self.read_tls_record().await?;
|
||||
if data.is_empty() {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "Connection closed"));
|
||||
}
|
||||
self.buffer.extend_from_slice(&data);
|
||||
}
|
||||
|
||||
Ok(self.buffer.split_to(n).freeze())
|
||||
}
|
||||
|
||||
/// Read a single TLS record
|
||||
async fn read_tls_record(&mut self) -> Result<Vec<u8>> {
|
||||
loop {
|
||||
// Read TLS record header (5 bytes)
|
||||
let mut header = [0u8; 5];
|
||||
self.upstream.read_exact(&mut header).await?;
|
||||
|
||||
let record_type = header[0];
|
||||
let version = [header[1], header[2]];
|
||||
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
|
||||
|
||||
// Validate version
|
||||
if version != TLS_VERSION {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("Invalid TLS version: {:02x?}", version),
|
||||
));
|
||||
}
|
||||
|
||||
// Read record body
|
||||
let mut data = vec![0u8; length];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
match record_type {
|
||||
TLS_RECORD_CHANGE_CIPHER => continue, // Skip
|
||||
TLS_RECORD_APPLICATION => return Ok(data),
|
||||
_ => {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("Unexpected TLS record type: 0x{:02x}", record_type),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<Result<()>> {
|
||||
// Drain buffer first
|
||||
if !self.buffer.is_empty() {
|
||||
let to_copy = self.buffer.len().min(buf.remaining());
|
||||
buf.put_slice(&self.buffer.split_to(to_copy));
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
// We need to read a TLS record, but poll_read doesn't support async/await
|
||||
// So we'll do a simplified version that reads header synchronously
|
||||
|
||||
// Read header
|
||||
let mut header = [0u8; 5];
|
||||
let mut header_buf = ReadBuf::new(&mut header);
|
||||
|
||||
match Pin::new(&mut self.upstream).poll_read(cx, &mut header_buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
if header_buf.filled().len() < 5 {
|
||||
// Need more data - store what we have and return pending
|
||||
// For simplicity, we'll just return empty
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
|
||||
let record_type = header[0];
|
||||
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
|
||||
|
||||
if record_type == TLS_RECORD_CHANGE_CIPHER {
|
||||
// Skip this record, try again
|
||||
cx.waker().wake_by_ref();
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
if record_type != TLS_RECORD_APPLICATION {
|
||||
return Poll::Ready(Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"Invalid TLS record type",
|
||||
)));
|
||||
}
|
||||
|
||||
// Read body
|
||||
let mut body = vec![0u8; length];
|
||||
let mut body_buf = ReadBuf::new(&mut body);
|
||||
|
||||
match Pin::new(&mut self.upstream).poll_read(cx, &mut body_buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let filled = body_buf.filled();
|
||||
let to_copy = filled.len().min(buf.remaining());
|
||||
buf.put_slice(&filled[..to_copy]);
|
||||
|
||||
if filled.len() > to_copy {
|
||||
self.buffer.extend_from_slice(&filled[to_copy..]);
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Writer that wraps data in TLS 1.3 records
|
||||
pub struct FakeTlsWriter<W> {
|
||||
upstream: W,
|
||||
}
|
||||
|
||||
impl<W> FakeTlsWriter<W> {
|
||||
/// Create new fake TLS writer
|
||||
pub fn new(upstream: W) -> Self {
|
||||
Self { upstream }
|
||||
}
|
||||
|
||||
/// Get reference to upstream
|
||||
pub fn get_ref(&self) -> &W {
|
||||
&self.upstream
|
||||
}
|
||||
|
||||
/// Get mutable reference to upstream
|
||||
pub fn get_mut(&mut self) -> &mut W {
|
||||
&mut self.upstream
|
||||
}
|
||||
|
||||
/// Consume and return upstream
|
||||
pub fn into_inner(self) -> W {
|
||||
self.upstream
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
// Build TLS record
|
||||
let chunk_size = buf.len().min(MAX_TLS_CHUNK_SIZE);
|
||||
let chunk = &buf[..chunk_size];
|
||||
|
||||
let mut record = Vec::with_capacity(5 + chunk_size);
|
||||
record.push(TLS_RECORD_APPLICATION);
|
||||
record.extend_from_slice(&TLS_VERSION);
|
||||
record.push((chunk_size >> 8) as u8);
|
||||
record.push(chunk_size as u8);
|
||||
record.extend_from_slice(chunk);
|
||||
|
||||
match Pin::new(&mut self.upstream).poll_write(cx, &record) {
|
||||
Poll::Ready(Ok(written)) => {
|
||||
if written >= 5 {
|
||||
Poll::Ready(Ok(written - 5))
|
||||
} else {
|
||||
Poll::Ready(Ok(0))
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.upstream).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.upstream).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
|
||||
/// Write all data wrapped in TLS records (async method)
|
||||
pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> {
|
||||
for chunk in data.chunks(MAX_TLS_CHUNK_SIZE) {
|
||||
let header = [
|
||||
TLS_RECORD_APPLICATION,
|
||||
TLS_VERSION[0],
|
||||
TLS_VERSION[1],
|
||||
(chunk.len() >> 8) as u8,
|
||||
chunk.len() as u8,
|
||||
];
|
||||
|
||||
self.upstream.write_all(&header).await?;
|
||||
self.upstream.write_all(chunk).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::duplex;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_stream_roundtrip() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let mut writer = FakeTlsWriter::new(client);
|
||||
let mut reader = FakeTlsReader::new(server);
|
||||
|
||||
let original = b"Hello, fake TLS!";
|
||||
writer.write_all_tls(original).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let received = reader.read_exact(original.len()).await.unwrap();
|
||||
assert_eq!(&received[..], original);
|
||||
}
|
||||
}
|
||||
113
src/stream/traits.rs
Normal file
113
src/stream/traits.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
//! Stream traits and common types
|
||||
|
||||
use bytes::Bytes;
|
||||
use std::io::Result;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
/// Extra metadata for frames
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct FrameMeta {
|
||||
/// Quick ACK requested
|
||||
pub quickack: bool,
|
||||
/// This is a simple ACK message
|
||||
pub simple_ack: bool,
|
||||
/// Skip sending this frame
|
||||
pub skip_send: bool,
|
||||
}
|
||||
|
||||
impl FrameMeta {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn with_quickack(mut self) -> Self {
|
||||
self.quickack = true;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_simple_ack(mut self) -> Self {
|
||||
self.simple_ack = true;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of reading a frame
|
||||
#[derive(Debug)]
|
||||
pub enum ReadFrameResult {
|
||||
/// Frame data with metadata
|
||||
Frame(Bytes, FrameMeta),
|
||||
/// Connection closed
|
||||
Closed,
|
||||
}
|
||||
|
||||
/// Trait for streams that wrap another stream
|
||||
pub trait LayeredStream<U> {
|
||||
/// Get reference to upstream
|
||||
fn upstream(&self) -> &U;
|
||||
|
||||
/// Get mutable reference to upstream
|
||||
fn upstream_mut(&mut self) -> &mut U;
|
||||
|
||||
/// Consume self and return upstream
|
||||
fn into_upstream(self) -> U;
|
||||
}
|
||||
|
||||
/// A split read half of a stream
|
||||
pub struct ReadHalf<R> {
|
||||
inner: R,
|
||||
}
|
||||
|
||||
impl<R> ReadHalf<R> {
|
||||
pub fn new(inner: R) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> R {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> AsyncRead for ReadHalf<R> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
/// A split write half of a stream
|
||||
pub struct WriteHalf<W> {
|
||||
inner: W,
|
||||
}
|
||||
|
||||
impl<W> WriteHalf<W> {
|
||||
pub fn new(inner: W) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> W {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> AsyncWrite for WriteHalf<W> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
9
src/transport/mod.rs
Normal file
9
src/transport/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
//! Transport layer: connection pooling, socket utilities, proxy protocol
|
||||
|
||||
pub mod pool;
|
||||
pub mod proxy_protocol;
|
||||
pub mod socket;
|
||||
|
||||
pub use pool::ConnectionPool;
|
||||
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
|
||||
pub use socket::*;
|
||||
338
src/transport/pool.rs
Normal file
338
src/transport/pool.rs
Normal file
@@ -0,0 +1,338 @@
|
||||
//! Connection Pool
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time::timeout;
|
||||
use parking_lot::RwLock;
|
||||
use tracing::{debug, warn};
|
||||
use crate::error::{ProxyError, Result};
|
||||
use super::socket::configure_tcp_socket;
|
||||
|
||||
/// A pooled connection with metadata
|
||||
struct PooledConnection {
|
||||
stream: TcpStream,
|
||||
created_at: Instant,
|
||||
}
|
||||
|
||||
/// Internal pool state for a single endpoint
|
||||
struct PoolInner {
|
||||
/// Available connections
|
||||
connections: Vec<PooledConnection>,
|
||||
/// Number of connections being established
|
||||
pending: usize,
|
||||
}
|
||||
|
||||
impl PoolInner {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
connections: Vec::new(),
|
||||
pending: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection pool configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PoolConfig {
|
||||
/// Maximum connections per endpoint
|
||||
pub max_connections: usize,
|
||||
/// Connection timeout
|
||||
pub connect_timeout: Duration,
|
||||
/// Maximum idle time before connection is dropped
|
||||
pub max_idle_time: Duration,
|
||||
/// Enable TCP keepalive
|
||||
pub keepalive: bool,
|
||||
/// Keepalive interval
|
||||
pub keepalive_interval: Duration,
|
||||
}
|
||||
|
||||
impl Default for PoolConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_connections: 64,
|
||||
connect_timeout: Duration::from_secs(10),
|
||||
max_idle_time: Duration::from_secs(60),
|
||||
keepalive: true,
|
||||
keepalive_interval: Duration::from_secs(40),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Thread-safe connection pool
|
||||
pub struct ConnectionPool {
|
||||
/// Per-endpoint pools
|
||||
pools: RwLock<HashMap<SocketAddr, Arc<Mutex<PoolInner>>>>,
|
||||
/// Configuration
|
||||
config: PoolConfig,
|
||||
}
|
||||
|
||||
impl ConnectionPool {
|
||||
/// Create new connection pool with default config
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(PoolConfig::default())
|
||||
}
|
||||
|
||||
/// Create connection pool with custom config
|
||||
pub fn with_config(config: PoolConfig) -> Self {
|
||||
Self {
|
||||
pools: RwLock::new(HashMap::new()),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get or create pool for an endpoint
|
||||
fn get_or_create_pool(&self, addr: SocketAddr) -> Arc<Mutex<PoolInner>> {
|
||||
// Fast path with read lock
|
||||
{
|
||||
let pools = self.pools.read();
|
||||
if let Some(pool) = pools.get(&addr) {
|
||||
return Arc::clone(pool);
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path with write lock
|
||||
let mut pools = self.pools.write();
|
||||
pools.entry(addr)
|
||||
.or_insert_with(|| Arc::new(Mutex::new(PoolInner::new())))
|
||||
.clone()
|
||||
}
|
||||
|
||||
/// Get a connection to the specified address
|
||||
pub async fn get(&self, addr: SocketAddr) -> Result<TcpStream> {
|
||||
let pool = self.get_or_create_pool(addr);
|
||||
|
||||
// Try to get an existing connection
|
||||
{
|
||||
let mut inner = pool.lock().await;
|
||||
|
||||
// Remove stale connections
|
||||
let now = Instant::now();
|
||||
inner.connections.retain(|c| {
|
||||
now.duration_since(c.created_at) < self.config.max_idle_time
|
||||
});
|
||||
|
||||
// Try to find a usable connection
|
||||
while let Some(conn) = inner.connections.pop() {
|
||||
// Check if connection is still alive
|
||||
if is_connection_alive(&conn.stream) {
|
||||
debug!(addr = %addr, "Reusing pooled connection");
|
||||
return Ok(conn.stream);
|
||||
}
|
||||
debug!(addr = %addr, "Discarding dead pooled connection");
|
||||
}
|
||||
|
||||
// Check if we can create a new connection
|
||||
let total = inner.connections.len() + inner.pending;
|
||||
if total >= self.config.max_connections {
|
||||
return Err(ProxyError::ConnectionTimeout {
|
||||
addr: addr.to_string()
|
||||
});
|
||||
}
|
||||
|
||||
inner.pending += 1;
|
||||
}
|
||||
|
||||
// Create new connection
|
||||
debug!(addr = %addr, "Creating new connection");
|
||||
let result = self.create_connection(addr).await;
|
||||
|
||||
// Decrement pending count
|
||||
{
|
||||
let mut inner = pool.lock().await;
|
||||
inner.pending = inner.pending.saturating_sub(1);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Create a new connection to the address
|
||||
async fn create_connection(&self, addr: SocketAddr) -> Result<TcpStream> {
|
||||
let connect_future = TcpStream::connect(addr);
|
||||
|
||||
let stream = timeout(self.config.connect_timeout, connect_future)
|
||||
.await
|
||||
.map_err(|_| ProxyError::ConnectionTimeout {
|
||||
addr: addr.to_string()
|
||||
})?
|
||||
.map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::ConnectionRefused {
|
||||
ProxyError::ConnectionRefused { addr: addr.to_string() }
|
||||
} else {
|
||||
ProxyError::Io(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
// Configure socket
|
||||
configure_tcp_socket(
|
||||
&stream,
|
||||
self.config.keepalive,
|
||||
self.config.keepalive_interval,
|
||||
)?;
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
/// Return a connection to the pool
|
||||
pub async fn put(&self, addr: SocketAddr, stream: TcpStream) {
|
||||
let pool = self.get_or_create_pool(addr);
|
||||
let mut inner = pool.lock().await;
|
||||
|
||||
if inner.connections.len() < self.config.max_connections {
|
||||
inner.connections.push(PooledConnection {
|
||||
stream,
|
||||
created_at: Instant::now(),
|
||||
});
|
||||
debug!(addr = %addr, pool_size = inner.connections.len(), "Returned connection to pool");
|
||||
} else {
|
||||
debug!(addr = %addr, "Pool full, dropping connection");
|
||||
}
|
||||
}
|
||||
|
||||
/// Close all pooled connections
|
||||
pub async fn close_all(&self) {
|
||||
let pools = self.pools.read();
|
||||
for (addr, pool) in pools.iter() {
|
||||
let mut inner = pool.lock().await;
|
||||
let count = inner.connections.len();
|
||||
inner.connections.clear();
|
||||
debug!(addr = %addr, count = count, "Closed pooled connections");
|
||||
}
|
||||
}
|
||||
|
||||
/// Get pool statistics
|
||||
pub async fn stats(&self) -> PoolStats {
|
||||
let pools = self.pools.read();
|
||||
let mut total_connections = 0;
|
||||
let mut total_pending = 0;
|
||||
let mut endpoints = 0;
|
||||
|
||||
for pool in pools.values() {
|
||||
let inner = pool.lock().await;
|
||||
total_connections += inner.connections.len();
|
||||
total_pending += inner.pending;
|
||||
endpoints += 1;
|
||||
}
|
||||
|
||||
PoolStats {
|
||||
endpoints,
|
||||
total_connections,
|
||||
total_pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ConnectionPool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Pool statistics
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PoolStats {
|
||||
pub endpoints: usize,
|
||||
pub total_connections: usize,
|
||||
pub total_pending: usize,
|
||||
}
|
||||
|
||||
/// Check if a TCP connection is still alive (non-blocking)
|
||||
fn is_connection_alive(stream: &TcpStream) -> bool {
|
||||
// Try a non-blocking read to check connection state
|
||||
let mut buf = [0u8; 1];
|
||||
match stream.try_read(&mut buf) {
|
||||
Ok(0) => false, // Connection closed
|
||||
Ok(_) => true, // Data available (shouldn't happen, but connection is alive)
|
||||
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => true, // No data, but alive
|
||||
Err(_) => false, // Some error, assume dead
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection pool with custom initialization
|
||||
pub struct InitializingPool<F> {
|
||||
pool: ConnectionPool,
|
||||
init_fn: F,
|
||||
}
|
||||
|
||||
impl<F, Fut> InitializingPool<F>
|
||||
where
|
||||
F: Fn(TcpStream, SocketAddr) -> Fut + Send + Sync,
|
||||
Fut: std::future::Future<Output = Result<TcpStream>> + Send,
|
||||
{
|
||||
/// Create pool with initialization function
|
||||
pub fn new(config: PoolConfig, init_fn: F) -> Self {
|
||||
Self {
|
||||
pool: ConnectionPool::with_config(config),
|
||||
init_fn,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get an initialized connection
|
||||
pub async fn get(&self, addr: SocketAddr) -> Result<TcpStream> {
|
||||
let stream = self.pool.get(addr).await?;
|
||||
(self.init_fn)(stream, addr).await
|
||||
}
|
||||
|
||||
/// Return connection to pool
|
||||
pub async fn put(&self, addr: SocketAddr, stream: TcpStream) {
|
||||
self.pool.put(addr, stream).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pool_basic() {
|
||||
// Start a test server
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
// Accept connections in background
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let _ = listener.accept().await;
|
||||
}
|
||||
});
|
||||
|
||||
let pool = ConnectionPool::new();
|
||||
|
||||
// Get a connection
|
||||
let conn1 = pool.get(addr).await.unwrap();
|
||||
|
||||
// Return it to pool
|
||||
pool.put(addr, conn1).await;
|
||||
|
||||
// Get again (should reuse)
|
||||
let _conn2 = pool.get(addr).await.unwrap();
|
||||
|
||||
let stats = pool.stats().await;
|
||||
assert_eq!(stats.endpoints, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pool_connection_refused() {
|
||||
let pool = ConnectionPool::with_config(PoolConfig {
|
||||
connect_timeout: Duration::from_millis(100),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Try to connect to a port that's not listening
|
||||
let result = pool.get("127.0.0.1:1".parse().unwrap()).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pool_stats() {
|
||||
let pool = ConnectionPool::new();
|
||||
|
||||
let stats = pool.stats().await;
|
||||
assert_eq!(stats.endpoints, 0);
|
||||
assert_eq!(stats.total_connections, 0);
|
||||
}
|
||||
}
|
||||
381
src/transport/proxy_protocol.rs
Normal file
381
src/transport/proxy_protocol.rs
Normal file
@@ -0,0 +1,381 @@
|
||||
//! HAProxy PROXY protocol V1/V2
|
||||
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
use crate::error::{ProxyError, Result};
|
||||
|
||||
/// PROXY protocol v1 signature
|
||||
const PROXY_V1_SIGNATURE: &[u8] = b"PROXY ";
|
||||
|
||||
/// PROXY protocol v2 signature
|
||||
const PROXY_V2_SIGNATURE: &[u8] = &[
|
||||
0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a,
|
||||
0x51, 0x55, 0x49, 0x54, 0x0a
|
||||
];
|
||||
|
||||
/// Minimum length for v1 detection
|
||||
const PROXY_V1_MIN_LEN: usize = 6;
|
||||
|
||||
/// Minimum length for v2 header
|
||||
const PROXY_V2_MIN_LEN: usize = 16;
|
||||
|
||||
/// Address families for v2
|
||||
mod address_family {
|
||||
pub const UNSPEC: u8 = 0x0;
|
||||
pub const INET: u8 = 0x1;
|
||||
pub const INET6: u8 = 0x2;
|
||||
}
|
||||
|
||||
/// Information extracted from PROXY protocol header
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProxyProtocolInfo {
|
||||
/// Source (client) address
|
||||
pub src_addr: SocketAddr,
|
||||
/// Destination address (optional)
|
||||
pub dst_addr: Option<SocketAddr>,
|
||||
/// Protocol version used (1 or 2)
|
||||
pub version: u8,
|
||||
}
|
||||
|
||||
impl ProxyProtocolInfo {
|
||||
/// Create info with just source address
|
||||
pub fn new(src_addr: SocketAddr) -> Self {
|
||||
Self {
|
||||
src_addr,
|
||||
dst_addr: None,
|
||||
version: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse PROXY protocol header from a stream
|
||||
///
|
||||
/// Returns the parsed info or an error if the header is invalid.
|
||||
/// The stream position is advanced past the header.
|
||||
pub async fn parse_proxy_protocol<R: AsyncRead + Unpin>(
|
||||
reader: &mut R,
|
||||
default_peer: SocketAddr,
|
||||
) -> Result<ProxyProtocolInfo> {
|
||||
// Read enough bytes to detect version
|
||||
let mut header = [0u8; PROXY_V2_MIN_LEN];
|
||||
reader.read_exact(&mut header[..PROXY_V1_MIN_LEN]).await
|
||||
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
|
||||
|
||||
// Check for v1
|
||||
if header[..PROXY_V1_MIN_LEN] == PROXY_V1_SIGNATURE[..] {
|
||||
return parse_v1(reader, default_peer).await;
|
||||
}
|
||||
|
||||
// Read rest for v2 detection
|
||||
reader.read_exact(&mut header[PROXY_V1_MIN_LEN..]).await
|
||||
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
|
||||
|
||||
// Check for v2
|
||||
if header[..12] == PROXY_V2_SIGNATURE[..] {
|
||||
return parse_v2(reader, &header, default_peer).await;
|
||||
}
|
||||
|
||||
Err(ProxyError::InvalidProxyProtocol)
|
||||
}
|
||||
|
||||
/// Parse PROXY protocol v1
|
||||
async fn parse_v1<R: AsyncRead + Unpin>(
|
||||
reader: &mut R,
|
||||
default_peer: SocketAddr,
|
||||
) -> Result<ProxyProtocolInfo> {
|
||||
// Read until CRLF (max 107 bytes total for v1)
|
||||
let mut line = Vec::with_capacity(128);
|
||||
line.extend_from_slice(PROXY_V1_SIGNATURE);
|
||||
|
||||
loop {
|
||||
let mut byte = [0u8];
|
||||
reader.read_exact(&mut byte).await
|
||||
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
|
||||
line.push(byte[0]);
|
||||
|
||||
if line.ends_with(b"\r\n") {
|
||||
break;
|
||||
}
|
||||
|
||||
if line.len() > 256 {
|
||||
return Err(ProxyError::InvalidProxyProtocol);
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the line: PROXY TCP4/TCP6/UNKNOWN src_ip dst_ip src_port dst_port
|
||||
let line_str = std::str::from_utf8(&line[PROXY_V1_MIN_LEN..line.len() - 2])
|
||||
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
|
||||
|
||||
let parts: Vec<&str> = line_str.split_whitespace().collect();
|
||||
|
||||
if parts.is_empty() {
|
||||
return Err(ProxyError::InvalidProxyProtocol);
|
||||
}
|
||||
|
||||
match parts[0] {
|
||||
"TCP4" | "TCP6" if parts.len() >= 5 => {
|
||||
let src_ip: IpAddr = parts[1].parse()
|
||||
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
|
||||
let dst_ip: IpAddr = parts[2].parse()
|
||||
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
|
||||
let src_port: u16 = parts[3].parse()
|
||||
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
|
||||
let dst_port: u16 = parts[4].parse()
|
||||
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
|
||||
|
||||
Ok(ProxyProtocolInfo {
|
||||
src_addr: SocketAddr::new(src_ip, src_port),
|
||||
dst_addr: Some(SocketAddr::new(dst_ip, dst_port)),
|
||||
version: 1,
|
||||
})
|
||||
}
|
||||
"UNKNOWN" => {
|
||||
// UNKNOWN means no address info, use default
|
||||
Ok(ProxyProtocolInfo {
|
||||
src_addr: default_peer,
|
||||
dst_addr: None,
|
||||
version: 1,
|
||||
})
|
||||
}
|
||||
_ => Err(ProxyError::InvalidProxyProtocol),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse PROXY protocol v2
|
||||
async fn parse_v2<R: AsyncRead + Unpin>(
|
||||
reader: &mut R,
|
||||
header: &[u8; PROXY_V2_MIN_LEN],
|
||||
default_peer: SocketAddr,
|
||||
) -> Result<ProxyProtocolInfo> {
|
||||
let version_command = header[12];
|
||||
let version = version_command >> 4;
|
||||
let command = version_command & 0x0f;
|
||||
|
||||
// Must be version 2
|
||||
if version != 2 {
|
||||
return Err(ProxyError::InvalidProxyProtocol);
|
||||
}
|
||||
|
||||
let family_protocol = header[13];
|
||||
let addr_len = u16::from_be_bytes([header[14], header[15]]) as usize;
|
||||
|
||||
// Read address data
|
||||
let mut addr_data = vec![0u8; addr_len];
|
||||
if addr_len > 0 {
|
||||
reader.read_exact(&mut addr_data).await
|
||||
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
|
||||
}
|
||||
|
||||
// LOCAL command (0x0) - use default peer
|
||||
if command == 0 {
|
||||
return Ok(ProxyProtocolInfo {
|
||||
src_addr: default_peer,
|
||||
dst_addr: None,
|
||||
version: 2,
|
||||
});
|
||||
}
|
||||
|
||||
// PROXY command (0x1) - parse addresses
|
||||
if command != 1 {
|
||||
return Err(ProxyError::InvalidProxyProtocol);
|
||||
}
|
||||
|
||||
let family = family_protocol >> 4;
|
||||
|
||||
match family {
|
||||
address_family::INET if addr_len >= 12 => {
|
||||
// IPv4: 4 + 4 + 2 + 2 = 12 bytes
|
||||
let src_ip = Ipv4Addr::new(
|
||||
addr_data[0], addr_data[1],
|
||||
addr_data[2], addr_data[3]
|
||||
);
|
||||
let dst_ip = Ipv4Addr::new(
|
||||
addr_data[4], addr_data[5],
|
||||
addr_data[6], addr_data[7]
|
||||
);
|
||||
let src_port = u16::from_be_bytes([addr_data[8], addr_data[9]]);
|
||||
let dst_port = u16::from_be_bytes([addr_data[10], addr_data[11]]);
|
||||
|
||||
Ok(ProxyProtocolInfo {
|
||||
src_addr: SocketAddr::new(IpAddr::V4(src_ip), src_port),
|
||||
dst_addr: Some(SocketAddr::new(IpAddr::V4(dst_ip), dst_port)),
|
||||
version: 2,
|
||||
})
|
||||
}
|
||||
address_family::INET6 if addr_len >= 36 => {
|
||||
// IPv6: 16 + 16 + 2 + 2 = 36 bytes
|
||||
let src_ip = Ipv6Addr::from(
|
||||
<[u8; 16]>::try_from(&addr_data[0..16]).unwrap()
|
||||
);
|
||||
let dst_ip = Ipv6Addr::from(
|
||||
<[u8; 16]>::try_from(&addr_data[16..32]).unwrap()
|
||||
);
|
||||
let src_port = u16::from_be_bytes([addr_data[32], addr_data[33]]);
|
||||
let dst_port = u16::from_be_bytes([addr_data[34], addr_data[35]]);
|
||||
|
||||
Ok(ProxyProtocolInfo {
|
||||
src_addr: SocketAddr::new(IpAddr::V6(src_ip), src_port),
|
||||
dst_addr: Some(SocketAddr::new(IpAddr::V6(dst_ip), dst_port)),
|
||||
version: 2,
|
||||
})
|
||||
}
|
||||
address_family::UNSPEC => {
|
||||
Ok(ProxyProtocolInfo {
|
||||
src_addr: default_peer,
|
||||
dst_addr: None,
|
||||
version: 2,
|
||||
})
|
||||
}
|
||||
_ => Err(ProxyError::InvalidProxyProtocol),
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for PROXY protocol v1 header
|
||||
pub struct ProxyProtocolV1Builder {
|
||||
family: &'static str,
|
||||
src_addr: Option<SocketAddr>,
|
||||
dst_addr: Option<SocketAddr>,
|
||||
}
|
||||
|
||||
impl ProxyProtocolV1Builder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
family: "UNKNOWN",
|
||||
src_addr: None,
|
||||
dst_addr: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tcp4(mut self, src: SocketAddr, dst: SocketAddr) -> Self {
|
||||
self.family = "TCP4";
|
||||
self.src_addr = Some(src);
|
||||
self.dst_addr = Some(dst);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tcp6(mut self, src: SocketAddr, dst: SocketAddr) -> Self {
|
||||
self.family = "TCP6";
|
||||
self.src_addr = Some(src);
|
||||
self.dst_addr = Some(dst);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(&self) -> Vec<u8> {
|
||||
match (self.src_addr, self.dst_addr) {
|
||||
(Some(src), Some(dst)) => {
|
||||
format!(
|
||||
"PROXY {} {} {} {} {}\r\n",
|
||||
self.family,
|
||||
src.ip(),
|
||||
dst.ip(),
|
||||
src.port(),
|
||||
dst.port()
|
||||
).into_bytes()
|
||||
}
|
||||
_ => b"PROXY UNKNOWN\r\n".to_vec(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ProxyProtocolV1Builder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Cursor;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_v1_tcp4() {
|
||||
let header = b"PROXY TCP4 192.168.1.1 10.0.0.1 12345 443\r\n";
|
||||
let mut cursor = Cursor::new(&header[PROXY_V1_MIN_LEN..]);
|
||||
let default = "0.0.0.0:0".parse().unwrap();
|
||||
|
||||
// Simulate that we've already read the signature
|
||||
let info = parse_v1(&mut cursor, default).await.unwrap();
|
||||
|
||||
assert_eq!(info.version, 1);
|
||||
assert_eq!(info.src_addr.ip().to_string(), "192.168.1.1");
|
||||
assert_eq!(info.src_addr.port(), 12345);
|
||||
assert!(info.dst_addr.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_v1_unknown() {
|
||||
let header = b"PROXY UNKNOWN\r\n";
|
||||
let mut cursor = Cursor::new(&header[PROXY_V1_MIN_LEN..]);
|
||||
let default: SocketAddr = "1.2.3.4:5678".parse().unwrap();
|
||||
|
||||
let info = parse_v1(&mut cursor, default).await.unwrap();
|
||||
|
||||
assert_eq!(info.version, 1);
|
||||
assert_eq!(info.src_addr, default);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_v2_tcp4() {
|
||||
// v2 header for TCP4
|
||||
let mut header = [0u8; 16];
|
||||
header[..12].copy_from_slice(PROXY_V2_SIGNATURE);
|
||||
header[12] = 0x21; // v2, PROXY command
|
||||
header[13] = 0x11; // AF_INET, STREAM
|
||||
header[14] = 0x00;
|
||||
header[15] = 0x0c; // 12 bytes of address data
|
||||
|
||||
let addr_data = [
|
||||
192, 168, 1, 1, // src IP
|
||||
10, 0, 0, 1, // dst IP
|
||||
0x30, 0x39, // src port (12345)
|
||||
0x01, 0xbb, // dst port (443)
|
||||
];
|
||||
|
||||
let mut cursor = Cursor::new(addr_data.to_vec());
|
||||
let default = "0.0.0.0:0".parse().unwrap();
|
||||
|
||||
let info = parse_v2(&mut cursor, &header, default).await.unwrap();
|
||||
|
||||
assert_eq!(info.version, 2);
|
||||
assert_eq!(info.src_addr.ip().to_string(), "192.168.1.1");
|
||||
assert_eq!(info.src_addr.port(), 12345);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_v2_local() {
|
||||
let mut header = [0u8; 16];
|
||||
header[..12].copy_from_slice(PROXY_V2_SIGNATURE);
|
||||
header[12] = 0x20; // v2, LOCAL command
|
||||
header[13] = 0x00;
|
||||
header[14] = 0x00;
|
||||
header[15] = 0x00; // 0 bytes of address data
|
||||
|
||||
let mut cursor = Cursor::new(Vec::new());
|
||||
let default: SocketAddr = "1.2.3.4:5678".parse().unwrap();
|
||||
|
||||
let info = parse_v2(&mut cursor, &header, default).await.unwrap();
|
||||
|
||||
assert_eq!(info.version, 2);
|
||||
assert_eq!(info.src_addr, default);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v1_builder() {
|
||||
let src: SocketAddr = "192.168.1.1:12345".parse().unwrap();
|
||||
let dst: SocketAddr = "10.0.0.1:443".parse().unwrap();
|
||||
|
||||
let header = ProxyProtocolV1Builder::new()
|
||||
.tcp4(src, dst)
|
||||
.build();
|
||||
|
||||
let expected = b"PROXY TCP4 192.168.1.1 10.0.0.1 12345 443\r\n";
|
||||
assert_eq!(header, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v1_builder_unknown() {
|
||||
let header = ProxyProtocolV1Builder::new().build();
|
||||
assert_eq!(header, b"PROXY UNKNOWN\r\n");
|
||||
}
|
||||
}
|
||||
230
src/transport/socket.rs
Normal file
230
src/transport/socket.rs
Normal file
@@ -0,0 +1,230 @@
|
||||
//! TCP Socket Configuration
|
||||
|
||||
use std::io::Result;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use socket2::{Socket, TcpKeepalive, Domain, Type, Protocol};
|
||||
use tracing::debug;
|
||||
|
||||
/// Configure TCP socket with recommended settings for proxy use
|
||||
pub fn configure_tcp_socket(
|
||||
stream: &TcpStream,
|
||||
keepalive: bool,
|
||||
keepalive_interval: Duration,
|
||||
) -> Result<()> {
|
||||
let socket = socket2::SockRef::from(stream);
|
||||
|
||||
// Disable Nagle's algorithm for lower latency
|
||||
socket.set_nodelay(true)?;
|
||||
|
||||
// Set keepalive if enabled
|
||||
if keepalive {
|
||||
let keepalive = TcpKeepalive::new()
|
||||
.with_time(keepalive_interval);
|
||||
|
||||
// Platform-specific keepalive settings
|
||||
#[cfg(any(target_os = "linux", target_os = "macos", target_os = "ios"))]
|
||||
let keepalive = keepalive.with_interval(keepalive_interval);
|
||||
|
||||
socket.set_tcp_keepalive(&keepalive)?;
|
||||
}
|
||||
|
||||
// Set buffer sizes
|
||||
set_buffer_sizes(&socket, 65536, 65536)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set socket buffer sizes
|
||||
fn set_buffer_sizes(socket: &socket2::SockRef, recv: usize, send: usize) -> Result<()> {
|
||||
// These may fail on some systems, so we ignore errors
|
||||
let _ = socket.set_recv_buffer_size(recv);
|
||||
let _ = socket.set_send_buffer_size(send);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configure socket for accepting client connections
|
||||
pub fn configure_client_socket(
|
||||
stream: &TcpStream,
|
||||
keepalive_secs: u64,
|
||||
ack_timeout_secs: u64,
|
||||
) -> Result<()> {
|
||||
let socket = socket2::SockRef::from(stream);
|
||||
|
||||
// Disable Nagle's algorithm
|
||||
socket.set_nodelay(true)?;
|
||||
|
||||
// Set keepalive
|
||||
let keepalive = TcpKeepalive::new()
|
||||
.with_time(Duration::from_secs(keepalive_secs));
|
||||
|
||||
#[cfg(any(target_os = "linux", target_os = "macos", target_os = "ios"))]
|
||||
let keepalive = keepalive.with_interval(Duration::from_secs(keepalive_secs));
|
||||
|
||||
socket.set_tcp_keepalive(&keepalive)?;
|
||||
|
||||
// Set TCP user timeout (Linux only)
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
use std::os::unix::io::AsRawFd;
|
||||
let fd = stream.as_raw_fd();
|
||||
let timeout_ms = (ack_timeout_secs * 1000) as libc::c_int;
|
||||
unsafe {
|
||||
libc::setsockopt(
|
||||
fd,
|
||||
libc::IPPROTO_TCP,
|
||||
libc::TCP_USER_TIMEOUT,
|
||||
&timeout_ms as *const _ as *const libc::c_void,
|
||||
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set socket to send RST on close (for masking)
|
||||
pub fn set_linger_zero(stream: &TcpStream) -> Result<()> {
|
||||
let socket = socket2::SockRef::from(stream);
|
||||
socket.set_linger(Some(Duration::ZERO))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a new TCP socket for outgoing connections
|
||||
pub fn create_outgoing_socket(addr: SocketAddr) -> Result<Socket> {
|
||||
let domain = if addr.is_ipv4() {
|
||||
Domain::IPV4
|
||||
} else {
|
||||
Domain::IPV6
|
||||
};
|
||||
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
|
||||
|
||||
// Set non-blocking
|
||||
socket.set_nonblocking(true)?;
|
||||
|
||||
// Disable Nagle
|
||||
socket.set_nodelay(true)?;
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
/// Get local address of a socket
|
||||
pub fn get_local_addr(stream: &TcpStream) -> Option<SocketAddr> {
|
||||
stream.local_addr().ok()
|
||||
}
|
||||
|
||||
/// Get peer address of a socket
|
||||
pub fn get_peer_addr(stream: &TcpStream) -> Option<SocketAddr> {
|
||||
stream.peer_addr().ok()
|
||||
}
|
||||
|
||||
/// Check if address is IPv6
|
||||
pub fn is_ipv6(addr: &SocketAddr) -> bool {
|
||||
addr.is_ipv6()
|
||||
}
|
||||
|
||||
/// Parse IPv4-mapped IPv6 address to IPv4
|
||||
pub fn normalize_ip(addr: SocketAddr) -> SocketAddr {
|
||||
match addr {
|
||||
SocketAddr::V6(v6) => {
|
||||
if let Some(v4) = v6.ip().to_ipv4_mapped() {
|
||||
SocketAddr::new(std::net::IpAddr::V4(v4), v6.port())
|
||||
} else {
|
||||
addr
|
||||
}
|
||||
}
|
||||
_ => addr,
|
||||
}
|
||||
}
|
||||
|
||||
/// Socket options for server listening
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ListenOptions {
|
||||
/// Enable SO_REUSEADDR
|
||||
pub reuse_addr: bool,
|
||||
/// Enable SO_REUSEPORT (Linux/BSD)
|
||||
pub reuse_port: bool,
|
||||
/// Backlog size
|
||||
pub backlog: u32,
|
||||
/// IPv6 only (disable dual-stack)
|
||||
pub ipv6_only: bool,
|
||||
}
|
||||
|
||||
impl Default for ListenOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
reuse_addr: true,
|
||||
reuse_port: true,
|
||||
backlog: 1024,
|
||||
ipv6_only: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a listening socket with the specified options
|
||||
pub fn create_listener(addr: SocketAddr, options: &ListenOptions) -> Result<Socket> {
|
||||
let domain = if addr.is_ipv4() {
|
||||
Domain::IPV4
|
||||
} else {
|
||||
Domain::IPV6
|
||||
};
|
||||
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
|
||||
|
||||
if options.reuse_addr {
|
||||
socket.set_reuse_address(true)?;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
if options.reuse_port {
|
||||
socket.set_reuse_port(true)?;
|
||||
}
|
||||
|
||||
if addr.is_ipv6() && options.ipv6_only {
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
|
||||
socket.set_nonblocking(true)?;
|
||||
socket.bind(&addr.into())?;
|
||||
socket.listen(options.backlog as i32)?;
|
||||
|
||||
debug!(addr = %addr, "Created listening socket");
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_configure_socket() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let stream = TcpStream::connect(addr).await.unwrap();
|
||||
configure_tcp_socket(&stream, true, Duration::from_secs(30)).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_ip() {
|
||||
// IPv4 stays IPv4
|
||||
let v4: SocketAddr = "192.168.1.1:8080".parse().unwrap();
|
||||
assert_eq!(normalize_ip(v4), v4);
|
||||
|
||||
// Pure IPv6 stays IPv6
|
||||
let v6: SocketAddr = "[::1]:8080".parse().unwrap();
|
||||
assert_eq!(normalize_ip(v6), v6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_listen_options_default() {
|
||||
let opts = ListenOptions::default();
|
||||
assert!(opts.reuse_addr);
|
||||
assert!(opts.reuse_port);
|
||||
assert_eq!(opts.backlog, 1024);
|
||||
}
|
||||
}
|
||||
118
src/util/ip.rs
Normal file
118
src/util/ip.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
//! IP Addr Detect
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
/// Detected IP addresses
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct IpInfo {
|
||||
pub ipv4: Option<IpAddr>,
|
||||
pub ipv6: Option<IpAddr>,
|
||||
}
|
||||
|
||||
impl IpInfo {
|
||||
/// Check if any IP is detected
|
||||
pub fn has_any(&self) -> bool {
|
||||
self.ipv4.is_some() || self.ipv6.is_some()
|
||||
}
|
||||
|
||||
/// Get preferred IP (IPv6 if available and preferred)
|
||||
pub fn preferred(&self, prefer_ipv6: bool) -> Option<IpAddr> {
|
||||
if prefer_ipv6 {
|
||||
self.ipv6.or(self.ipv4)
|
||||
} else {
|
||||
self.ipv4.or(self.ipv6)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// URLs for IP detection
|
||||
const IPV4_URLS: &[&str] = &[
|
||||
"http://v4.ident.me/",
|
||||
"http://ipv4.icanhazip.com/",
|
||||
"http://api.ipify.org/",
|
||||
];
|
||||
|
||||
const IPV6_URLS: &[&str] = &[
|
||||
"http://v6.ident.me/",
|
||||
"http://ipv6.icanhazip.com/",
|
||||
"http://api6.ipify.org/",
|
||||
];
|
||||
|
||||
/// Detect public IP addresses
|
||||
pub async fn detect_ip() -> IpInfo {
|
||||
let mut info = IpInfo::default();
|
||||
|
||||
// Detect IPv4
|
||||
for url in IPV4_URLS {
|
||||
if let Some(ip) = fetch_ip(url).await {
|
||||
if ip.is_ipv4() {
|
||||
info.ipv4 = Some(ip);
|
||||
debug!(ip = %ip, "Detected IPv4 address");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Detect IPv6
|
||||
for url in IPV6_URLS {
|
||||
if let Some(ip) = fetch_ip(url).await {
|
||||
if ip.is_ipv6() {
|
||||
info.ipv6 = Some(ip);
|
||||
debug!(ip = %ip, "Detected IPv6 address");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !info.has_any() {
|
||||
warn!("Failed to detect public IP address");
|
||||
}
|
||||
|
||||
info
|
||||
}
|
||||
|
||||
/// Fetch IP from URL
|
||||
async fn fetch_ip(url: &str) -> Option<IpAddr> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(5))
|
||||
.build()
|
||||
.ok()?;
|
||||
|
||||
let response = client.get(url).send().await.ok()?;
|
||||
let text = response.text().await.ok()?;
|
||||
|
||||
text.trim().parse().ok()
|
||||
}
|
||||
|
||||
/// Synchronous IP detection (for startup)
|
||||
pub fn detect_ip_sync() -> IpInfo {
|
||||
tokio::runtime::Handle::current().block_on(detect_ip())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ip_info() {
|
||||
let info = IpInfo::default();
|
||||
assert!(!info.has_any());
|
||||
|
||||
let info = IpInfo {
|
||||
ipv4: Some("1.2.3.4".parse().unwrap()),
|
||||
ipv6: None,
|
||||
};
|
||||
assert!(info.has_any());
|
||||
assert_eq!(info.preferred(false), Some("1.2.3.4".parse().unwrap()));
|
||||
assert_eq!(info.preferred(true), Some("1.2.3.4".parse().unwrap()));
|
||||
|
||||
let info = IpInfo {
|
||||
ipv4: Some("1.2.3.4".parse().unwrap()),
|
||||
ipv6: Some("::1".parse().unwrap()),
|
||||
};
|
||||
assert_eq!(info.preferred(false), Some("1.2.3.4".parse().unwrap()));
|
||||
assert_eq!(info.preferred(true), Some("::1".parse().unwrap()));
|
||||
}
|
||||
}
|
||||
7
src/util/mod.rs
Normal file
7
src/util/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
//! Utils
|
||||
|
||||
pub mod ip;
|
||||
pub mod time;
|
||||
|
||||
pub use ip::*;
|
||||
pub use time::*;
|
||||
76
src/util/time.rs
Normal file
76
src/util/time.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
//! Time Sync
|
||||
|
||||
use std::time::Duration;
|
||||
use chrono::{DateTime, Utc};
|
||||
use tracing::{debug, warn, error};
|
||||
|
||||
const TIME_SYNC_URL: &str = "https://core.telegram.org/getProxySecret";
|
||||
const MAX_TIME_SKEW_SECS: i64 = 30;
|
||||
|
||||
/// Time sync result
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TimeSyncResult {
|
||||
pub server_time: DateTime<Utc>,
|
||||
pub local_time: DateTime<Utc>,
|
||||
pub skew_secs: i64,
|
||||
pub is_skewed: bool,
|
||||
}
|
||||
|
||||
/// Check time synchronization with Telegram servers
|
||||
pub async fn check_time_sync() -> Option<TimeSyncResult> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(10))
|
||||
.build()
|
||||
.ok()?;
|
||||
|
||||
let response = client.get(TIME_SYNC_URL).send().await.ok()?;
|
||||
|
||||
// Get Date header
|
||||
let date_header = response.headers().get("date")?;
|
||||
let date_str = date_header.to_str().ok()?;
|
||||
|
||||
// Parse date
|
||||
let server_time = DateTime::parse_from_rfc2822(date_str)
|
||||
.ok()?
|
||||
.with_timezone(&Utc);
|
||||
|
||||
let local_time = Utc::now();
|
||||
let skew_secs = (local_time - server_time).num_seconds();
|
||||
let is_skewed = skew_secs.abs() > MAX_TIME_SKEW_SECS;
|
||||
|
||||
let result = TimeSyncResult {
|
||||
server_time,
|
||||
local_time,
|
||||
skew_secs,
|
||||
is_skewed,
|
||||
};
|
||||
|
||||
if is_skewed {
|
||||
warn!(
|
||||
server = %server_time,
|
||||
local = %local_time,
|
||||
skew = skew_secs,
|
||||
"Time skew detected"
|
||||
);
|
||||
} else {
|
||||
debug!(skew = skew_secs, "Time sync OK");
|
||||
}
|
||||
|
||||
Some(result)
|
||||
}
|
||||
|
||||
/// Background time sync task
|
||||
pub async fn time_sync_task(check_interval: Duration) -> ! {
|
||||
loop {
|
||||
if let Some(result) = check_time_sync().await {
|
||||
if result.is_skewed {
|
||||
error!(
|
||||
"System clock is off by {} seconds. Please sync your clock.",
|
||||
result.skew_secs
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(check_interval).await;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user