Compare commits
49 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fa6867056 | ||
|
|
54ea6efdd0 | ||
|
|
27ac32a901 | ||
|
|
829f53c123 | ||
|
|
43eae6127d | ||
|
|
a03212c8cc | ||
|
|
2613969a7c | ||
|
|
be1b2db867 | ||
|
|
8fbee8701b | ||
|
|
952d160870 | ||
|
|
91ae6becde | ||
|
|
e1f576e4fe | ||
|
|
a7556cabdc | ||
|
|
b2e8d16bb1 | ||
|
|
d95e762812 | ||
|
|
384f927fc3 | ||
|
|
1b7c09ae18 | ||
|
|
85cb4092d5 | ||
|
|
5016160ac3 | ||
|
|
4f007f3128 | ||
|
|
7746a1177c | ||
|
|
2bb2a2983f | ||
|
|
5778be4f6e | ||
|
|
f443d3dfc7 | ||
|
|
450cf180ad | ||
|
|
84fa7face0 | ||
|
|
f8a2ea1972 | ||
|
|
96d0a6bdfa | ||
|
|
eeee55e8ea | ||
|
|
7be179b3c0 | ||
|
|
b2e034f8f1 | ||
|
|
ffe5a6cfb7 | ||
|
|
0e096ca8fb | ||
|
|
50658525cf | ||
|
|
4fd5ff4e83 | ||
|
|
df4f312fec | ||
|
|
7d9a8b99b4 | ||
|
|
06f34e55cd | ||
|
|
153cb7f3a3 | ||
|
|
7f8904a989 | ||
|
|
0ee71a59a0 | ||
|
|
45c7347e22 | ||
|
|
3805237d74 | ||
|
|
5b281bf7fd | ||
|
|
d64cccd52c | ||
|
|
016fdada68 | ||
|
|
2c2ceeaf54 | ||
|
|
dd6badd786 | ||
|
|
50e72368c8 |
41
.github/workflows/rust.yml
vendored
Normal file
41
.github/workflows/rust.yml
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
name: Rust
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install latest stable Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
components: rustfmt, clippy
|
||||
|
||||
- name: Cache cargo registry & build artifacts
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- name: Build Release
|
||||
run: cargo build --release --verbose
|
||||
|
||||
- name: Check for unused dependencies
|
||||
run: cargo udeps || true
|
||||
@@ -46,6 +46,7 @@ base64 = "0.21"
|
||||
url = "2.5"
|
||||
regex = "1.10"
|
||||
once_cell = "1.19"
|
||||
crossbeam-queue = "0.3"
|
||||
|
||||
# HTTP
|
||||
reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false }
|
||||
|
||||
255
README.md
255
README.md
@@ -1,2 +1,253 @@
|
||||
# telemt
|
||||
MTProxy for Telegram on Rust + Tokio
|
||||
# Telemt - MTProxy on Rust + Tokio
|
||||
|
||||
**Telemt** is a fast, secure, and feature-rich server written in Rust: it fully implements the official Telegram proxy algo and adds many production-ready improvements such as connection pooling, replay protection, detailed statistics, masking from "prying" eyes
|
||||
|
||||
# GOTO
|
||||
- [Features](#features)
|
||||
- [Quick Start Guide](#quick-start-guide)
|
||||
- [How to use?](#how-to-use)
|
||||
- [Systemd Method](#telemt-via-systemd)
|
||||
- [Configuration](#configuration)
|
||||
- [Minimal Configuration](#minimal-configuration-for-first-start)
|
||||
- [Advanced](#advanced)
|
||||
- [Adtag](#adtag)
|
||||
- [Listening and Announce IPs](#listening-and-announce-ips)
|
||||
- [Upstream Manager](#upstream-manager)
|
||||
- [IP](#bind-on-ip)
|
||||
- [SOCKS](#socks45-as-upstream)
|
||||
- [FAQ](#faq)
|
||||
- [Telegram Calls](#telegram-calls-via-mtproxy)
|
||||
- [DPI](#how-does-dpi-see-mtproxy-tls)
|
||||
- [Whitelist on Network Level](#whitelist-on-ip)
|
||||
- [Build](#build)
|
||||
- [Why Rust?](#why-rust)
|
||||
|
||||
## Features
|
||||
|
||||
- Full support for all official MTProto proxy modes:
|
||||
- Classic
|
||||
- Secure - with `dd` prefix
|
||||
- Fake TLS - with `ee` prefix + SNI fronting
|
||||
- Replay attack protection
|
||||
- Optional traffic masking: forward unrecognized connections to a real web server, e.g. GitHub 🤪
|
||||
- Configurable keepalives + timeouts + IPv6 and "Fast Mode"
|
||||
- Graceful shutdown on Ctrl+C
|
||||
- Extensive logging via `trace` and `debug` with `RUST_LOG` method
|
||||
|
||||
## Quick Start Guide
|
||||
**This software is designed for Debian-based OS: in addition to Debian, these are Ubuntu, Mint, Kali, MX and many other Linux**
|
||||
1. Download release
|
||||
```bash
|
||||
wget https://github.com/telemt/telemt/releases/latest/download/telemt
|
||||
```
|
||||
2. Move to Bin Folder
|
||||
```bash
|
||||
mv telemt /bin
|
||||
```
|
||||
4. Make Executable
|
||||
```bash
|
||||
chmod +x /bin/telemt
|
||||
```
|
||||
5. Go to [How to use?](#how-to-use) section for for further steps
|
||||
|
||||
## How to use?
|
||||
### Telemt via Systemd
|
||||
**This instruction "assume" that you:**
|
||||
- logged in as root or executed `su -` / `sudo su`
|
||||
- you already have an assembled and executable `telemt` in /bin folder as a result of the [Quick Start Guide](#quick-start-guide) or [Build](#build)
|
||||
|
||||
**0. Check port and generate secrets**
|
||||
|
||||
The port you have selected for use should be MISSING from the list, when:
|
||||
```bash
|
||||
netstat -lnp
|
||||
```
|
||||
|
||||
Generate 16 bytes/32 characters HEX with OpenSSL or another way:
|
||||
```bash
|
||||
openssl rand -hex 16
|
||||
```
|
||||
OR
|
||||
```bash
|
||||
xxd -l 16 -p /dev/urandom
|
||||
```
|
||||
OR
|
||||
```bash
|
||||
python3 -c 'import os; print(os.urandom(16).hex())'
|
||||
```
|
||||
|
||||
**1. Place your config to /etc/telemt.toml**
|
||||
|
||||
Open nano
|
||||
```bash
|
||||
nano /etc/telemt.toml
|
||||
```
|
||||
paste your config from [Configuration](#configuration) section
|
||||
|
||||
then Ctrl+X -> Y -> Enter to save
|
||||
|
||||
**2. Create service on /etc/systemd/system/telemt.service**
|
||||
|
||||
Open nano
|
||||
```bash
|
||||
nano /etc/systemd/system/telemt.service
|
||||
```
|
||||
paste this Systemd Module
|
||||
```bash
|
||||
[Unit]
|
||||
Description=Telemt
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
WorkingDirectory=/bin
|
||||
ExecStart=/bin/telemt /etc/telemt.toml
|
||||
Restart=on-failure
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
then Ctrl+X -> Y -> Enter to save
|
||||
|
||||
**3.** In Shell type `systemctl start telemt` - it must start with zero exit-code
|
||||
|
||||
**4.** In Shell type `systemctl status telemt` - there you can reach info about current MTProxy status
|
||||
|
||||
**5.** In Shell type `systemctl enable telemt` - then telemt will start with system startup, after the network is up
|
||||
|
||||
## Configuration
|
||||
### Minimal Configuration for First Start
|
||||
```toml
|
||||
port = 443 # Listening port
|
||||
show_links = ["tele", "hello"] # Specify users, for whom will be displayed the links
|
||||
|
||||
[users]
|
||||
tele = "00000000000000000000000000000000" # Replace the secret with one generated before
|
||||
hello = "00000000000000000000000000000000" # Replace the secret with one generated before
|
||||
|
||||
[modes]
|
||||
classic = false # Plain obfuscated mode
|
||||
secure = false # dd-prefix mode
|
||||
tls = true # Fake TLS - ee-prefix
|
||||
|
||||
tls_domain = "petrovich.ru" # Domain for ee-secret and masking
|
||||
mask = true # Enable masking of bad traffic
|
||||
mask_host = "petrovich.ru" # Optional override for mask destination
|
||||
mask_port = 443 # Port for masking
|
||||
|
||||
prefer_ipv6 = false # Try IPv6 DCs first if true
|
||||
fast_mode = true # Use "fast" obfuscation variant
|
||||
|
||||
client_keepalive = 600 # Seconds
|
||||
client_ack_timeout = 300 # Seconds
|
||||
```
|
||||
### Advanced
|
||||
#### Adtag
|
||||
To use channel advertising and usage statistics from Telegram, get Adtag from [@mtproxybot](https://t.me/mtproxybot), add this parameter to the end of config.toml and specify it
|
||||
```toml
|
||||
ad_tag = "00000000000000000000000000000000" # Replace zeros to your adtag from @mtproxybot
|
||||
```
|
||||
#### Listening and Announce IPs
|
||||
To specify listening address and/or address in links, add to the end of config.toml:
|
||||
```toml
|
||||
[[listeners]]
|
||||
ip = "0.0.0.0" # 0.0.0.0 = all IPs; your IP = specific listening
|
||||
announce_ip = "1.2.3.4" # IP in links; comment with # if not used
|
||||
```
|
||||
#### Upstream Manager
|
||||
To specify upstream, add to the end of config.toml:
|
||||
##### Bind on IP
|
||||
```toml
|
||||
[[upstreams]]
|
||||
type = "direct"
|
||||
weight = 1
|
||||
enabled = true
|
||||
interface = "192.168.1.100" # Change to your outgoing IP
|
||||
```
|
||||
##### SOCKS4/5 as Upstream
|
||||
- Without Auth:
|
||||
```toml
|
||||
[[upstreams]]
|
||||
type = "socks5" # Specify SOCKS4 or SOCKS5
|
||||
address = "1.2.3.4:1234" # SOCKS-server Address
|
||||
weight = 1 # Set Weight for Scenarios
|
||||
enabled = true
|
||||
```
|
||||
|
||||
- With Auth:
|
||||
```toml
|
||||
[[upstreams]]
|
||||
type = "socks5" # Specify SOCKS4 or SOCKS5
|
||||
address = "1.2.3.4:1234" # SOCKS-server Address
|
||||
username = "user" # Username for Auth on SOCKS-server
|
||||
password = "pass" # Password for Auth on SOCKS-server
|
||||
weight = 1 # Set Weight for Scenarios
|
||||
enabled = true
|
||||
```
|
||||
|
||||
## FAQ
|
||||
### Telegram Calls via MTProxy
|
||||
- Telegram architecture **does NOT allow calls via MTProxy**, but only via SOCKS5, which cannot be obfuscated
|
||||
### How does DPI see MTProxy TLS?
|
||||
- DPI sees MTProxy in Fake TLS (ee) mode as TLS 1.3
|
||||
- the SNI you specify sends both the client and the server;
|
||||
- ALPN is similar to HTTP 1.1/2;
|
||||
- high entropy, which is normal for AES-encrypted traffic;
|
||||
### Whitelist on IP
|
||||
- MTProxy cannot work when there is:
|
||||
- no IP connectivity to the target host: Russian Whitelist on Mobile Networks - "Белый список"
|
||||
- OR all TCP traffic is blocked
|
||||
- OR high entropy/encrypted traffic is blocked: content filters at universities and critical infrastructure
|
||||
- OR all TLS traffic is blocked
|
||||
- OR specified port is blocked: use 443 to make it "like real"
|
||||
- OR provided SNI is blocked: use "officially approved"/innocuous name
|
||||
- like most protocols on the Internet;
|
||||
- these situations are observed:
|
||||
- in China behind the Great Firewall
|
||||
- in Russia on mobile networks, less in wired networks
|
||||
- in Iran during "activity"
|
||||
|
||||
|
||||
## Build
|
||||
```bash
|
||||
# Cloning repo
|
||||
git clone https://github.com/telemt/telemt
|
||||
# Changing Directory to telemt
|
||||
cd telemt
|
||||
# Starting Release Build
|
||||
cargo build --release
|
||||
# Move to /bin
|
||||
mv ./target/release/telemt /bin
|
||||
# Make executable
|
||||
chmod +x /bin/telemt
|
||||
# Lets go!
|
||||
telemt config.toml
|
||||
```
|
||||
|
||||
## Why Rust?
|
||||
- Long-running reliability and idempotent behavior
|
||||
- Rust’s deterministic resource management - RAII
|
||||
- No garbage collector
|
||||
- Memory safety and reduced attack surface
|
||||
- Tokio's asynchronous architecture
|
||||
|
||||
## Issues
|
||||
- ✅ [SOCKS5 as Upstream](https://github.com/telemt/telemt/issues/1) -> added Upstream Management
|
||||
- ⌛ [iOS - Media Upload Hanging-in-Loop](https://github.com/telemt/telemt/issues/2)
|
||||
|
||||
## Roadmap
|
||||
- Public IP in links
|
||||
- Config Reload-on-fly
|
||||
- Bind to device or IP for outbound/inbound connections
|
||||
- Adtag Support per SNI / Secret
|
||||
- Fail-fast on start + Fail-soft on runtime (only WARN/ERROR)
|
||||
- Zero-copy, minimal allocs on hotpath
|
||||
- DC Healthchecks + global fallback
|
||||
- No global mutable state
|
||||
- Client isolation + Fair Bandwidth
|
||||
- Backpressure-aware IO
|
||||
- "Secret Policy" - SNI / Secret Routing :D
|
||||
- Multi-upstream Balancer and Failover
|
||||
- Strict FSM per handshake
|
||||
- Session-based Antireplay with Sliding window, non-broking reconnects
|
||||
- Web Control: statistic, state of health, latency, client experience...
|
||||
|
||||
@@ -18,6 +18,7 @@ pub struct ProxyModes {
|
||||
}
|
||||
|
||||
fn default_true() -> bool { true }
|
||||
fn default_weight() -> u16 { 1 }
|
||||
|
||||
impl Default for ProxyModes {
|
||||
fn default() -> Self {
|
||||
@@ -25,6 +26,48 @@ impl Default for ProxyModes {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum UpstreamType {
|
||||
Direct {
|
||||
#[serde(default)]
|
||||
interface: Option<String>, // Bind to specific IP/Interface
|
||||
},
|
||||
Socks4 {
|
||||
address: String, // IP:Port of SOCKS server
|
||||
#[serde(default)]
|
||||
interface: Option<String>, // Bind to specific IP/Interface for connection to SOCKS
|
||||
#[serde(default)]
|
||||
user_id: Option<String>,
|
||||
},
|
||||
Socks5 {
|
||||
address: String,
|
||||
#[serde(default)]
|
||||
interface: Option<String>,
|
||||
#[serde(default)]
|
||||
username: Option<String>,
|
||||
#[serde(default)]
|
||||
password: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UpstreamConfig {
|
||||
#[serde(flatten)]
|
||||
pub upstream_type: UpstreamType,
|
||||
#[serde(default = "default_weight")]
|
||||
pub weight: u16,
|
||||
#[serde(default = "default_true")]
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ListenerConfig {
|
||||
pub ip: IpAddr,
|
||||
#[serde(default)]
|
||||
pub announce_ip: Option<IpAddr>, // IP to show in tg:// links
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProxyConfig {
|
||||
#[serde(default = "default_port")]
|
||||
@@ -104,15 +147,28 @@ pub struct ProxyConfig {
|
||||
|
||||
#[serde(default = "default_fake_cert_len")]
|
||||
pub fake_cert_len: usize,
|
||||
|
||||
// New fields
|
||||
#[serde(default)]
|
||||
pub upstreams: Vec<UpstreamConfig>,
|
||||
|
||||
#[serde(default)]
|
||||
pub listeners: Vec<ListenerConfig>,
|
||||
|
||||
#[serde(default)]
|
||||
pub show_link: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_port() -> u16 { 443 }
|
||||
fn default_tls_domain() -> String { "www.google.com".to_string() }
|
||||
fn default_mask_port() -> u16 { 443 }
|
||||
fn default_replay_check_len() -> usize { 65536 }
|
||||
fn default_handshake_timeout() -> u64 { 10 }
|
||||
// CHANGED: Increased handshake timeout for bad mobile networks
|
||||
fn default_handshake_timeout() -> u64 { 15 }
|
||||
fn default_connect_timeout() -> u64 { 10 }
|
||||
fn default_keepalive() -> u64 { 600 }
|
||||
// CHANGED: Reduced keepalive from 600s to 60s.
|
||||
// Mobile NATs often drop idle connections after 60-120s.
|
||||
fn default_keepalive() -> u64 { 60 }
|
||||
fn default_ack_timeout() -> u64 { 300 }
|
||||
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
|
||||
fn default_fake_cert_len() -> usize { 2048 }
|
||||
@@ -156,6 +212,9 @@ impl Default for ProxyConfig {
|
||||
metrics_port: None,
|
||||
metrics_whitelist: default_metrics_whitelist(),
|
||||
fake_cert_len: default_fake_cert_len(),
|
||||
upstreams: Vec::new(),
|
||||
listeners: Vec::new(),
|
||||
show_link: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -187,6 +246,33 @@ impl ProxyConfig {
|
||||
use rand::Rng;
|
||||
config.fake_cert_len = rand::thread_rng().gen_range(1024..4096);
|
||||
|
||||
// Migration: Populate listeners if empty
|
||||
if config.listeners.is_empty() {
|
||||
if let Ok(ipv4) = config.listen_addr_ipv4.parse::<IpAddr>() {
|
||||
config.listeners.push(ListenerConfig {
|
||||
ip: ipv4,
|
||||
announce_ip: None,
|
||||
});
|
||||
}
|
||||
if let Some(ipv6_str) = &config.listen_addr_ipv6 {
|
||||
if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() {
|
||||
config.listeners.push(ListenerConfig {
|
||||
ip: ipv6,
|
||||
announce_ip: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Migration: Populate upstreams if empty (Default Direct)
|
||||
if config.upstreams.is_empty() {
|
||||
config.upstreams.push(UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct { interface: None },
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
@@ -202,26 +288,3 @@ impl ProxyConfig {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = ProxyConfig::default();
|
||||
assert_eq!(config.port, 443);
|
||||
assert!(config.modes.tls);
|
||||
assert_eq!(config.client_keepalive, 600);
|
||||
assert_eq!(config.client_ack_timeout, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_validate() {
|
||||
let mut config = ProxyConfig::default();
|
||||
assert!(config.validate().is_ok());
|
||||
|
||||
config.users.clear();
|
||||
assert!(config.validate().is_err());
|
||||
}
|
||||
}
|
||||
@@ -1,21 +1,24 @@
|
||||
//! AES
|
||||
//! AES encryption implementations
|
||||
//!
|
||||
//! Provides AES-256-CTR and AES-256-CBC modes for MTProto encryption.
|
||||
|
||||
use aes::Aes256;
|
||||
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
|
||||
use cbc::{Encryptor as CbcEncryptor, Decryptor as CbcDecryptor};
|
||||
use cbc::cipher::{BlockEncryptMut, BlockDecryptMut, block_padding::NoPadding};
|
||||
use crate::error::{ProxyError, Result};
|
||||
|
||||
type Aes256Ctr = Ctr128BE<Aes256>;
|
||||
type Aes256CbcEnc = CbcEncryptor<Aes256>;
|
||||
type Aes256CbcDec = CbcDecryptor<Aes256>;
|
||||
|
||||
// ============= AES-256-CTR =============
|
||||
|
||||
/// AES-256-CTR encryptor/decryptor
|
||||
///
|
||||
/// CTR mode is symmetric - encryption and decryption are the same operation.
|
||||
pub struct AesCtr {
|
||||
cipher: Aes256Ctr,
|
||||
}
|
||||
|
||||
impl AesCtr {
|
||||
/// Create new AES-CTR cipher with key and IV
|
||||
pub fn new(key: &[u8; 32], iv: u128) -> Self {
|
||||
let iv_bytes = iv.to_be_bytes();
|
||||
Self {
|
||||
@@ -23,6 +26,7 @@ impl AesCtr {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create from key and IV slices
|
||||
pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result<Self> {
|
||||
if key.len() != 32 {
|
||||
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
|
||||
@@ -54,17 +58,28 @@ impl AesCtr {
|
||||
}
|
||||
}
|
||||
|
||||
/// AES-256-CBC Ciphermagic
|
||||
// ============= AES-256-CBC =============
|
||||
|
||||
/// AES-256-CBC cipher with proper chaining
|
||||
///
|
||||
/// Unlike CTR mode, CBC is NOT symmetric - encryption and decryption
|
||||
/// are different operations. This implementation handles CBC chaining
|
||||
/// correctly across multiple blocks.
|
||||
pub struct AesCbc {
|
||||
key: [u8; 32],
|
||||
iv: [u8; 16],
|
||||
}
|
||||
|
||||
impl AesCbc {
|
||||
/// AES block size
|
||||
const BLOCK_SIZE: usize = 16;
|
||||
|
||||
/// Create new AES-CBC cipher with key and IV
|
||||
pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self {
|
||||
Self { key, iv }
|
||||
}
|
||||
|
||||
/// Create from slices
|
||||
pub fn from_slices(key: &[u8], iv: &[u8]) -> Result<Self> {
|
||||
if key.len() != 32 {
|
||||
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
|
||||
@@ -79,32 +94,36 @@ impl AesCbc {
|
||||
})
|
||||
}
|
||||
|
||||
/// Encrypt data using CBC mode
|
||||
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
|
||||
if data.len() % 16 != 0 {
|
||||
return Err(ProxyError::Crypto(
|
||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||
));
|
||||
}
|
||||
|
||||
if data.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut buffer = data.to_vec();
|
||||
|
||||
let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into());
|
||||
|
||||
for chunk in buffer.chunks_mut(16) {
|
||||
encryptor.encrypt_block_mut(chunk.into());
|
||||
}
|
||||
|
||||
Ok(buffer)
|
||||
/// Encrypt a single block using raw AES (no chaining)
|
||||
fn encrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] {
|
||||
use aes::cipher::BlockEncrypt;
|
||||
let mut output = *block;
|
||||
key_schedule.encrypt_block((&mut output).into());
|
||||
output
|
||||
}
|
||||
|
||||
/// Decrypt data using CBC mode
|
||||
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
|
||||
if data.len() % 16 != 0 {
|
||||
/// Decrypt a single block using raw AES (no chaining)
|
||||
fn decrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] {
|
||||
use aes::cipher::BlockDecrypt;
|
||||
let mut output = *block;
|
||||
key_schedule.decrypt_block((&mut output).into());
|
||||
output
|
||||
}
|
||||
|
||||
/// XOR two 16-byte blocks
|
||||
fn xor_blocks(a: &[u8; 16], b: &[u8; 16]) -> [u8; 16] {
|
||||
let mut result = [0u8; 16];
|
||||
for i in 0..16 {
|
||||
result[i] = a[i] ^ b[i];
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Encrypt data using CBC mode with proper chaining
|
||||
///
|
||||
/// CBC Encryption: C[i] = AES_Encrypt(P[i] XOR C[i-1]), where C[-1] = IV
|
||||
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
|
||||
if data.len() % Self::BLOCK_SIZE != 0 {
|
||||
return Err(ProxyError::Crypto(
|
||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||
));
|
||||
@@ -114,20 +133,73 @@ impl AesCbc {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut buffer = data.to_vec();
|
||||
use aes::cipher::KeyInit;
|
||||
let key_schedule = aes::Aes256::new((&self.key).into());
|
||||
|
||||
let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into());
|
||||
let mut result = Vec::with_capacity(data.len());
|
||||
let mut prev_ciphertext = self.iv;
|
||||
|
||||
for chunk in buffer.chunks_mut(16) {
|
||||
decryptor.decrypt_block_mut(chunk.into());
|
||||
for chunk in data.chunks(Self::BLOCK_SIZE) {
|
||||
let plaintext: [u8; 16] = chunk.try_into().unwrap();
|
||||
|
||||
// XOR plaintext with previous ciphertext (or IV for first block)
|
||||
let xored = Self::xor_blocks(&plaintext, &prev_ciphertext);
|
||||
|
||||
// Encrypt the XORed block
|
||||
let ciphertext = self.encrypt_block(&xored, &key_schedule);
|
||||
|
||||
// Save for next iteration
|
||||
prev_ciphertext = ciphertext;
|
||||
|
||||
// Append to result
|
||||
result.extend_from_slice(&ciphertext);
|
||||
}
|
||||
|
||||
Ok(buffer)
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Decrypt data using CBC mode with proper chaining
|
||||
///
|
||||
/// CBC Decryption: P[i] = AES_Decrypt(C[i]) XOR C[i-1], where C[-1] = IV
|
||||
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
|
||||
if data.len() % Self::BLOCK_SIZE != 0 {
|
||||
return Err(ProxyError::Crypto(
|
||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||
));
|
||||
}
|
||||
|
||||
if data.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
use aes::cipher::KeyInit;
|
||||
let key_schedule = aes::Aes256::new((&self.key).into());
|
||||
|
||||
let mut result = Vec::with_capacity(data.len());
|
||||
let mut prev_ciphertext = self.iv;
|
||||
|
||||
for chunk in data.chunks(Self::BLOCK_SIZE) {
|
||||
let ciphertext: [u8; 16] = chunk.try_into().unwrap();
|
||||
|
||||
// Decrypt the block
|
||||
let decrypted = self.decrypt_block(&ciphertext, &key_schedule);
|
||||
|
||||
// XOR with previous ciphertext (or IV for first block)
|
||||
let plaintext = Self::xor_blocks(&decrypted, &prev_ciphertext);
|
||||
|
||||
// Save current ciphertext for next iteration
|
||||
prev_ciphertext = ciphertext;
|
||||
|
||||
// Append to result
|
||||
result.extend_from_slice(&plaintext);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Encrypt data in-place
|
||||
pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
|
||||
if data.len() % 16 != 0 {
|
||||
if data.len() % Self::BLOCK_SIZE != 0 {
|
||||
return Err(ProxyError::Crypto(
|
||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||
));
|
||||
@@ -137,10 +209,25 @@ impl AesCbc {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into());
|
||||
use aes::cipher::KeyInit;
|
||||
let key_schedule = aes::Aes256::new((&self.key).into());
|
||||
|
||||
for chunk in data.chunks_mut(16) {
|
||||
encryptor.encrypt_block_mut(chunk.into());
|
||||
let mut prev_ciphertext = self.iv;
|
||||
|
||||
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
|
||||
let block = &mut data[i..i + Self::BLOCK_SIZE];
|
||||
|
||||
// XOR with previous ciphertext
|
||||
for j in 0..Self::BLOCK_SIZE {
|
||||
block[j] ^= prev_ciphertext[j];
|
||||
}
|
||||
|
||||
// Encrypt in-place
|
||||
let block_array: &mut [u8; 16] = block.try_into().unwrap();
|
||||
*block_array = self.encrypt_block(block_array, &key_schedule);
|
||||
|
||||
// Save for next iteration
|
||||
prev_ciphertext = *block_array;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -148,7 +235,7 @@ impl AesCbc {
|
||||
|
||||
/// Decrypt data in-place
|
||||
pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
|
||||
if data.len() % 16 != 0 {
|
||||
if data.len() % Self::BLOCK_SIZE != 0 {
|
||||
return Err(ProxyError::Crypto(
|
||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||
));
|
||||
@@ -158,16 +245,38 @@ impl AesCbc {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into());
|
||||
use aes::cipher::KeyInit;
|
||||
let key_schedule = aes::Aes256::new((&self.key).into());
|
||||
|
||||
for chunk in data.chunks_mut(16) {
|
||||
decryptor.decrypt_block_mut(chunk.into());
|
||||
// For in-place decryption, we need to save ciphertext blocks
|
||||
// before we overwrite them
|
||||
let mut prev_ciphertext = self.iv;
|
||||
|
||||
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
|
||||
let block = &mut data[i..i + Self::BLOCK_SIZE];
|
||||
|
||||
// Save current ciphertext before modifying
|
||||
let current_ciphertext: [u8; 16] = block.try_into().unwrap();
|
||||
|
||||
// Decrypt in-place
|
||||
let block_array: &mut [u8; 16] = block.try_into().unwrap();
|
||||
*block_array = self.decrypt_block(block_array, &key_schedule);
|
||||
|
||||
// XOR with previous ciphertext
|
||||
for j in 0..Self::BLOCK_SIZE {
|
||||
block[j] ^= prev_ciphertext[j];
|
||||
}
|
||||
|
||||
// Save for next iteration
|
||||
prev_ciphertext = current_ciphertext;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Encryption Traits =============
|
||||
|
||||
/// Trait for unified encryption interface
|
||||
pub trait Encryptor: Send + Sync {
|
||||
fn encrypt(&mut self, data: &[u8]) -> Vec<u8>;
|
||||
@@ -209,6 +318,8 @@ impl Decryptor for PassthroughEncryptor {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ============= AES-CTR Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_aes_ctr_roundtrip() {
|
||||
let key = [0u8; 32];
|
||||
@@ -225,13 +336,35 @@ mod tests {
|
||||
assert_eq!(original.as_slice(), decrypted.as_slice());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aes_ctr_in_place() {
|
||||
let key = [0x42u8; 32];
|
||||
let iv = 999u128;
|
||||
|
||||
let original = b"Test data for in-place encryption";
|
||||
let mut data = original.to_vec();
|
||||
|
||||
let mut cipher = AesCtr::new(&key, iv);
|
||||
cipher.apply(&mut data);
|
||||
|
||||
// Encrypted should be different
|
||||
assert_ne!(&data[..], original);
|
||||
|
||||
// Decrypt with fresh cipher
|
||||
let mut cipher = AesCtr::new(&key, iv);
|
||||
cipher.apply(&mut data);
|
||||
|
||||
assert_eq!(&data[..], original);
|
||||
}
|
||||
|
||||
// ============= AES-CBC Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_aes_cbc_roundtrip() {
|
||||
let key = [0u8; 32];
|
||||
let iv = [0u8; 16];
|
||||
|
||||
// Must be aligned to 16 bytes
|
||||
let original = [0u8; 32];
|
||||
let original = [0u8; 32]; // 2 blocks
|
||||
|
||||
let cipher = AesCbc::new(key, iv);
|
||||
let encrypted = cipher.encrypt(&original).unwrap();
|
||||
@@ -242,47 +375,59 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_aes_cbc_chaining_works() {
|
||||
// This is the key test - verify CBC chaining is correct
|
||||
let key = [0x42u8; 32];
|
||||
let iv = [0x00u8; 16];
|
||||
|
||||
let plaintext = [0xAA_u8; 32];
|
||||
// Two IDENTICAL plaintext blocks
|
||||
let plaintext = [0xAAu8; 32];
|
||||
|
||||
let cipher = AesCbc::new(key, iv);
|
||||
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
||||
|
||||
// CBC Corrections
|
||||
// With proper CBC, identical plaintext blocks produce DIFFERENT ciphertext
|
||||
let block1 = &ciphertext[0..16];
|
||||
let block2 = &ciphertext[16..32];
|
||||
|
||||
assert_ne!(block1, block2, "CBC chaining broken: identical plaintext blocks produced identical ciphertext");
|
||||
assert_ne!(
|
||||
block1, block2,
|
||||
"CBC chaining broken: identical plaintext blocks produced identical ciphertext. \
|
||||
This indicates ECB mode, not CBC!"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aes_cbc_known_vector() {
|
||||
// Test with known NIST test vector
|
||||
// AES-256-CBC with zero key and zero IV
|
||||
let key = [0u8; 32];
|
||||
let iv = [0u8; 16];
|
||||
|
||||
// 3 Datablocks
|
||||
let plaintext = [
|
||||
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
|
||||
0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF,
|
||||
// Block 2
|
||||
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
|
||||
0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF,
|
||||
// Block 3 - different
|
||||
0xFF, 0xEE, 0xDD, 0xCC, 0xBB, 0xAA, 0x99, 0x88,
|
||||
0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11, 0x00,
|
||||
];
|
||||
let plaintext = [0u8; 16];
|
||||
|
||||
let cipher = AesCbc::new(key, iv);
|
||||
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
||||
|
||||
// Decrypt + Verify
|
||||
// Decrypt and verify roundtrip
|
||||
let decrypted = cipher.decrypt(&ciphertext).unwrap();
|
||||
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
|
||||
|
||||
// Verify Ciphertexts Block 1 != Block 2
|
||||
assert_ne!(&ciphertext[0..16], &ciphertext[16..32]);
|
||||
// Ciphertext should not be all zeros
|
||||
assert_ne!(ciphertext.as_slice(), plaintext.as_slice());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aes_cbc_multi_block() {
|
||||
let key = [0x12u8; 32];
|
||||
let iv = [0x34u8; 16];
|
||||
|
||||
// 5 blocks = 80 bytes
|
||||
let plaintext: Vec<u8> = (0..80).collect();
|
||||
|
||||
let cipher = AesCbc::new(key, iv);
|
||||
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
||||
let decrypted = cipher.decrypt(&ciphertext).unwrap();
|
||||
|
||||
assert_eq!(plaintext, decrypted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -291,7 +436,7 @@ mod tests {
|
||||
let iv = [0x34u8; 16];
|
||||
|
||||
let original = [0x56u8; 48]; // 3 blocks
|
||||
let mut buffer = original.clone();
|
||||
let mut buffer = original;
|
||||
|
||||
let cipher = AesCbc::new(key, iv);
|
||||
|
||||
@@ -317,35 +462,85 @@ mod tests {
|
||||
fn test_aes_cbc_unaligned_error() {
|
||||
let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
|
||||
|
||||
// 15 bytes
|
||||
// 15 bytes - not aligned to block size
|
||||
let result = cipher.encrypt(&[0u8; 15]);
|
||||
assert!(result.is_err());
|
||||
|
||||
// 17 bytes
|
||||
// 17 bytes - not aligned
|
||||
let result = cipher.encrypt(&[0u8; 17]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aes_cbc_avalanche_effect() {
|
||||
// Cipherplane
|
||||
|
||||
// Changing one bit in plaintext should change entire ciphertext block
|
||||
// and all subsequent blocks (due to chaining)
|
||||
let key = [0xAB; 32];
|
||||
let iv = [0xCD; 16];
|
||||
|
||||
let mut plaintext1 = [0u8; 32];
|
||||
let mut plaintext2 = [0u8; 32];
|
||||
plaintext2[0] = 0x01; // Один бит отличается
|
||||
plaintext2[0] = 0x01; // Single bit difference in first block
|
||||
|
||||
let cipher = AesCbc::new(key, iv);
|
||||
|
||||
let ciphertext1 = cipher.encrypt(&plaintext1).unwrap();
|
||||
let ciphertext2 = cipher.encrypt(&plaintext2).unwrap();
|
||||
|
||||
// First Blocks Diff
|
||||
// First blocks should be different
|
||||
assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]);
|
||||
|
||||
// Second Blocks Diff
|
||||
// Second blocks should ALSO be different (chaining effect)
|
||||
assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aes_cbc_iv_matters() {
|
||||
// Same plaintext with different IVs should produce different ciphertext
|
||||
let key = [0x55; 32];
|
||||
let plaintext = [0x77u8; 16];
|
||||
|
||||
let cipher1 = AesCbc::new(key, [0u8; 16]);
|
||||
let cipher2 = AesCbc::new(key, [1u8; 16]);
|
||||
|
||||
let ciphertext1 = cipher1.encrypt(&plaintext).unwrap();
|
||||
let ciphertext2 = cipher2.encrypt(&plaintext).unwrap();
|
||||
|
||||
assert_ne!(ciphertext1, ciphertext2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aes_cbc_deterministic() {
|
||||
// Same key, IV, plaintext should always produce same ciphertext
|
||||
let key = [0x99; 32];
|
||||
let iv = [0x88; 16];
|
||||
let plaintext = [0x77u8; 32];
|
||||
|
||||
let cipher = AesCbc::new(key, iv);
|
||||
|
||||
let ciphertext1 = cipher.encrypt(&plaintext).unwrap();
|
||||
let ciphertext2 = cipher.encrypt(&plaintext).unwrap();
|
||||
|
||||
assert_eq!(ciphertext1, ciphertext2);
|
||||
}
|
||||
|
||||
// ============= Error Handling Tests =============
|
||||
|
||||
#[test]
|
||||
fn test_invalid_key_length() {
|
||||
let result = AesCtr::from_key_iv(&[0u8; 16], &[0u8; 16]);
|
||||
assert!(result.is_err());
|
||||
|
||||
let result = AesCbc::from_slices(&[0u8; 16], &[0u8; 16]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_iv_length() {
|
||||
let result = AesCtr::from_key_iv(&[0u8; 32], &[0u8; 8]);
|
||||
assert!(result.is_err());
|
||||
|
||||
let result = AesCbc::from_slices(&[0u8; 32], &[0u8; 8]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
264
src/error.rs
264
src/error.rs
@@ -1,8 +1,177 @@
|
||||
//! Error Types
|
||||
|
||||
use std::fmt;
|
||||
use std::net::SocketAddr;
|
||||
use thiserror::Error;
|
||||
|
||||
// ============= Stream Errors =============
|
||||
|
||||
/// Errors specific to stream I/O operations
|
||||
#[derive(Debug)]
|
||||
pub enum StreamError {
|
||||
/// Partial read: got fewer bytes than expected
|
||||
PartialRead {
|
||||
expected: usize,
|
||||
got: usize,
|
||||
},
|
||||
/// Partial write: wrote fewer bytes than expected
|
||||
PartialWrite {
|
||||
expected: usize,
|
||||
written: usize,
|
||||
},
|
||||
/// Stream is in poisoned state and cannot be used
|
||||
Poisoned {
|
||||
reason: String,
|
||||
},
|
||||
/// Buffer overflow: attempted to buffer more than allowed
|
||||
BufferOverflow {
|
||||
limit: usize,
|
||||
attempted: usize,
|
||||
},
|
||||
/// Invalid frame format
|
||||
InvalidFrame {
|
||||
details: String,
|
||||
},
|
||||
/// Unexpected end of stream
|
||||
UnexpectedEof,
|
||||
/// Underlying I/O error
|
||||
Io(std::io::Error),
|
||||
}
|
||||
|
||||
impl fmt::Display for StreamError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::PartialRead { expected, got } => {
|
||||
write!(f, "partial read: expected {} bytes, got {}", expected, got)
|
||||
}
|
||||
Self::PartialWrite { expected, written } => {
|
||||
write!(f, "partial write: expected {} bytes, wrote {}", expected, written)
|
||||
}
|
||||
Self::Poisoned { reason } => {
|
||||
write!(f, "stream poisoned: {}", reason)
|
||||
}
|
||||
Self::BufferOverflow { limit, attempted } => {
|
||||
write!(f, "buffer overflow: limit {}, attempted {}", limit, attempted)
|
||||
}
|
||||
Self::InvalidFrame { details } => {
|
||||
write!(f, "invalid frame: {}", details)
|
||||
}
|
||||
Self::UnexpectedEof => {
|
||||
write!(f, "unexpected end of stream")
|
||||
}
|
||||
Self::Io(e) => {
|
||||
write!(f, "I/O error: {}", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for StreamError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::Io(e) => Some(e),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for StreamError {
|
||||
fn from(err: std::io::Error) -> Self {
|
||||
Self::Io(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StreamError> for std::io::Error {
|
||||
fn from(err: StreamError) -> Self {
|
||||
match err {
|
||||
StreamError::Io(e) => e,
|
||||
StreamError::UnexpectedEof => {
|
||||
std::io::Error::new(std::io::ErrorKind::UnexpectedEof, err)
|
||||
}
|
||||
StreamError::Poisoned { .. } => {
|
||||
std::io::Error::new(std::io::ErrorKind::Other, err)
|
||||
}
|
||||
StreamError::BufferOverflow { .. } => {
|
||||
std::io::Error::new(std::io::ErrorKind::OutOfMemory, err)
|
||||
}
|
||||
StreamError::InvalidFrame { .. } => {
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidData, err)
|
||||
}
|
||||
StreamError::PartialRead { .. } | StreamError::PartialWrite { .. } => {
|
||||
std::io::Error::new(std::io::ErrorKind::Other, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Recoverable Trait =============
|
||||
|
||||
/// Trait for errors that may be recoverable
|
||||
pub trait Recoverable {
|
||||
/// Check if error is recoverable (can retry operation)
|
||||
fn is_recoverable(&self) -> bool;
|
||||
|
||||
/// Check if connection can continue after this error
|
||||
fn can_continue(&self) -> bool;
|
||||
}
|
||||
|
||||
impl Recoverable for StreamError {
|
||||
fn is_recoverable(&self) -> bool {
|
||||
match self {
|
||||
// Partial operations can be retried
|
||||
Self::PartialRead { .. } | Self::PartialWrite { .. } => true,
|
||||
// I/O errors depend on kind
|
||||
Self::Io(e) => matches!(
|
||||
e.kind(),
|
||||
std::io::ErrorKind::WouldBlock
|
||||
| std::io::ErrorKind::Interrupted
|
||||
| std::io::ErrorKind::TimedOut
|
||||
),
|
||||
// These are not recoverable
|
||||
Self::Poisoned { .. }
|
||||
| Self::BufferOverflow { .. }
|
||||
| Self::InvalidFrame { .. }
|
||||
| Self::UnexpectedEof => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn can_continue(&self) -> bool {
|
||||
match self {
|
||||
// Poisoned stream cannot be used
|
||||
Self::Poisoned { .. } => false,
|
||||
// EOF means stream is done
|
||||
Self::UnexpectedEof => false,
|
||||
// Buffer overflow is fatal
|
||||
Self::BufferOverflow { .. } => false,
|
||||
// Others might allow continuation
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Recoverable for std::io::Error {
|
||||
fn is_recoverable(&self) -> bool {
|
||||
matches!(
|
||||
self.kind(),
|
||||
std::io::ErrorKind::WouldBlock
|
||||
| std::io::ErrorKind::Interrupted
|
||||
| std::io::ErrorKind::TimedOut
|
||||
)
|
||||
}
|
||||
|
||||
fn can_continue(&self) -> bool {
|
||||
!matches!(
|
||||
self.kind(),
|
||||
std::io::ErrorKind::BrokenPipe
|
||||
| std::io::ErrorKind::ConnectionReset
|
||||
| std::io::ErrorKind::ConnectionAborted
|
||||
| std::io::ErrorKind::NotConnected
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Main Proxy Errors =============
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ProxyError {
|
||||
// ============= Crypto Errors =============
|
||||
@@ -13,6 +182,11 @@ pub enum ProxyError {
|
||||
#[error("Invalid key length: expected {expected}, got {got}")]
|
||||
InvalidKeyLength { expected: usize, got: usize },
|
||||
|
||||
// ============= Stream Errors =============
|
||||
|
||||
#[error("Stream error: {0}")]
|
||||
Stream(#[from] StreamError),
|
||||
|
||||
// ============= Protocol Errors =============
|
||||
|
||||
#[error("Invalid handshake: {0}")]
|
||||
@@ -39,6 +213,12 @@ pub enum ProxyError {
|
||||
#[error("Sequence number mismatch: expected={expected}, got={got}")]
|
||||
SeqNoMismatch { expected: i32, got: i32 },
|
||||
|
||||
#[error("TLS handshake failed: {reason}")]
|
||||
TlsHandshakeFailed { reason: String },
|
||||
|
||||
#[error("Telegram handshake timeout")]
|
||||
TgHandshakeTimeout,
|
||||
|
||||
// ============= Network Errors =============
|
||||
|
||||
#[error("Connection timeout to {addr}")]
|
||||
@@ -55,6 +235,9 @@ pub enum ProxyError {
|
||||
#[error("Invalid proxy protocol header")]
|
||||
InvalidProxyProtocol,
|
||||
|
||||
#[error("Proxy error: {0}")]
|
||||
Proxy(String),
|
||||
|
||||
// ============= Config Errors =============
|
||||
|
||||
#[error("Config error: {0}")]
|
||||
@@ -77,15 +260,41 @@ pub enum ProxyError {
|
||||
#[error("Unknown user")]
|
||||
UnknownUser,
|
||||
|
||||
#[error("Rate limited")]
|
||||
RateLimited,
|
||||
|
||||
// ============= General Errors =============
|
||||
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
impl Recoverable for ProxyError {
|
||||
fn is_recoverable(&self) -> bool {
|
||||
match self {
|
||||
Self::Stream(e) => e.is_recoverable(),
|
||||
Self::Io(e) => e.is_recoverable(),
|
||||
Self::ConnectionTimeout { .. } => true,
|
||||
Self::RateLimited => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn can_continue(&self) -> bool {
|
||||
match self {
|
||||
Self::Stream(e) => e.can_continue(),
|
||||
Self::Io(e) => e.can_continue(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenient Result type alias
|
||||
pub type Result<T> = std::result::Result<T, ProxyError>;
|
||||
|
||||
/// Result type for stream operations
|
||||
pub type StreamResult<T> = std::result::Result<T, StreamError>;
|
||||
|
||||
/// Result with optional bad client handling
|
||||
#[derive(Debug)]
|
||||
pub enum HandshakeResult<T> {
|
||||
@@ -125,6 +334,14 @@ impl<T> HandshakeResult<T> {
|
||||
HandshakeResult::Error(e) => HandshakeResult::Error(e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert success to Option
|
||||
pub fn ok(self) -> Option<T> {
|
||||
match self {
|
||||
HandshakeResult::Success(v) => Some(v),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<ProxyError> for HandshakeResult<T> {
|
||||
@@ -139,10 +356,48 @@ impl<T> From<std::io::Error> for HandshakeResult<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<StreamError> for HandshakeResult<T> {
|
||||
fn from(err: StreamError) -> Self {
|
||||
HandshakeResult::Error(ProxyError::Stream(err))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_stream_error_display() {
|
||||
let err = StreamError::PartialRead { expected: 100, got: 50 };
|
||||
assert!(err.to_string().contains("100"));
|
||||
assert!(err.to_string().contains("50"));
|
||||
|
||||
let err = StreamError::Poisoned { reason: "test".into() };
|
||||
assert!(err.to_string().contains("test"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_error_recoverable() {
|
||||
assert!(StreamError::PartialRead { expected: 10, got: 5 }.is_recoverable());
|
||||
assert!(StreamError::PartialWrite { expected: 10, written: 5 }.is_recoverable());
|
||||
assert!(!StreamError::Poisoned { reason: "x".into() }.is_recoverable());
|
||||
assert!(!StreamError::UnexpectedEof.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_error_can_continue() {
|
||||
assert!(!StreamError::Poisoned { reason: "x".into() }.can_continue());
|
||||
assert!(!StreamError::UnexpectedEof.can_continue());
|
||||
assert!(StreamError::PartialRead { expected: 10, got: 5 }.can_continue());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_error_to_io_error() {
|
||||
let stream_err = StreamError::UnexpectedEof;
|
||||
let io_err: std::io::Error = stream_err.into();
|
||||
assert_eq!(io_err.kind(), std::io::ErrorKind::UnexpectedEof);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handshake_result() {
|
||||
let success: HandshakeResult<i32> = HandshakeResult::Success(42);
|
||||
@@ -165,6 +420,15 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_proxy_error_recoverable() {
|
||||
let err = ProxyError::RateLimited;
|
||||
assert!(err.is_recoverable());
|
||||
|
||||
let err = ProxyError::InvalidHandshake("bad".into());
|
||||
assert!(!err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_display() {
|
||||
let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() };
|
||||
|
||||
304
src/main.rs
304
src/main.rs
@@ -1,158 +1,196 @@
|
||||
//! Telemt - MTProxy on Rust
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::signal;
|
||||
use tracing::{info, error, Level};
|
||||
use tracing_subscriber::{FmtSubscriber, EnvFilter};
|
||||
use tracing::{info, error, warn};
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
|
||||
mod error;
|
||||
mod config;
|
||||
mod crypto;
|
||||
mod error;
|
||||
mod protocol;
|
||||
mod proxy;
|
||||
mod stats;
|
||||
mod stream;
|
||||
mod transport;
|
||||
mod proxy;
|
||||
mod config;
|
||||
mod stats;
|
||||
mod util;
|
||||
|
||||
use config::ProxyConfig;
|
||||
use stats::{Stats, ReplayChecker};
|
||||
use transport::ConnectionPool;
|
||||
use proxy::ClientHandler;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::proxy::ClientHandler;
|
||||
use crate::stats::{Stats, ReplayChecker};
|
||||
use crate::transport::{create_listener, ListenOptions, UpstreamManager};
|
||||
use crate::util::ip::detect_ip;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||
// Initialize logging with env filter
|
||||
// Use RUST_LOG=debug or RUST_LOG=trace for more details
|
||||
let filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Initialize logging
|
||||
fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive("info".parse().unwrap()))
|
||||
.init();
|
||||
|
||||
let subscriber = FmtSubscriber::builder()
|
||||
.with_env_filter(filter)
|
||||
.with_target(true)
|
||||
.with_thread_ids(false)
|
||||
.with_file(false)
|
||||
.with_line_number(false)
|
||||
.finish();
|
||||
|
||||
tracing::subscriber::set_global_default(subscriber)?;
|
||||
|
||||
// Load configuration
|
||||
let config_path = std::env::args()
|
||||
.nth(1)
|
||||
.unwrap_or_else(|| "config.toml".to_string());
|
||||
|
||||
info!("Loading configuration from {}", config_path);
|
||||
|
||||
let config = ProxyConfig::load(&config_path).unwrap_or_else(|e| {
|
||||
error!("Failed to load config: {}", e);
|
||||
info!("Using default configuration");
|
||||
ProxyConfig::default()
|
||||
});
|
||||
|
||||
if let Err(e) = config.validate() {
|
||||
error!("Invalid configuration: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let config = Arc::new(config);
|
||||
|
||||
info!("Starting MTProto Proxy on port {}", config.port);
|
||||
info!("Fast mode: {}", config.fast_mode);
|
||||
info!("Modes: classic={}, secure={}, tls={}",
|
||||
config.modes.classic, config.modes.secure, config.modes.tls);
|
||||
|
||||
// Initialize components
|
||||
let stats = Arc::new(Stats::new());
|
||||
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len));
|
||||
let pool = Arc::new(ConnectionPool::new());
|
||||
|
||||
// Create handler
|
||||
let handler = Arc::new(ClientHandler::new(
|
||||
Arc::clone(&config),
|
||||
Arc::clone(&stats),
|
||||
Arc::clone(&replay_checker),
|
||||
Arc::clone(&pool),
|
||||
));
|
||||
|
||||
// Start listener
|
||||
let addr: SocketAddr = format!("{}:{}", config.listen_addr_ipv4, config.port)
|
||||
.parse()?;
|
||||
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
info!("Listening on {}", addr);
|
||||
|
||||
// Print proxy links
|
||||
print_proxy_links(&config);
|
||||
|
||||
info!("Use RUST_LOG=debug or RUST_LOG=trace for more detailed logging");
|
||||
|
||||
// Main accept loop
|
||||
let accept_loop = async {
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((stream, peer)) => {
|
||||
let handler = Arc::clone(&handler);
|
||||
tokio::spawn(async move {
|
||||
handler.handle(stream, peer).await;
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Accept error: {}", e);
|
||||
}
|
||||
// Load config
|
||||
let config_path = std::env::args().nth(1).unwrap_or_else(|| "config.toml".to_string());
|
||||
let config = match ProxyConfig::load(&config_path) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
// If config doesn't exist, try to create default
|
||||
if std::path::Path::new(&config_path).exists() {
|
||||
error!("Failed to load config: {}", e);
|
||||
std::process::exit(1);
|
||||
} else {
|
||||
let default = ProxyConfig::default();
|
||||
let toml = toml::to_string_pretty(&default).unwrap();
|
||||
std::fs::write(&config_path, toml).unwrap();
|
||||
info!("Created default config at {}", config_path);
|
||||
default
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Graceful shutdown
|
||||
tokio::select! {
|
||||
_ = accept_loop => {}
|
||||
_ = signal::ctrl_c() => {
|
||||
info!("Shutting down...");
|
||||
config.validate()?;
|
||||
|
||||
let config = Arc::new(config);
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
// CHANGED: Initialize global ReplayChecker here instead of per-connection
|
||||
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len));
|
||||
|
||||
// Initialize Upstream Manager
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
|
||||
|
||||
// Start Health Checks
|
||||
let um_clone = upstream_manager.clone();
|
||||
tokio::spawn(async move {
|
||||
um_clone.run_health_checks().await;
|
||||
});
|
||||
|
||||
// Detect public IP if needed (once at startup)
|
||||
let detected_ip = detect_ip().await;
|
||||
|
||||
// Start Listeners
|
||||
let mut listeners = Vec::new();
|
||||
|
||||
for listener_conf in &config.listeners {
|
||||
let addr = SocketAddr::new(listener_conf.ip, config.port);
|
||||
let options = ListenOptions {
|
||||
ipv6_only: listener_conf.ip.is_ipv6(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
match create_listener(addr, &options) {
|
||||
Ok(socket) => {
|
||||
let listener = TcpListener::from_std(socket.into())?;
|
||||
info!("Listening on {}", addr);
|
||||
|
||||
// Determine public IP for tg:// links
|
||||
// 1. Use explicit announce_ip if set
|
||||
// 2. If listening on 0.0.0.0 or ::, use detected public IP
|
||||
// 3. Otherwise use the bind IP
|
||||
let public_ip = if let Some(ip) = listener_conf.announce_ip {
|
||||
ip
|
||||
} else if listener_conf.ip.is_unspecified() {
|
||||
// Try to use detected IP of the same family
|
||||
if listener_conf.ip.is_ipv4() {
|
||||
detected_ip.ipv4.unwrap_or(listener_conf.ip)
|
||||
} else {
|
||||
detected_ip.ipv6.unwrap_or(listener_conf.ip)
|
||||
}
|
||||
} else {
|
||||
listener_conf.ip
|
||||
};
|
||||
|
||||
// Show links for configured users
|
||||
if !config.show_link.is_empty() {
|
||||
info!("--- Proxy Links for {} ---", public_ip);
|
||||
for user_name in &config.show_link {
|
||||
if let Some(secret) = config.users.get(user_name) {
|
||||
info!("User: {}", user_name);
|
||||
|
||||
// Classic
|
||||
if config.modes.classic {
|
||||
info!(" Classic: tg://proxy?server={}&port={}&secret={}",
|
||||
public_ip, config.port, secret);
|
||||
}
|
||||
|
||||
// DD (Secure)
|
||||
if config.modes.secure {
|
||||
info!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
|
||||
public_ip, config.port, secret);
|
||||
}
|
||||
|
||||
// EE-TLS (FakeTLS)
|
||||
if config.modes.tls {
|
||||
let domain_hex = hex::encode(&config.tls_domain);
|
||||
info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
|
||||
public_ip, config.port, secret, domain_hex);
|
||||
}
|
||||
} else {
|
||||
warn!("User '{}' specified in show_link not found in users list", user_name);
|
||||
}
|
||||
}
|
||||
info!("-----------------------------------");
|
||||
}
|
||||
|
||||
listeners.push(listener);
|
||||
},
|
||||
Err(e) => {
|
||||
error!("Failed to bind to {}: {}", addr, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
pool.close_all().await;
|
||||
if listeners.is_empty() {
|
||||
error!("No listeners could be started. Exiting.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Accept loop
|
||||
for listener in listeners {
|
||||
let config = config.clone();
|
||||
let stats = stats.clone();
|
||||
let upstream_manager = upstream_manager.clone();
|
||||
let replay_checker = replay_checker.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((stream, peer_addr)) => {
|
||||
let config = config.clone();
|
||||
let stats = stats.clone();
|
||||
let upstream_manager = upstream_manager.clone();
|
||||
let replay_checker = replay_checker.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = ClientHandler::new(
|
||||
stream,
|
||||
peer_addr,
|
||||
config,
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker // Pass global checker
|
||||
).run().await {
|
||||
// Log only relevant errors
|
||||
// debug!("Connection error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Accept error: {}", e);
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Wait for signal
|
||||
match signal::ctrl_c().await {
|
||||
Ok(()) => info!("Shutting down..."),
|
||||
Err(e) => error!("Signal error: {}", e),
|
||||
}
|
||||
|
||||
info!("Goodbye!");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn print_proxy_links(config: &ProxyConfig) {
|
||||
println!("\n=== Proxy Links ===\n");
|
||||
|
||||
for (user, secret) in &config.users {
|
||||
if config.modes.tls {
|
||||
let tls_secret = format!(
|
||||
"ee{}{}",
|
||||
secret,
|
||||
hex::encode(config.tls_domain.as_bytes())
|
||||
);
|
||||
println!(
|
||||
"{} (TLS): tg://proxy?server=IP&port={}&secret={}",
|
||||
user, config.port, tls_secret
|
||||
);
|
||||
}
|
||||
|
||||
if config.modes.secure {
|
||||
println!(
|
||||
"{} (Secure): tg://proxy?server=IP&port={}&secret=dd{}",
|
||||
user, config.port, secret
|
||||
);
|
||||
}
|
||||
|
||||
if config.modes.classic {
|
||||
println!(
|
||||
"{} (Classic): tg://proxy?server=IP&port={}&secret={}",
|
||||
user, config.port, secret
|
||||
);
|
||||
}
|
||||
|
||||
println!();
|
||||
}
|
||||
|
||||
println!("===================\n");
|
||||
}
|
||||
@@ -167,7 +167,10 @@ pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300;
|
||||
// ============= Buffer Sizes =============
|
||||
|
||||
/// Default buffer size
|
||||
pub const DEFAULT_BUFFER_SIZE: usize = 65536;
|
||||
/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and align with
|
||||
/// the new buffering strategy for better iOS upload performance.
|
||||
pub const DEFAULT_BUFFER_SIZE: usize = 16384;
|
||||
|
||||
/// Small buffer size for bad client handling
|
||||
pub const SMALL_BUFFER_SIZE: usize = 8192;
|
||||
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
//! Fake TLS 1.3 Handshake
|
||||
//!
|
||||
//! This module handles the fake TLS 1.3 handshake used by MTProto proxy
|
||||
//! for domain fronting. The handshake looks like valid TLS 1.3 but
|
||||
//! actually carries MTProto authentication data.
|
||||
|
||||
use crate::crypto::{sha256_hmac, random::SECURE_RANDOM};
|
||||
use crate::error::{ProxyError, Result};
|
||||
use super::constants::*;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
// ============= Public Constants =============
|
||||
|
||||
/// TLS handshake digest length
|
||||
pub const TLS_DIGEST_LEN: usize = 32;
|
||||
|
||||
/// Position of digest in TLS ClientHello
|
||||
pub const TLS_DIGEST_POS: usize = 11;
|
||||
|
||||
/// Length to store for replay protection (first 16 bytes of digest)
|
||||
pub const TLS_DIGEST_HALF_LEN: usize = 16;
|
||||
|
||||
@@ -16,6 +24,26 @@ pub const TLS_DIGEST_HALF_LEN: usize = 16;
|
||||
pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before
|
||||
pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after
|
||||
|
||||
// ============= Private Constants =============
|
||||
|
||||
/// TLS Extension types
|
||||
mod extension_type {
|
||||
pub const KEY_SHARE: u16 = 0x0033;
|
||||
pub const SUPPORTED_VERSIONS: u16 = 0x002b;
|
||||
}
|
||||
|
||||
/// TLS Cipher Suites
|
||||
mod cipher_suite {
|
||||
pub const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
|
||||
}
|
||||
|
||||
/// TLS Named Curves
|
||||
mod named_curve {
|
||||
pub const X25519: u16 = 0x001d;
|
||||
}
|
||||
|
||||
// ============= TLS Validation Result =============
|
||||
|
||||
/// Result of validating TLS handshake
|
||||
#[derive(Debug)]
|
||||
pub struct TlsValidation {
|
||||
@@ -29,7 +57,185 @@ pub struct TlsValidation {
|
||||
pub timestamp: u32,
|
||||
}
|
||||
|
||||
// ============= TLS Extension Builder =============
|
||||
|
||||
/// Builder for TLS extensions with correct length calculation
|
||||
struct TlsExtensionBuilder {
|
||||
extensions: Vec<u8>,
|
||||
}
|
||||
|
||||
impl TlsExtensionBuilder {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
extensions: Vec::with_capacity(128),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add Key Share extension with X25519 key
|
||||
fn add_key_share(&mut self, public_key: &[u8; 32]) -> &mut Self {
|
||||
// Extension type: key_share (0x0033)
|
||||
self.extensions.extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes());
|
||||
|
||||
// Key share entry: curve (2) + key_len (2) + key (32) = 36 bytes
|
||||
// Extension data length
|
||||
let entry_len: u16 = 2 + 2 + 32; // curve + length + key
|
||||
self.extensions.extend_from_slice(&entry_len.to_be_bytes());
|
||||
|
||||
// Named curve: x25519
|
||||
self.extensions.extend_from_slice(&named_curve::X25519.to_be_bytes());
|
||||
|
||||
// Key length
|
||||
self.extensions.extend_from_slice(&(32u16).to_be_bytes());
|
||||
|
||||
// Key data
|
||||
self.extensions.extend_from_slice(public_key);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Add Supported Versions extension
|
||||
fn add_supported_versions(&mut self, version: u16) -> &mut Self {
|
||||
// Extension type: supported_versions (0x002b)
|
||||
self.extensions.extend_from_slice(&extension_type::SUPPORTED_VERSIONS.to_be_bytes());
|
||||
|
||||
// Extension data: length (2) + version (2)
|
||||
self.extensions.extend_from_slice(&(2u16).to_be_bytes());
|
||||
|
||||
// Selected version
|
||||
self.extensions.extend_from_slice(&version.to_be_bytes());
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Build final extensions with length prefix
|
||||
fn build(self) -> Vec<u8> {
|
||||
let mut result = Vec::with_capacity(2 + self.extensions.len());
|
||||
|
||||
// Extensions length (2 bytes)
|
||||
let len = self.extensions.len() as u16;
|
||||
result.extend_from_slice(&len.to_be_bytes());
|
||||
|
||||
// Extensions data
|
||||
result.extend_from_slice(&self.extensions);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Get current extensions without length prefix (for calculation)
|
||||
#[allow(dead_code)]
|
||||
fn as_bytes(&self) -> &[u8] {
|
||||
&self.extensions
|
||||
}
|
||||
}
|
||||
|
||||
// ============= ServerHello Builder =============
|
||||
|
||||
/// Builder for TLS ServerHello with correct structure
|
||||
struct ServerHelloBuilder {
|
||||
/// Random bytes (32 bytes, will contain digest)
|
||||
random: [u8; 32],
|
||||
/// Session ID (echoed from ClientHello)
|
||||
session_id: Vec<u8>,
|
||||
/// Cipher suite
|
||||
cipher_suite: [u8; 2],
|
||||
/// Compression method
|
||||
compression: u8,
|
||||
/// Extensions
|
||||
extensions: TlsExtensionBuilder,
|
||||
}
|
||||
|
||||
impl ServerHelloBuilder {
|
||||
fn new(session_id: Vec<u8>) -> Self {
|
||||
Self {
|
||||
random: [0u8; 32],
|
||||
session_id,
|
||||
cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256,
|
||||
compression: 0x00,
|
||||
extensions: TlsExtensionBuilder::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn with_x25519_key(mut self, key: &[u8; 32]) -> Self {
|
||||
self.extensions.add_key_share(key);
|
||||
self
|
||||
}
|
||||
|
||||
fn with_tls13_version(mut self) -> Self {
|
||||
// TLS 1.3 = 0x0304
|
||||
self.extensions.add_supported_versions(0x0304);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build ServerHello message (without record header)
|
||||
fn build_message(&self) -> Vec<u8> {
|
||||
let extensions = self.extensions.extensions.clone();
|
||||
let extensions_len = extensions.len() as u16;
|
||||
|
||||
// Calculate total length
|
||||
let body_len = 2 + // version
|
||||
32 + // random
|
||||
1 + self.session_id.len() + // session_id length + data
|
||||
2 + // cipher suite
|
||||
1 + // compression
|
||||
2 + extensions.len(); // extensions length + data
|
||||
|
||||
let mut message = Vec::with_capacity(4 + body_len);
|
||||
|
||||
// Handshake header
|
||||
message.push(0x02); // ServerHello message type
|
||||
|
||||
// 3-byte length
|
||||
let len_bytes = (body_len as u32).to_be_bytes();
|
||||
message.extend_from_slice(&len_bytes[1..4]);
|
||||
|
||||
// Server version (TLS 1.2 in header, actual version in extension)
|
||||
message.extend_from_slice(&TLS_VERSION);
|
||||
|
||||
// Random (32 bytes) - placeholder, will be replaced with digest
|
||||
message.extend_from_slice(&self.random);
|
||||
|
||||
// Session ID
|
||||
message.push(self.session_id.len() as u8);
|
||||
message.extend_from_slice(&self.session_id);
|
||||
|
||||
// Cipher suite
|
||||
message.extend_from_slice(&self.cipher_suite);
|
||||
|
||||
// Compression method
|
||||
message.push(self.compression);
|
||||
|
||||
// Extensions length
|
||||
message.extend_from_slice(&extensions_len.to_be_bytes());
|
||||
|
||||
// Extensions data
|
||||
message.extend_from_slice(&extensions);
|
||||
|
||||
message
|
||||
}
|
||||
|
||||
/// Build complete ServerHello TLS record
|
||||
fn build_record(&self) -> Vec<u8> {
|
||||
let message = self.build_message();
|
||||
|
||||
let mut record = Vec::with_capacity(5 + message.len());
|
||||
|
||||
// TLS record header
|
||||
record.push(TLS_RECORD_HANDSHAKE);
|
||||
record.extend_from_slice(&TLS_VERSION);
|
||||
record.extend_from_slice(&(message.len() as u16).to_be_bytes());
|
||||
|
||||
// Message
|
||||
record.extend_from_slice(&message);
|
||||
|
||||
record
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Public Functions =============
|
||||
|
||||
/// Validate TLS ClientHello against user secrets
|
||||
///
|
||||
/// Returns validation result if a matching user is found.
|
||||
pub fn validate_tls_handshake(
|
||||
handshake: &[u8],
|
||||
secrets: &[(String, Vec<u8>)],
|
||||
@@ -86,7 +292,8 @@ pub fn validate_tls_handshake(
|
||||
// Check time skew
|
||||
if !ignore_time_skew {
|
||||
// Allow very small timestamps (boot time instead of unix time)
|
||||
let is_boot_time = timestamp < 60 * 60 * 24 * 1000;
|
||||
// This is a quirk in some clients that use uptime instead of real time
|
||||
let is_boot_time = timestamp < 60 * 60 * 24 * 1000; // < ~2.7 years in seconds
|
||||
|
||||
if !is_boot_time && (time_diff < TIME_SKEW_MIN || time_diff > TIME_SKEW_MAX) {
|
||||
continue;
|
||||
@@ -105,15 +312,22 @@ pub fn validate_tls_handshake(
|
||||
}
|
||||
|
||||
/// Generate a fake X25519 public key for TLS
|
||||
/// This generates a value that looks like a valid X25519 key
|
||||
///
|
||||
/// This generates random bytes that look like a valid X25519 public key.
|
||||
/// Since we're not doing real TLS, the actual cryptographic properties don't matter.
|
||||
pub fn gen_fake_x25519_key() -> [u8; 32] {
|
||||
// For simplicity, just generate random 32 bytes
|
||||
// In real X25519, this would be a point on the curve
|
||||
let bytes = SECURE_RANDOM.bytes(32);
|
||||
bytes.try_into().unwrap()
|
||||
}
|
||||
|
||||
/// Build TLS ServerHello response
|
||||
///
|
||||
/// This builds a complete TLS 1.3-like response including:
|
||||
/// - ServerHello record with extensions
|
||||
/// - Change Cipher Spec record
|
||||
/// - Fake encrypted certificate (Application Data record)
|
||||
///
|
||||
/// The response includes an HMAC digest that the client can verify.
|
||||
pub fn build_server_hello(
|
||||
secret: &[u8],
|
||||
client_digest: &[u8; TLS_DIGEST_LEN],
|
||||
@@ -122,62 +336,48 @@ pub fn build_server_hello(
|
||||
) -> Vec<u8> {
|
||||
let x25519_key = gen_fake_x25519_key();
|
||||
|
||||
// TLS extensions
|
||||
let mut extensions = Vec::new();
|
||||
extensions.extend_from_slice(&[0x00, 0x2e]); // Extension length placeholder
|
||||
extensions.extend_from_slice(&[0x00, 0x33, 0x00, 0x24]); // Key share extension
|
||||
extensions.extend_from_slice(&[0x00, 0x1d, 0x00, 0x20]); // X25519 curve
|
||||
extensions.extend_from_slice(&x25519_key);
|
||||
extensions.extend_from_slice(&[0x00, 0x2b, 0x00, 0x02, 0x03, 0x04]); // Supported versions
|
||||
// Build ServerHello
|
||||
let server_hello = ServerHelloBuilder::new(session_id.to_vec())
|
||||
.with_x25519_key(&x25519_key)
|
||||
.with_tls13_version()
|
||||
.build_record();
|
||||
|
||||
// ServerHello body
|
||||
let mut srv_hello = Vec::new();
|
||||
srv_hello.extend_from_slice(&TLS_VERSION);
|
||||
srv_hello.extend_from_slice(&[0u8; TLS_DIGEST_LEN]); // Placeholder for digest
|
||||
srv_hello.push(session_id.len() as u8);
|
||||
srv_hello.extend_from_slice(session_id);
|
||||
srv_hello.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256
|
||||
srv_hello.push(0x00); // No compression
|
||||
srv_hello.extend_from_slice(&extensions);
|
||||
|
||||
// Build complete packet
|
||||
let mut hello_pkt = Vec::new();
|
||||
|
||||
// ServerHello record
|
||||
hello_pkt.push(TLS_RECORD_HANDSHAKE);
|
||||
hello_pkt.extend_from_slice(&TLS_VERSION);
|
||||
hello_pkt.extend_from_slice(&((srv_hello.len() + 4) as u16).to_be_bytes());
|
||||
hello_pkt.push(0x02); // ServerHello message type
|
||||
let len_bytes = (srv_hello.len() as u32).to_be_bytes();
|
||||
hello_pkt.extend_from_slice(&len_bytes[1..4]); // 3-byte length
|
||||
hello_pkt.extend_from_slice(&srv_hello);
|
||||
|
||||
// Change Cipher Spec record
|
||||
hello_pkt.extend_from_slice(&[
|
||||
// Build Change Cipher Spec record
|
||||
let change_cipher_spec = [
|
||||
TLS_RECORD_CHANGE_CIPHER,
|
||||
TLS_VERSION[0], TLS_VERSION[1],
|
||||
0x00, 0x01, 0x01
|
||||
]);
|
||||
0x00, 0x01, // length = 1
|
||||
0x01, // CCS byte
|
||||
];
|
||||
|
||||
// Application Data record (fake certificate)
|
||||
// Build fake certificate (Application Data record)
|
||||
let fake_cert = SECURE_RANDOM.bytes(fake_cert_len);
|
||||
hello_pkt.push(TLS_RECORD_APPLICATION);
|
||||
hello_pkt.extend_from_slice(&TLS_VERSION);
|
||||
hello_pkt.extend_from_slice(&(fake_cert.len() as u16).to_be_bytes());
|
||||
hello_pkt.extend_from_slice(&fake_cert);
|
||||
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
|
||||
app_data_record.push(TLS_RECORD_APPLICATION);
|
||||
app_data_record.extend_from_slice(&TLS_VERSION);
|
||||
app_data_record.extend_from_slice(&(fake_cert_len as u16).to_be_bytes());
|
||||
app_data_record.extend_from_slice(&fake_cert);
|
||||
|
||||
// Combine all records
|
||||
let mut response = Vec::with_capacity(
|
||||
server_hello.len() + change_cipher_spec.len() + app_data_record.len()
|
||||
);
|
||||
response.extend_from_slice(&server_hello);
|
||||
response.extend_from_slice(&change_cipher_spec);
|
||||
response.extend_from_slice(&app_data_record);
|
||||
|
||||
// Compute HMAC for the response
|
||||
let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + hello_pkt.len());
|
||||
let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + response.len());
|
||||
hmac_input.extend_from_slice(client_digest);
|
||||
hmac_input.extend_from_slice(&hello_pkt);
|
||||
hmac_input.extend_from_slice(&response);
|
||||
let response_digest = sha256_hmac(secret, &hmac_input);
|
||||
|
||||
// Insert computed digest
|
||||
// Position: after record header (5) + message type/length (4) + version (2) = 11
|
||||
hello_pkt[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
|
||||
// Insert computed digest into ServerHello
|
||||
// Position: record header (5) + message type (1) + length (3) + version (2) = 11
|
||||
response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
|
||||
.copy_from_slice(&response_digest);
|
||||
|
||||
hello_pkt
|
||||
response
|
||||
}
|
||||
|
||||
/// Check if bytes look like a TLS ClientHello
|
||||
@@ -186,7 +386,7 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TLS record header: 0x16 0x03 0x01
|
||||
// TLS record header: 0x16 (handshake) 0x03 0x01 (TLS 1.0)
|
||||
first_bytes[0] == TLS_RECORD_HANDSHAKE
|
||||
&& first_bytes[1] == 0x03
|
||||
&& first_bytes[2] == 0x01
|
||||
@@ -206,6 +406,61 @@ pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> {
|
||||
Some((record_type, length))
|
||||
}
|
||||
|
||||
/// Validate a ServerHello response structure
|
||||
///
|
||||
/// This is useful for testing that our ServerHello is well-formed.
|
||||
#[cfg(test)]
|
||||
fn validate_server_hello_structure(data: &[u8]) -> Result<()> {
|
||||
if data.len() < 5 {
|
||||
return Err(ProxyError::InvalidTlsRecord {
|
||||
record_type: 0,
|
||||
version: [0, 0],
|
||||
});
|
||||
}
|
||||
|
||||
// Check record header
|
||||
if data[0] != TLS_RECORD_HANDSHAKE {
|
||||
return Err(ProxyError::InvalidTlsRecord {
|
||||
record_type: data[0],
|
||||
version: [data[1], data[2]],
|
||||
});
|
||||
}
|
||||
|
||||
// Check version
|
||||
if data[1..3] != TLS_VERSION {
|
||||
return Err(ProxyError::InvalidTlsRecord {
|
||||
record_type: data[0],
|
||||
version: [data[1], data[2]],
|
||||
});
|
||||
}
|
||||
|
||||
// Check record length
|
||||
let record_len = u16::from_be_bytes([data[3], data[4]]) as usize;
|
||||
if data.len() < 5 + record_len {
|
||||
return Err(ProxyError::InvalidHandshake(
|
||||
format!("ServerHello record truncated: expected {}, got {}",
|
||||
5 + record_len, data.len())
|
||||
));
|
||||
}
|
||||
|
||||
// Check message type
|
||||
if data[5] != 0x02 {
|
||||
return Err(ProxyError::InvalidHandshake(
|
||||
format!("Expected ServerHello (0x02), got 0x{:02x}", data[5])
|
||||
));
|
||||
}
|
||||
|
||||
// Parse message length
|
||||
let msg_len = u32::from_be_bytes([0, data[6], data[7], data[8]]) as usize;
|
||||
if msg_len + 4 != record_len {
|
||||
return Err(ProxyError::InvalidHandshake(
|
||||
format!("Message length mismatch: {} + 4 != {}", msg_len, record_len)
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -241,4 +496,145 @@ mod tests {
|
||||
assert_eq!(key2.len(), 32);
|
||||
assert_ne!(key1, key2); // Should be random
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tls_extension_builder() {
|
||||
let key = [0x42u8; 32];
|
||||
|
||||
let mut builder = TlsExtensionBuilder::new();
|
||||
builder.add_key_share(&key);
|
||||
builder.add_supported_versions(0x0304);
|
||||
|
||||
let result = builder.build();
|
||||
|
||||
// Check length prefix
|
||||
let len = u16::from_be_bytes([result[0], result[1]]) as usize;
|
||||
assert_eq!(len, result.len() - 2);
|
||||
|
||||
// Check key_share extension is present
|
||||
assert!(result.len() > 40); // At least key share
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_hello_builder() {
|
||||
let session_id = vec![0x01, 0x02, 0x03, 0x04];
|
||||
let key = [0x55u8; 32];
|
||||
|
||||
let builder = ServerHelloBuilder::new(session_id.clone())
|
||||
.with_x25519_key(&key)
|
||||
.with_tls13_version();
|
||||
|
||||
let record = builder.build_record();
|
||||
|
||||
// Validate structure
|
||||
validate_server_hello_structure(&record).expect("Invalid ServerHello structure");
|
||||
|
||||
// Check record type
|
||||
assert_eq!(record[0], TLS_RECORD_HANDSHAKE);
|
||||
|
||||
// Check version
|
||||
assert_eq!(&record[1..3], &TLS_VERSION);
|
||||
|
||||
// Check message type (ServerHello = 0x02)
|
||||
assert_eq!(record[5], 0x02);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_server_hello_structure() {
|
||||
let secret = b"test secret";
|
||||
let client_digest = [0x42u8; 32];
|
||||
let session_id = vec![0xAA; 32];
|
||||
|
||||
let response = build_server_hello(secret, &client_digest, &session_id, 2048);
|
||||
|
||||
// Should have at least 3 records
|
||||
assert!(response.len() > 100);
|
||||
|
||||
// First record should be ServerHello
|
||||
assert_eq!(response[0], TLS_RECORD_HANDSHAKE);
|
||||
|
||||
// Validate ServerHello structure
|
||||
validate_server_hello_structure(&response).expect("Invalid ServerHello");
|
||||
|
||||
// Find Change Cipher Spec
|
||||
let server_hello_len = 5 + u16::from_be_bytes([response[3], response[4]]) as usize;
|
||||
let ccs_start = server_hello_len;
|
||||
|
||||
assert!(response.len() > ccs_start + 6);
|
||||
assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER);
|
||||
|
||||
// Find Application Data
|
||||
let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize;
|
||||
let app_start = ccs_start + ccs_len;
|
||||
|
||||
assert!(response.len() > app_start + 5);
|
||||
assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_server_hello_digest() {
|
||||
let secret = b"test secret key here";
|
||||
let client_digest = [0x42u8; 32];
|
||||
let session_id = vec![0xAA; 32];
|
||||
|
||||
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024);
|
||||
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024);
|
||||
|
||||
// Digest position should have non-zero data
|
||||
let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
|
||||
assert!(!digest1.iter().all(|&b| b == 0));
|
||||
|
||||
// Different calls should have different digests (due to random cert)
|
||||
let digest2 = &response2[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
|
||||
assert_ne!(digest1, digest2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_hello_extensions_length() {
|
||||
let session_id = vec![0x01; 32];
|
||||
let key = [0x55u8; 32];
|
||||
|
||||
let builder = ServerHelloBuilder::new(session_id)
|
||||
.with_x25519_key(&key)
|
||||
.with_tls13_version();
|
||||
|
||||
let record = builder.build_record();
|
||||
|
||||
// Parse to find extensions
|
||||
let msg_start = 5; // After record header
|
||||
let msg_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize;
|
||||
|
||||
// Skip to session ID
|
||||
let session_id_pos = msg_start + 4 + 2 + 32; // header(4) + version(2) + random(32)
|
||||
let session_id_len = record[session_id_pos] as usize;
|
||||
|
||||
// Skip to extensions
|
||||
let ext_len_pos = session_id_pos + 1 + session_id_len + 2 + 1; // session_id + cipher(2) + compression(1)
|
||||
let ext_len = u16::from_be_bytes([record[ext_len_pos], record[ext_len_pos + 1]]) as usize;
|
||||
|
||||
// Verify extensions length matches actual data
|
||||
let extensions_data = &record[ext_len_pos + 2..msg_start + 4 + msg_len];
|
||||
assert_eq!(ext_len, extensions_data.len(),
|
||||
"Extension length mismatch: declared {}, actual {}", ext_len, extensions_data.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_tls_handshake_format() {
|
||||
// Build a minimal ClientHello-like structure
|
||||
let mut handshake = vec![0u8; 100];
|
||||
|
||||
// Put a valid-looking digest at position 11
|
||||
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
|
||||
.copy_from_slice(&[0x42; 32]);
|
||||
|
||||
// Session ID length
|
||||
handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32;
|
||||
|
||||
// This won't validate (wrong HMAC) but shouldn't panic
|
||||
let secrets = vec![("test".to_string(), b"secret".to_vec())];
|
||||
let result = validate_tls_handshake(&handshake, &secrets, true);
|
||||
|
||||
// Should return None (no match) but not panic
|
||||
assert!(result.is_none());
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,7 @@ use crate::error::{ProxyError, Result, HandshakeResult};
|
||||
use crate::protocol::constants::*;
|
||||
use crate::protocol::tls;
|
||||
use crate::stats::{Stats, ReplayChecker};
|
||||
use crate::transport::{ConnectionPool, configure_client_socket};
|
||||
use crate::transport::{configure_client_socket, UpstreamManager};
|
||||
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter};
|
||||
use crate::crypto::AesCtr;
|
||||
|
||||
@@ -24,39 +24,54 @@ use super::handshake::{
|
||||
use super::relay::relay_bidirectional;
|
||||
use super::masking::handle_bad_client;
|
||||
|
||||
/// Client connection handler
|
||||
pub struct ClientHandler {
|
||||
/// Client connection handler (builder struct)
|
||||
pub struct ClientHandler;
|
||||
|
||||
/// Running client handler with stream and context
|
||||
pub struct RunningClientHandler {
|
||||
stream: TcpStream,
|
||||
peer: SocketAddr,
|
||||
config: Arc<ProxyConfig>,
|
||||
stats: Arc<Stats>,
|
||||
replay_checker: Arc<ReplayChecker>,
|
||||
pool: Arc<ConnectionPool>,
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
}
|
||||
|
||||
impl ClientHandler {
|
||||
/// Create new client handler
|
||||
/// Create new client handler instance
|
||||
pub fn new(
|
||||
stream: TcpStream,
|
||||
peer: SocketAddr,
|
||||
config: Arc<ProxyConfig>,
|
||||
stats: Arc<Stats>,
|
||||
replay_checker: Arc<ReplayChecker>,
|
||||
pool: Arc<ConnectionPool>,
|
||||
) -> Self {
|
||||
Self {
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
replay_checker: Arc<ReplayChecker>, // CHANGED: Accept global checker
|
||||
) -> RunningClientHandler {
|
||||
// CHANGED: Removed local creation of ReplayChecker.
|
||||
// It is now passed from main.rs to ensure global replay protection.
|
||||
|
||||
RunningClientHandler {
|
||||
stream,
|
||||
peer,
|
||||
config,
|
||||
stats,
|
||||
replay_checker,
|
||||
pool,
|
||||
upstream_manager,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a client connection
|
||||
pub async fn handle(&self, stream: TcpStream, peer: SocketAddr) {
|
||||
impl RunningClientHandler {
|
||||
/// Run the client handler
|
||||
pub async fn run(mut self) -> Result<()> {
|
||||
self.stats.increment_connects_all();
|
||||
|
||||
let peer = self.peer;
|
||||
debug!(peer = %peer, "New connection");
|
||||
|
||||
// Configure socket
|
||||
if let Err(e) = configure_client_socket(
|
||||
&stream,
|
||||
&self.stream,
|
||||
self.config.client_keepalive,
|
||||
self.config.client_ack_timeout,
|
||||
) {
|
||||
@@ -66,49 +81,56 @@ impl ClientHandler {
|
||||
// Perform handshake with timeout
|
||||
let handshake_timeout = Duration::from_secs(self.config.client_handshake_timeout);
|
||||
|
||||
// Clone stats for error handling block
|
||||
let stats = self.stats.clone();
|
||||
|
||||
let result = timeout(
|
||||
handshake_timeout,
|
||||
self.do_handshake(stream, peer)
|
||||
self.do_handshake()
|
||||
).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => {
|
||||
debug!(peer = %peer, "Connection handled successfully");
|
||||
Ok(())
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
debug!(peer = %peer, error = %e, "Handshake failed");
|
||||
Err(e)
|
||||
}
|
||||
Err(_) => {
|
||||
self.stats.increment_handshake_timeouts();
|
||||
stats.increment_handshake_timeouts();
|
||||
debug!(peer = %peer, "Handshake timeout");
|
||||
Err(ProxyError::TgHandshakeTimeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform handshake and relay
|
||||
async fn do_handshake(&self, mut stream: TcpStream, peer: SocketAddr) -> Result<()> {
|
||||
async fn do_handshake(mut self) -> Result<()> {
|
||||
// Read first bytes to determine handshake type
|
||||
let mut first_bytes = [0u8; 5];
|
||||
stream.read_exact(&mut first_bytes).await?;
|
||||
self.stream.read_exact(&mut first_bytes).await?;
|
||||
|
||||
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
||||
let peer = self.peer;
|
||||
|
||||
debug!(peer = %peer, is_tls = is_tls, first_bytes = %hex::encode(&first_bytes), "Handshake type detected");
|
||||
|
||||
if is_tls {
|
||||
self.handle_tls_client(stream, peer, first_bytes).await
|
||||
self.handle_tls_client(first_bytes).await
|
||||
} else {
|
||||
self.handle_direct_client(stream, peer, first_bytes).await
|
||||
self.handle_direct_client(first_bytes).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle TLS-wrapped client
|
||||
async fn handle_tls_client(
|
||||
&self,
|
||||
mut stream: TcpStream,
|
||||
peer: SocketAddr,
|
||||
mut self,
|
||||
first_bytes: [u8; 5],
|
||||
) -> Result<()> {
|
||||
let peer = self.peer;
|
||||
|
||||
// Read TLS handshake length
|
||||
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
||||
|
||||
@@ -117,17 +139,22 @@ impl ClientHandler {
|
||||
if tls_len < 512 {
|
||||
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
|
||||
self.stats.increment_connects_bad();
|
||||
handle_bad_client(stream, &first_bytes, &self.config).await;
|
||||
handle_bad_client(self.stream, &first_bytes, &self.config).await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Read full TLS handshake
|
||||
let mut handshake = vec![0u8; 5 + tls_len];
|
||||
handshake[..5].copy_from_slice(&first_bytes);
|
||||
stream.read_exact(&mut handshake[5..]).await?;
|
||||
self.stream.read_exact(&mut handshake[5..]).await?;
|
||||
|
||||
// Extract fields before consuming self.stream
|
||||
let config = self.config.clone();
|
||||
let replay_checker = self.replay_checker.clone();
|
||||
let stats = self.stats.clone();
|
||||
|
||||
// Split stream for reading/writing
|
||||
let (read_half, write_half) = stream.into_split();
|
||||
let (read_half, write_half) = self.stream.into_split();
|
||||
|
||||
// Handle TLS handshake
|
||||
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
|
||||
@@ -135,12 +162,12 @@ impl ClientHandler {
|
||||
read_half,
|
||||
write_half,
|
||||
peer,
|
||||
&self.config,
|
||||
&self.replay_checker,
|
||||
&config,
|
||||
&replay_checker,
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
HandshakeResult::BadClient => {
|
||||
self.stats.increment_connects_bad();
|
||||
stats.increment_connects_bad();
|
||||
return Ok(());
|
||||
}
|
||||
HandshakeResult::Error(e) => return Err(e),
|
||||
@@ -158,44 +185,62 @@ impl ClientHandler {
|
||||
tls_reader,
|
||||
tls_writer,
|
||||
peer,
|
||||
&self.config,
|
||||
&self.replay_checker,
|
||||
&config,
|
||||
&replay_checker,
|
||||
true,
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
HandshakeResult::BadClient => {
|
||||
self.stats.increment_connects_bad();
|
||||
stats.increment_connects_bad();
|
||||
return Ok(());
|
||||
}
|
||||
HandshakeResult::Error(e) => return Err(e),
|
||||
};
|
||||
|
||||
// Handle authenticated client
|
||||
self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await
|
||||
// We can't use self.handle_authenticated_inner because self is partially moved
|
||||
// So we call it as an associated function or method on a new struct,
|
||||
// or just inline the logic / use a static method.
|
||||
// Since handle_authenticated_inner needs self.upstream_manager and self.stats,
|
||||
// we should pass them explicitly.
|
||||
|
||||
Self::handle_authenticated_static(
|
||||
crypto_reader,
|
||||
crypto_writer,
|
||||
success,
|
||||
self.upstream_manager,
|
||||
self.stats,
|
||||
self.config
|
||||
).await
|
||||
}
|
||||
|
||||
/// Handle direct (non-TLS) client
|
||||
async fn handle_direct_client(
|
||||
&self,
|
||||
mut stream: TcpStream,
|
||||
peer: SocketAddr,
|
||||
mut self,
|
||||
first_bytes: [u8; 5],
|
||||
) -> Result<()> {
|
||||
let peer = self.peer;
|
||||
|
||||
// Check if non-TLS modes are enabled
|
||||
if !self.config.modes.classic && !self.config.modes.secure {
|
||||
debug!(peer = %peer, "Non-TLS modes disabled");
|
||||
self.stats.increment_connects_bad();
|
||||
handle_bad_client(stream, &first_bytes, &self.config).await;
|
||||
handle_bad_client(self.stream, &first_bytes, &self.config).await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Read rest of handshake
|
||||
let mut handshake = [0u8; HANDSHAKE_LEN];
|
||||
handshake[..5].copy_from_slice(&first_bytes);
|
||||
stream.read_exact(&mut handshake[5..]).await?;
|
||||
self.stream.read_exact(&mut handshake[5..]).await?;
|
||||
|
||||
// Extract fields
|
||||
let config = self.config.clone();
|
||||
let replay_checker = self.replay_checker.clone();
|
||||
let stats = self.stats.clone();
|
||||
|
||||
// Split stream
|
||||
let (read_half, write_half) = stream.into_split();
|
||||
let (read_half, write_half) = self.stream.into_split();
|
||||
|
||||
// Handle MTProto handshake
|
||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||
@@ -203,27 +248,36 @@ impl ClientHandler {
|
||||
read_half,
|
||||
write_half,
|
||||
peer,
|
||||
&self.config,
|
||||
&self.replay_checker,
|
||||
&config,
|
||||
&replay_checker,
|
||||
false,
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
HandshakeResult::BadClient => {
|
||||
self.stats.increment_connects_bad();
|
||||
stats.increment_connects_bad();
|
||||
return Ok(());
|
||||
}
|
||||
HandshakeResult::Error(e) => return Err(e),
|
||||
};
|
||||
|
||||
self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await
|
||||
Self::handle_authenticated_static(
|
||||
crypto_reader,
|
||||
crypto_writer,
|
||||
success,
|
||||
self.upstream_manager,
|
||||
self.stats,
|
||||
self.config
|
||||
).await
|
||||
}
|
||||
|
||||
/// Handle authenticated client - connect to Telegram and relay
|
||||
async fn handle_authenticated_inner<R, W>(
|
||||
&self,
|
||||
/// Static version of handle_authenticated_inner to avoid ownership issues
|
||||
async fn handle_authenticated_static<R, W>(
|
||||
client_reader: CryptoReader<R>,
|
||||
client_writer: CryptoWriter<W>,
|
||||
success: HandshakeSuccess,
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
stats: Arc<Stats>,
|
||||
config: Arc<ProxyConfig>,
|
||||
) -> Result<()>
|
||||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
@@ -232,13 +286,13 @@ impl ClientHandler {
|
||||
let user = &success.user;
|
||||
|
||||
// Check user limits
|
||||
if let Err(e) = self.check_user_limits(user) {
|
||||
if let Err(e) = Self::check_user_limits_static(user, &config, &stats) {
|
||||
warn!(user = %user, error = %e, "User limit exceeded");
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
// Get datacenter address
|
||||
let dc_addr = self.get_dc_addr(success.dc_idx)?;
|
||||
let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?;
|
||||
|
||||
info!(
|
||||
user = %user,
|
||||
@@ -246,39 +300,40 @@ impl ClientHandler {
|
||||
dc = success.dc_idx,
|
||||
dc_addr = %dc_addr,
|
||||
proto = ?success.proto_tag,
|
||||
fast_mode = self.config.fast_mode,
|
||||
fast_mode = config.fast_mode,
|
||||
"Connecting to Telegram"
|
||||
);
|
||||
|
||||
// Connect to Telegram
|
||||
let tg_stream = self.pool.get(dc_addr).await?;
|
||||
// Connect to Telegram via UpstreamManager
|
||||
let tg_stream = upstream_manager.connect(dc_addr).await?;
|
||||
|
||||
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected to Telegram, performing handshake");
|
||||
|
||||
// Perform Telegram handshake and get crypto streams
|
||||
let (tg_reader, tg_writer) = self.do_tg_handshake(
|
||||
let (tg_reader, tg_writer) = Self::do_tg_handshake_static(
|
||||
tg_stream,
|
||||
&success,
|
||||
&config,
|
||||
).await?;
|
||||
|
||||
debug!(peer = %success.peer, "Telegram handshake complete, starting relay");
|
||||
|
||||
// Update stats
|
||||
self.stats.increment_user_connects(user);
|
||||
self.stats.increment_user_curr_connects(user);
|
||||
stats.increment_user_connects(user);
|
||||
stats.increment_user_curr_connects(user);
|
||||
|
||||
// Relay traffic - передаём Arc::clone(&self.stats)
|
||||
// Relay traffic
|
||||
let relay_result = relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
tg_reader,
|
||||
tg_writer,
|
||||
user,
|
||||
Arc::clone(&self.stats),
|
||||
Arc::clone(&stats),
|
||||
).await;
|
||||
|
||||
// Update stats
|
||||
self.stats.decrement_user_curr_connects(user);
|
||||
stats.decrement_user_curr_connects(user);
|
||||
|
||||
match &relay_result {
|
||||
Ok(()) => debug!(user = %user, peer = %success.peer, "Relay completed normally"),
|
||||
@@ -288,26 +343,26 @@ impl ClientHandler {
|
||||
relay_result
|
||||
}
|
||||
|
||||
/// Check user limits (expiration, connection count, data quota)
|
||||
fn check_user_limits(&self, user: &str) -> Result<()> {
|
||||
/// Check user limits (static version)
|
||||
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
|
||||
// Check expiration
|
||||
if let Some(expiration) = self.config.user_expirations.get(user) {
|
||||
if let Some(expiration) = config.user_expirations.get(user) {
|
||||
if chrono::Utc::now() > *expiration {
|
||||
return Err(ProxyError::UserExpired { user: user.to_string() });
|
||||
}
|
||||
}
|
||||
|
||||
// Check connection limit
|
||||
if let Some(limit) = self.config.user_max_tcp_conns.get(user) {
|
||||
let current = self.stats.get_user_curr_connects(user);
|
||||
if let Some(limit) = config.user_max_tcp_conns.get(user) {
|
||||
let current = stats.get_user_curr_connects(user);
|
||||
if current >= *limit as u64 {
|
||||
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
|
||||
}
|
||||
}
|
||||
|
||||
// Check data quota
|
||||
if let Some(quota) = self.config.user_data_quota.get(user) {
|
||||
let used = self.stats.get_user_total_octets(user);
|
||||
if let Some(quota) = config.user_data_quota.get(user) {
|
||||
let used = stats.get_user_total_octets(user);
|
||||
if used >= *quota {
|
||||
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
|
||||
}
|
||||
@@ -316,11 +371,11 @@ impl ClientHandler {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get datacenter address by index
|
||||
fn get_dc_addr(&self, dc_idx: i16) -> Result<SocketAddr> {
|
||||
/// Get datacenter address by index (static version)
|
||||
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
||||
let idx = (dc_idx.abs() - 1) as usize;
|
||||
|
||||
let datacenters = if self.config.prefer_ipv6 {
|
||||
let datacenters = if config.prefer_ipv6 {
|
||||
&*TG_DATACENTERS_V6
|
||||
} else {
|
||||
&*TG_DATACENTERS_V4
|
||||
@@ -333,19 +388,18 @@ impl ClientHandler {
|
||||
))
|
||||
}
|
||||
|
||||
/// Perform handshake with Telegram server
|
||||
/// Returns crypto reader and writer for TG connection
|
||||
async fn do_tg_handshake(
|
||||
&self,
|
||||
/// Perform handshake with Telegram server (static version)
|
||||
async fn do_tg_handshake_static(
|
||||
mut stream: TcpStream,
|
||||
success: &HandshakeSuccess,
|
||||
config: &ProxyConfig,
|
||||
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
|
||||
// Generate nonce with keys for TG
|
||||
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
|
||||
success.proto_tag,
|
||||
&success.dec_key, // Client's dec key
|
||||
success.dec_iv,
|
||||
self.config.fast_mode,
|
||||
config.fast_mode,
|
||||
);
|
||||
|
||||
// Encrypt nonce
|
||||
|
||||
@@ -13,7 +13,7 @@ const MASK_BUFFER_SIZE: usize = 8192;
|
||||
|
||||
/// Handle a bad client by forwarding to mask host
|
||||
pub async fn handle_bad_client(
|
||||
mut client: TcpStream,
|
||||
client: TcpStream,
|
||||
initial_data: &[u8],
|
||||
config: &ProxyConfig,
|
||||
) {
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
//! Bidirectional Relay
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||
use tracing::{debug, trace, warn};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, trace, warn, info};
|
||||
use crate::error::Result;
|
||||
use crate::stats::Stats;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
const BUFFER_SIZE: usize = 65536;
|
||||
// CHANGED: Reduced from 128KB to 16KB to match TLS record size and prevent bufferbloat.
|
||||
// This is critical for iOS clients to maintain proper TCP flow control during uploads.
|
||||
const BUFFER_SIZE: usize = 16384;
|
||||
|
||||
// Activity timeout for iOS compatibility (30 minutes)
|
||||
// iOS does not support TCP_USER_TIMEOUT, so we implement application-level timeout
|
||||
const ACTIVITY_TIMEOUT_SECS: u64 = 1800;
|
||||
|
||||
/// Relay data bidirectionally between client and server
|
||||
pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
||||
@@ -36,15 +44,40 @@ where
|
||||
let c2s_bytes_clone = Arc::clone(&c2s_bytes);
|
||||
let s2c_bytes_clone = Arc::clone(&s2c_bytes);
|
||||
|
||||
// Client -> Server task
|
||||
// Activity timeout for iOS compatibility
|
||||
let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS);
|
||||
|
||||
// Client -> Server task with activity timeout
|
||||
let c2s = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; BUFFER_SIZE];
|
||||
let mut total_bytes = 0u64;
|
||||
let mut msg_count = 0u64;
|
||||
let mut last_activity = Instant::now();
|
||||
let mut last_log = Instant::now();
|
||||
|
||||
loop {
|
||||
match client_reader.read(&mut buf).await {
|
||||
Ok(0) => {
|
||||
// Read with timeout to prevent infinite hang on iOS
|
||||
let read_result = tokio::time::timeout(
|
||||
activity_timeout,
|
||||
client_reader.read(&mut buf)
|
||||
).await;
|
||||
|
||||
match read_result {
|
||||
// Timeout - no activity for too long
|
||||
Err(_) => {
|
||||
warn!(
|
||||
user = %user_c2s,
|
||||
total_bytes = total_bytes,
|
||||
msgs = msg_count,
|
||||
idle_secs = last_activity.elapsed().as_secs(),
|
||||
"Activity timeout (C->S) - no data received"
|
||||
);
|
||||
let _ = server_writer.shutdown().await;
|
||||
break;
|
||||
}
|
||||
|
||||
// Read successful
|
||||
Ok(Ok(0)) => {
|
||||
debug!(
|
||||
user = %user_c2s,
|
||||
total_bytes = total_bytes,
|
||||
@@ -54,9 +87,11 @@ where
|
||||
let _ = server_writer.shutdown().await;
|
||||
break;
|
||||
}
|
||||
Ok(n) => {
|
||||
|
||||
Ok(Ok(n)) => {
|
||||
total_bytes += n as u64;
|
||||
msg_count += 1;
|
||||
last_activity = Instant::now();
|
||||
c2s_bytes_clone.store(total_bytes, Ordering::Relaxed);
|
||||
|
||||
stats_c2s.add_user_octets_from(&user_c2s, n as u64);
|
||||
@@ -70,6 +105,19 @@ where
|
||||
"C->S data"
|
||||
);
|
||||
|
||||
// Log activity every 10 seconds for large transfers
|
||||
if last_log.elapsed() > Duration::from_secs(10) {
|
||||
let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64();
|
||||
info!(
|
||||
user = %user_c2s,
|
||||
total_bytes = total_bytes,
|
||||
msgs = msg_count,
|
||||
rate_kbps = (rate / 1024.0) as u64,
|
||||
"C->S transfer in progress"
|
||||
);
|
||||
last_log = Instant::now();
|
||||
}
|
||||
|
||||
if let Err(e) = server_writer.write_all(&buf[..n]).await {
|
||||
debug!(user = %user_c2s, error = %e, "Failed to write to server");
|
||||
break;
|
||||
@@ -79,7 +127,8 @@ where
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
|
||||
Ok(Err(e)) => {
|
||||
debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error");
|
||||
break;
|
||||
}
|
||||
@@ -87,15 +136,37 @@ where
|
||||
}
|
||||
});
|
||||
|
||||
// Server -> Client task
|
||||
// Server -> Client task with activity timeout
|
||||
let s2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; BUFFER_SIZE];
|
||||
let mut total_bytes = 0u64;
|
||||
let mut msg_count = 0u64;
|
||||
let mut last_activity = Instant::now();
|
||||
let mut last_log = Instant::now();
|
||||
|
||||
loop {
|
||||
match server_reader.read(&mut buf).await {
|
||||
Ok(0) => {
|
||||
// Read with timeout to prevent infinite hang on iOS
|
||||
let read_result = tokio::time::timeout(
|
||||
activity_timeout,
|
||||
server_reader.read(&mut buf)
|
||||
).await;
|
||||
|
||||
match read_result {
|
||||
// Timeout - no activity for too long
|
||||
Err(_) => {
|
||||
warn!(
|
||||
user = %user_s2c,
|
||||
total_bytes = total_bytes,
|
||||
msgs = msg_count,
|
||||
idle_secs = last_activity.elapsed().as_secs(),
|
||||
"Activity timeout (S->C) - no data received"
|
||||
);
|
||||
let _ = client_writer.shutdown().await;
|
||||
break;
|
||||
}
|
||||
|
||||
// Read successful
|
||||
Ok(Ok(0)) => {
|
||||
debug!(
|
||||
user = %user_s2c,
|
||||
total_bytes = total_bytes,
|
||||
@@ -105,9 +176,11 @@ where
|
||||
let _ = client_writer.shutdown().await;
|
||||
break;
|
||||
}
|
||||
Ok(n) => {
|
||||
|
||||
Ok(Ok(n)) => {
|
||||
total_bytes += n as u64;
|
||||
msg_count += 1;
|
||||
last_activity = Instant::now();
|
||||
s2c_bytes_clone.store(total_bytes, Ordering::Relaxed);
|
||||
|
||||
stats_s2c.add_user_octets_to(&user_s2c, n as u64);
|
||||
@@ -121,6 +194,19 @@ where
|
||||
"S->C data"
|
||||
);
|
||||
|
||||
// Log activity every 10 seconds for large transfers
|
||||
if last_log.elapsed() > Duration::from_secs(10) {
|
||||
let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64();
|
||||
info!(
|
||||
user = %user_s2c,
|
||||
total_bytes = total_bytes,
|
||||
msgs = msg_count,
|
||||
rate_kbps = (rate / 1024.0) as u64,
|
||||
"S->C transfer in progress"
|
||||
);
|
||||
last_log = Instant::now();
|
||||
}
|
||||
|
||||
if let Err(e) = client_writer.write_all(&buf[..n]).await {
|
||||
debug!(user = %user_s2c, error = %e, "Failed to write to client");
|
||||
break;
|
||||
@@ -130,7 +216,8 @@ where
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
|
||||
Ok(Err(e)) => {
|
||||
debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error");
|
||||
break;
|
||||
}
|
||||
|
||||
451
src/stream/buffer_pool.rs
Normal file
451
src/stream/buffer_pool.rs
Normal file
@@ -0,0 +1,451 @@
|
||||
//! Reusable buffer pool to avoid allocations in hot paths
|
||||
//!
|
||||
//! This module provides a thread-safe pool of BytesMut buffers
|
||||
//! that can be reused across connections to reduce allocation pressure.
|
||||
|
||||
use bytes::BytesMut;
|
||||
use crossbeam_queue::ArrayQueue;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
// ============= Configuration =============
|
||||
|
||||
/// Default buffer size
|
||||
/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and prevent bufferbloat.
|
||||
pub const DEFAULT_BUFFER_SIZE: usize = 16 * 1024;
|
||||
|
||||
/// Default maximum number of pooled buffers
|
||||
pub const DEFAULT_MAX_BUFFERS: usize = 1024;
|
||||
|
||||
// ============= Buffer Pool =============
|
||||
|
||||
/// Thread-safe pool of reusable buffers
|
||||
pub struct BufferPool {
|
||||
/// Queue of available buffers
|
||||
buffers: ArrayQueue<BytesMut>,
|
||||
/// Size of each buffer
|
||||
buffer_size: usize,
|
||||
/// Maximum number of buffers to pool
|
||||
max_buffers: usize,
|
||||
/// Total allocated buffers (including in-use)
|
||||
allocated: AtomicUsize,
|
||||
/// Number of times we had to create a new buffer
|
||||
misses: AtomicUsize,
|
||||
/// Number of successful reuses
|
||||
hits: AtomicUsize,
|
||||
}
|
||||
|
||||
impl BufferPool {
|
||||
/// Create a new buffer pool with default settings
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(DEFAULT_BUFFER_SIZE, DEFAULT_MAX_BUFFERS)
|
||||
}
|
||||
|
||||
/// Create a buffer pool with custom configuration
|
||||
pub fn with_config(buffer_size: usize, max_buffers: usize) -> Self {
|
||||
Self {
|
||||
buffers: ArrayQueue::new(max_buffers),
|
||||
buffer_size,
|
||||
max_buffers,
|
||||
allocated: AtomicUsize::new(0),
|
||||
misses: AtomicUsize::new(0),
|
||||
hits: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a buffer from the pool, or create a new one if empty
|
||||
pub fn get(self: &Arc<Self>) -> PooledBuffer {
|
||||
match self.buffers.pop() {
|
||||
Some(mut buffer) => {
|
||||
self.hits.fetch_add(1, Ordering::Relaxed);
|
||||
buffer.clear();
|
||||
PooledBuffer {
|
||||
buffer: Some(buffer),
|
||||
pool: Arc::clone(self),
|
||||
}
|
||||
}
|
||||
None => {
|
||||
self.misses.fetch_add(1, Ordering::Relaxed);
|
||||
self.allocated.fetch_add(1, Ordering::Relaxed);
|
||||
PooledBuffer {
|
||||
buffer: Some(BytesMut::with_capacity(self.buffer_size)),
|
||||
pool: Arc::clone(self),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to get a buffer, returns None if pool is empty
|
||||
pub fn try_get(self: &Arc<Self>) -> Option<PooledBuffer> {
|
||||
self.buffers.pop().map(|mut buffer| {
|
||||
self.hits.fetch_add(1, Ordering::Relaxed);
|
||||
buffer.clear();
|
||||
PooledBuffer {
|
||||
buffer: Some(buffer),
|
||||
pool: Arc::clone(self),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Return a buffer to the pool
|
||||
fn return_buffer(&self, mut buffer: BytesMut) {
|
||||
// Clear the buffer but keep capacity
|
||||
buffer.clear();
|
||||
|
||||
// Only return if we haven't exceeded max and buffer is right size
|
||||
if buffer.capacity() >= self.buffer_size {
|
||||
// Try to push to pool, if full just drop
|
||||
let _ = self.buffers.push(buffer);
|
||||
}
|
||||
// If buffer was dropped (pool full), decrement allocated
|
||||
// Actually we don't decrement here because the buffer might have been
|
||||
// grown beyond our size - we just let it go
|
||||
}
|
||||
|
||||
/// Get pool statistics
|
||||
pub fn stats(&self) -> PoolStats {
|
||||
PoolStats {
|
||||
pooled: self.buffers.len(),
|
||||
allocated: self.allocated.load(Ordering::Relaxed),
|
||||
max_buffers: self.max_buffers,
|
||||
buffer_size: self.buffer_size,
|
||||
hits: self.hits.load(Ordering::Relaxed),
|
||||
misses: self.misses.load(Ordering::Relaxed),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get buffer size
|
||||
pub fn buffer_size(&self) -> usize {
|
||||
self.buffer_size
|
||||
}
|
||||
|
||||
/// Preallocate buffers to fill the pool
|
||||
pub fn preallocate(&self, count: usize) {
|
||||
let to_alloc = count.min(self.max_buffers);
|
||||
for _ in 0..to_alloc {
|
||||
if self.buffers.push(BytesMut::with_capacity(self.buffer_size)).is_err() {
|
||||
break;
|
||||
}
|
||||
self.allocated.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BufferPool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Pool Statistics =============
|
||||
|
||||
/// Statistics about buffer pool usage
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PoolStats {
|
||||
/// Current number of buffers in pool
|
||||
pub pooled: usize,
|
||||
/// Total buffers allocated (in-use + pooled)
|
||||
pub allocated: usize,
|
||||
/// Maximum buffers allowed
|
||||
pub max_buffers: usize,
|
||||
/// Size of each buffer
|
||||
pub buffer_size: usize,
|
||||
/// Number of cache hits (reused buffer)
|
||||
pub hits: usize,
|
||||
/// Number of cache misses (new allocation)
|
||||
pub misses: usize,
|
||||
}
|
||||
|
||||
impl PoolStats {
|
||||
/// Get hit rate as percentage
|
||||
pub fn hit_rate(&self) -> f64 {
|
||||
let total = self.hits + self.misses;
|
||||
if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(self.hits as f64 / total as f64) * 100.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Pooled Buffer =============
|
||||
|
||||
/// A buffer that automatically returns to the pool when dropped
|
||||
pub struct PooledBuffer {
|
||||
buffer: Option<BytesMut>,
|
||||
pool: Arc<BufferPool>,
|
||||
}
|
||||
|
||||
impl PooledBuffer {
|
||||
/// Take the inner buffer, preventing return to pool
|
||||
pub fn take(mut self) -> BytesMut {
|
||||
self.buffer.take().unwrap()
|
||||
}
|
||||
|
||||
/// Get the capacity of the buffer
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.buffer.as_ref().map(|b| b.capacity()).unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Check if buffer is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.buffer.as_ref().map(|b| b.is_empty()).unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Get the length of data in buffer
|
||||
pub fn len(&self) -> usize {
|
||||
self.buffer.as_ref().map(|b| b.len()).unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Clear the buffer
|
||||
pub fn clear(&mut self) {
|
||||
if let Some(ref mut b) = self.buffer {
|
||||
b.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for PooledBuffer {
|
||||
type Target = BytesMut;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.buffer.as_ref().expect("buffer taken")
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for PooledBuffer {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.buffer.as_mut().expect("buffer taken")
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PooledBuffer {
|
||||
fn drop(&mut self) {
|
||||
if let Some(buffer) = self.buffer.take() {
|
||||
self.pool.return_buffer(buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for PooledBuffer {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
self.buffer.as_ref().map(|b| b.as_ref()).unwrap_or(&[])
|
||||
}
|
||||
}
|
||||
|
||||
impl AsMut<[u8]> for PooledBuffer {
|
||||
fn as_mut(&mut self) -> &mut [u8] {
|
||||
self.buffer.as_mut().map(|b| b.as_mut()).unwrap_or(&mut [])
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Scoped Buffer =============
|
||||
|
||||
/// A buffer that can be used for a scoped operation
|
||||
/// Useful for ensuring buffer is returned even on early return
|
||||
pub struct ScopedBuffer<'a> {
|
||||
buffer: &'a mut PooledBuffer,
|
||||
}
|
||||
|
||||
impl<'a> ScopedBuffer<'a> {
|
||||
/// Create a new scoped buffer
|
||||
pub fn new(buffer: &'a mut PooledBuffer) -> Self {
|
||||
buffer.clear();
|
||||
Self { buffer }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Deref for ScopedBuffer<'a> {
|
||||
type Target = BytesMut;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.buffer.deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> DerefMut for ScopedBuffer<'a> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.buffer.deref_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for ScopedBuffer<'a> {
|
||||
fn drop(&mut self) {
|
||||
self.buffer.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pool_basic() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||
|
||||
// Get a buffer
|
||||
let mut buf1 = pool.get();
|
||||
buf1.extend_from_slice(b"hello");
|
||||
assert_eq!(&buf1[..], b"hello");
|
||||
|
||||
// Drop returns to pool
|
||||
drop(buf1);
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.pooled, 1);
|
||||
assert_eq!(stats.hits, 0);
|
||||
assert_eq!(stats.misses, 1);
|
||||
|
||||
// Get again - should reuse
|
||||
let buf2 = pool.get();
|
||||
assert!(buf2.is_empty()); // Buffer was cleared
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.pooled, 0);
|
||||
assert_eq!(stats.hits, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pool_multiple_buffers() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||
|
||||
// Get multiple buffers
|
||||
let buf1 = pool.get();
|
||||
let buf2 = pool.get();
|
||||
let buf3 = pool.get();
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.allocated, 3);
|
||||
assert_eq!(stats.pooled, 0);
|
||||
|
||||
// Return all
|
||||
drop(buf1);
|
||||
drop(buf2);
|
||||
drop(buf3);
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.pooled, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pool_overflow() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 2));
|
||||
|
||||
// Get 3 buffers (more than max)
|
||||
let buf1 = pool.get();
|
||||
let buf2 = pool.get();
|
||||
let buf3 = pool.get();
|
||||
|
||||
// Return all - only 2 should be pooled
|
||||
drop(buf1);
|
||||
drop(buf2);
|
||||
drop(buf3);
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.pooled, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pool_take() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||
|
||||
let mut buf = pool.get();
|
||||
buf.extend_from_slice(b"data");
|
||||
|
||||
// Take ownership, buffer should not return to pool
|
||||
let taken = buf.take();
|
||||
assert_eq!(&taken[..], b"data");
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.pooled, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pool_preallocate() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||
pool.preallocate(5);
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.pooled, 5);
|
||||
assert_eq!(stats.allocated, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pool_try_get() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||
|
||||
// Pool is empty, try_get returns None
|
||||
assert!(pool.try_get().is_none());
|
||||
|
||||
// Add a buffer to pool
|
||||
pool.preallocate(1);
|
||||
|
||||
// Now try_get should succeed
|
||||
assert!(pool.try_get().is_some());
|
||||
assert!(pool.try_get().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hit_rate() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||
|
||||
// First get is a miss
|
||||
let buf1 = pool.get();
|
||||
drop(buf1);
|
||||
|
||||
// Second get is a hit
|
||||
let buf2 = pool.get();
|
||||
drop(buf2);
|
||||
|
||||
// Third get is a hit
|
||||
let _buf3 = pool.get();
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.hits, 2);
|
||||
assert_eq!(stats.misses, 1);
|
||||
assert!((stats.hit_rate() - 66.67).abs() < 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scoped_buffer() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||
let mut buf = pool.get();
|
||||
|
||||
{
|
||||
let mut scoped = ScopedBuffer::new(&mut buf);
|
||||
scoped.extend_from_slice(b"scoped data");
|
||||
assert_eq!(&scoped[..], b"scoped data");
|
||||
}
|
||||
|
||||
// After scoped is dropped, buffer is cleared
|
||||
assert!(buf.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_access() {
|
||||
use std::thread;
|
||||
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 100));
|
||||
let mut handles = vec![];
|
||||
|
||||
for _ in 0..10 {
|
||||
let pool_clone = Arc::clone(&pool);
|
||||
handles.push(thread::spawn(move || {
|
||||
for _ in 0..100 {
|
||||
let mut buf = pool_clone.get();
|
||||
buf.extend_from_slice(b"test");
|
||||
// buf auto-returned on drop
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
let stats = pool.stats();
|
||||
// All buffers should be returned
|
||||
assert!(stats.pooled > 0);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
187
src/stream/frame.rs
Normal file
187
src/stream/frame.rs
Normal file
@@ -0,0 +1,187 @@
|
||||
//! MTProto frame types and traits
|
||||
//!
|
||||
//! This module defines the common types and traits used by all
|
||||
//! frame encoding/decoding implementations.
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use std::io::Result;
|
||||
|
||||
use crate::protocol::constants::ProtoTag;
|
||||
|
||||
// ============= Frame Types =============
|
||||
|
||||
/// A decoded MTProto frame
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Frame {
|
||||
/// Frame payload data
|
||||
pub data: Bytes,
|
||||
/// Frame metadata
|
||||
pub meta: FrameMeta,
|
||||
}
|
||||
|
||||
impl Frame {
|
||||
/// Create a new frame with data and default metadata
|
||||
pub fn new(data: Bytes) -> Self {
|
||||
Self {
|
||||
data,
|
||||
meta: FrameMeta::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new frame with data and metadata
|
||||
pub fn with_meta(data: Bytes, meta: FrameMeta) -> Self {
|
||||
Self { data, meta }
|
||||
}
|
||||
|
||||
/// Create an empty frame
|
||||
pub fn empty() -> Self {
|
||||
Self::new(Bytes::new())
|
||||
}
|
||||
|
||||
/// Check if frame is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
|
||||
/// Get frame length
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
/// Create a QuickAck request frame
|
||||
pub fn quickack(data: Bytes) -> Self {
|
||||
Self {
|
||||
data,
|
||||
meta: FrameMeta {
|
||||
quickack: true,
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a simple ACK frame
|
||||
pub fn simple_ack(data: Bytes) -> Self {
|
||||
Self {
|
||||
data,
|
||||
meta: FrameMeta {
|
||||
simple_ack: true,
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Frame metadata
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct FrameMeta {
|
||||
/// Quick ACK requested - client wants immediate acknowledgment
|
||||
pub quickack: bool,
|
||||
/// This is a simple ACK message (reversed data)
|
||||
pub simple_ack: bool,
|
||||
/// Original padding length (for secure mode)
|
||||
pub padding_len: u8,
|
||||
}
|
||||
|
||||
impl FrameMeta {
|
||||
/// Create new empty metadata
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Create with quickack flag
|
||||
pub fn with_quickack(mut self) -> Self {
|
||||
self.quickack = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Create with simple_ack flag
|
||||
pub fn with_simple_ack(mut self) -> Self {
|
||||
self.simple_ack = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Create with padding length
|
||||
pub fn with_padding(mut self, len: u8) -> Self {
|
||||
self.padding_len = len;
|
||||
self
|
||||
}
|
||||
|
||||
/// Check if any special flags are set
|
||||
pub fn has_flags(&self) -> bool {
|
||||
self.quickack || self.simple_ack
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Codec Trait =============
|
||||
|
||||
/// Trait for frame codecs that can encode and decode frames
|
||||
pub trait FrameCodec: Send + Sync {
|
||||
/// Get the protocol tag for this codec
|
||||
fn proto_tag(&self) -> ProtoTag;
|
||||
|
||||
/// Encode a frame into the destination buffer
|
||||
///
|
||||
/// Returns the number of bytes written.
|
||||
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> Result<usize>;
|
||||
|
||||
/// Try to decode a frame from the source buffer
|
||||
///
|
||||
/// Returns:
|
||||
/// - `Ok(Some(frame))` if a complete frame was decoded
|
||||
/// - `Ok(None)` if more data is needed
|
||||
/// - `Err(e)` if an error occurred
|
||||
///
|
||||
/// On success, the consumed bytes are removed from `src`.
|
||||
fn decode(&self, src: &mut BytesMut) -> Result<Option<Frame>>;
|
||||
|
||||
/// Get the minimum bytes needed to determine frame length
|
||||
fn min_header_size(&self) -> usize;
|
||||
|
||||
/// Get the maximum allowed frame size
|
||||
fn max_frame_size(&self) -> usize {
|
||||
// Default: 16MB
|
||||
16 * 1024 * 1024
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Codec Factory =============
|
||||
|
||||
/// Create a frame codec for the given protocol tag
|
||||
pub fn create_codec(proto_tag: ProtoTag) -> Box<dyn FrameCodec> {
|
||||
match proto_tag {
|
||||
ProtoTag::Abridged => Box::new(crate::stream::frame_codec::AbridgedCodec::new()),
|
||||
ProtoTag::Intermediate => Box::new(crate::stream::frame_codec::IntermediateCodec::new()),
|
||||
ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_frame_creation() {
|
||||
let frame = Frame::new(Bytes::from_static(b"test"));
|
||||
assert_eq!(frame.len(), 4);
|
||||
assert!(!frame.is_empty());
|
||||
assert!(!frame.meta.quickack);
|
||||
|
||||
let frame = Frame::empty();
|
||||
assert!(frame.is_empty());
|
||||
|
||||
let frame = Frame::quickack(Bytes::from_static(b"ack"));
|
||||
assert!(frame.meta.quickack);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frame_meta() {
|
||||
let meta = FrameMeta::new()
|
||||
.with_quickack()
|
||||
.with_padding(3);
|
||||
|
||||
assert!(meta.quickack);
|
||||
assert!(!meta.simple_ack);
|
||||
assert_eq!(meta.padding_len, 3);
|
||||
assert!(meta.has_flags());
|
||||
}
|
||||
}
|
||||
621
src/stream/frame_codec.rs
Normal file
621
src/stream/frame_codec.rs
Normal file
@@ -0,0 +1,621 @@
|
||||
//! tokio-util codec integration for MTProto frames
|
||||
//!
|
||||
//! This module provides Encoder/Decoder implementations compatible
|
||||
//! with tokio-util's Framed wrapper for easy async frame I/O.
|
||||
|
||||
use bytes::{Bytes, BytesMut, BufMut};
|
||||
use std::io::{self, Error, ErrorKind};
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
use crate::protocol::constants::ProtoTag;
|
||||
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
|
||||
|
||||
// ============= Unified Codec =============
|
||||
|
||||
/// Unified frame codec that wraps all protocol variants
|
||||
///
|
||||
/// This codec implements tokio-util's Encoder and Decoder traits,
|
||||
/// allowing it to be used with `Framed` for async frame I/O.
|
||||
pub struct FrameCodec {
|
||||
/// Protocol variant
|
||||
proto_tag: ProtoTag,
|
||||
/// Maximum allowed frame size
|
||||
max_frame_size: usize,
|
||||
}
|
||||
|
||||
impl FrameCodec {
|
||||
/// Create a new codec for the given protocol
|
||||
pub fn new(proto_tag: ProtoTag) -> Self {
|
||||
Self {
|
||||
proto_tag,
|
||||
max_frame_size: 16 * 1024 * 1024, // 16MB default
|
||||
}
|
||||
}
|
||||
|
||||
/// Set maximum frame size
|
||||
pub fn with_max_frame_size(mut self, size: usize) -> Self {
|
||||
self.max_frame_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Get protocol tag
|
||||
pub fn proto_tag(&self) -> ProtoTag {
|
||||
self.proto_tag
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for FrameCodec {
|
||||
type Item = Frame;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
match self.proto_tag {
|
||||
ProtoTag::Abridged => decode_abridged(src, self.max_frame_size),
|
||||
ProtoTag::Intermediate => decode_intermediate(src, self.max_frame_size),
|
||||
ProtoTag::Secure => decode_secure(src, self.max_frame_size),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<Frame> for FrameCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
match self.proto_tag {
|
||||
ProtoTag::Abridged => encode_abridged(&frame, dst),
|
||||
ProtoTag::Intermediate => encode_intermediate(&frame, dst),
|
||||
ProtoTag::Secure => encode_secure(&frame, dst),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Abridged Protocol =============
|
||||
|
||||
fn decode_abridged(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame>> {
|
||||
if src.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut meta = FrameMeta::new();
|
||||
let first_byte = src[0];
|
||||
|
||||
// Extract length and quickack flag
|
||||
let mut len_words = (first_byte & 0x7f) as usize;
|
||||
if first_byte >= 0x80 {
|
||||
meta.quickack = true;
|
||||
}
|
||||
|
||||
let header_len;
|
||||
|
||||
if len_words == 0x7f {
|
||||
// Extended length (3 more bytes needed)
|
||||
if src.len() < 4 {
|
||||
return Ok(None);
|
||||
}
|
||||
len_words = u32::from_le_bytes([src[1], src[2], src[3], 0]) as usize;
|
||||
header_len = 4;
|
||||
} else {
|
||||
header_len = 1;
|
||||
}
|
||||
|
||||
// Length is in 4-byte words
|
||||
let byte_len = len_words.checked_mul(4).ok_or_else(|| {
|
||||
Error::new(ErrorKind::InvalidData, "frame length overflow")
|
||||
})?;
|
||||
|
||||
// Validate size
|
||||
if byte_len > max_size {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("frame too large: {} bytes (max {})", byte_len, max_size)
|
||||
));
|
||||
}
|
||||
|
||||
let total_len = header_len + byte_len;
|
||||
|
||||
if src.len() < total_len {
|
||||
// Reserve space for the rest of the frame
|
||||
src.reserve(total_len - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Extract data
|
||||
let _ = src.split_to(header_len);
|
||||
let data = src.split_to(byte_len).freeze();
|
||||
|
||||
Ok(Some(Frame::with_meta(data, meta)))
|
||||
}
|
||||
|
||||
fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
||||
let data = &frame.data;
|
||||
|
||||
// Validate alignment
|
||||
if data.len() % 4 != 0 {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("abridged frame must be 4-byte aligned, got {} bytes", data.len())
|
||||
));
|
||||
}
|
||||
|
||||
// Simple ACK: send reversed data without header
|
||||
if frame.meta.simple_ack {
|
||||
dst.reserve(data.len());
|
||||
for byte in data.iter().rev() {
|
||||
dst.put_u8(*byte);
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let len_words = data.len() / 4;
|
||||
|
||||
if len_words < 0x7f {
|
||||
// Short header
|
||||
dst.reserve(1 + data.len());
|
||||
let mut len_byte = len_words as u8;
|
||||
if frame.meta.quickack {
|
||||
len_byte |= 0x80;
|
||||
}
|
||||
dst.put_u8(len_byte);
|
||||
} else if len_words < (1 << 24) {
|
||||
// Extended header
|
||||
dst.reserve(4 + data.len());
|
||||
let mut first = 0x7fu8;
|
||||
if frame.meta.quickack {
|
||||
first |= 0x80;
|
||||
}
|
||||
dst.put_u8(first);
|
||||
let len_bytes = (len_words as u32).to_le_bytes();
|
||||
dst.extend_from_slice(&len_bytes[..3]);
|
||||
} else {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("frame too large: {} bytes", data.len())
|
||||
));
|
||||
}
|
||||
|
||||
dst.extend_from_slice(data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============= Intermediate Protocol =============
|
||||
|
||||
fn decode_intermediate(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame>> {
|
||||
if src.len() < 4 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut meta = FrameMeta::new();
|
||||
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
|
||||
|
||||
// Check QuickACK flag
|
||||
if len >= 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
|
||||
// Validate size
|
||||
if len > max_size {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("frame too large: {} bytes (max {})", len, max_size)
|
||||
));
|
||||
}
|
||||
|
||||
let total_len = 4 + len;
|
||||
|
||||
if src.len() < total_len {
|
||||
src.reserve(total_len - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Extract data
|
||||
let _ = src.split_to(4);
|
||||
let data = src.split_to(len).freeze();
|
||||
|
||||
Ok(Some(Frame::with_meta(data, meta)))
|
||||
}
|
||||
|
||||
fn encode_intermediate(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
||||
let data = &frame.data;
|
||||
|
||||
// Simple ACK: just send data
|
||||
if frame.meta.simple_ack {
|
||||
dst.reserve(data.len());
|
||||
dst.extend_from_slice(data);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
dst.reserve(4 + data.len());
|
||||
|
||||
let mut len = data.len() as u32;
|
||||
if frame.meta.quickack {
|
||||
len |= 0x80000000;
|
||||
}
|
||||
|
||||
dst.extend_from_slice(&len.to_le_bytes());
|
||||
dst.extend_from_slice(data);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============= Secure Intermediate Protocol =============
|
||||
|
||||
fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame>> {
|
||||
if src.len() < 4 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut meta = FrameMeta::new();
|
||||
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
|
||||
|
||||
// Check QuickACK flag
|
||||
if len >= 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
|
||||
// Validate size
|
||||
if len > max_size {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("frame too large: {} bytes (max {})", len, max_size)
|
||||
));
|
||||
}
|
||||
|
||||
let total_len = 4 + len;
|
||||
|
||||
if src.len() < total_len {
|
||||
src.reserve(total_len - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Calculate padding (indicated by length not divisible by 4)
|
||||
let padding_len = len % 4;
|
||||
let data_len = if padding_len != 0 {
|
||||
len - padding_len
|
||||
} else {
|
||||
len
|
||||
};
|
||||
|
||||
meta.padding_len = padding_len as u8;
|
||||
|
||||
// Extract data (excluding padding)
|
||||
let _ = src.split_to(4);
|
||||
let all_data = src.split_to(len);
|
||||
// Copy only the data portion, excluding padding
|
||||
let data = Bytes::copy_from_slice(&all_data[..data_len]);
|
||||
|
||||
Ok(Some(Frame::with_meta(data, meta)))
|
||||
}
|
||||
|
||||
fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
||||
use crate::crypto::random::SECURE_RANDOM;
|
||||
|
||||
let data = &frame.data;
|
||||
|
||||
// Simple ACK: just send data
|
||||
if frame.meta.simple_ack {
|
||||
dst.reserve(data.len());
|
||||
dst.extend_from_slice(data);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Generate padding to make length not divisible by 4
|
||||
let padding_len = if data.len() % 4 == 0 {
|
||||
// Add 1-3 bytes to make it non-aligned
|
||||
(SECURE_RANDOM.range(3) + 1) as usize
|
||||
} else {
|
||||
// Already non-aligned, can add 0-3
|
||||
SECURE_RANDOM.range(4) as usize
|
||||
};
|
||||
|
||||
let total_len = data.len() + padding_len;
|
||||
dst.reserve(4 + total_len);
|
||||
|
||||
let mut len = total_len as u32;
|
||||
if frame.meta.quickack {
|
||||
len |= 0x80000000;
|
||||
}
|
||||
|
||||
dst.extend_from_slice(&len.to_le_bytes());
|
||||
dst.extend_from_slice(data);
|
||||
|
||||
if padding_len > 0 {
|
||||
let padding = SECURE_RANDOM.bytes(padding_len);
|
||||
dst.extend_from_slice(&padding);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============= Typed Codecs =============
|
||||
|
||||
/// Abridged protocol codec
|
||||
pub struct AbridgedCodec {
|
||||
max_frame_size: usize,
|
||||
}
|
||||
|
||||
impl AbridgedCodec {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
max_frame_size: 16 * 1024 * 1024,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AbridgedCodec {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for AbridgedCodec {
|
||||
type Item = Frame;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
decode_abridged(src, self.max_frame_size)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<Frame> for AbridgedCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
encode_abridged(&frame, dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl FrameCodecTrait for AbridgedCodec {
|
||||
fn proto_tag(&self) -> ProtoTag {
|
||||
ProtoTag::Abridged
|
||||
}
|
||||
|
||||
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
|
||||
let before = dst.len();
|
||||
encode_abridged(frame, dst)?;
|
||||
Ok(dst.len() - before)
|
||||
}
|
||||
|
||||
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
|
||||
decode_abridged(src, self.max_frame_size)
|
||||
}
|
||||
|
||||
fn min_header_size(&self) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
/// Intermediate protocol codec
|
||||
pub struct IntermediateCodec {
|
||||
max_frame_size: usize,
|
||||
}
|
||||
|
||||
impl IntermediateCodec {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
max_frame_size: 16 * 1024 * 1024,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for IntermediateCodec {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for IntermediateCodec {
|
||||
type Item = Frame;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
decode_intermediate(src, self.max_frame_size)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<Frame> for IntermediateCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
encode_intermediate(&frame, dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl FrameCodecTrait for IntermediateCodec {
|
||||
fn proto_tag(&self) -> ProtoTag {
|
||||
ProtoTag::Intermediate
|
||||
}
|
||||
|
||||
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
|
||||
let before = dst.len();
|
||||
encode_intermediate(frame, dst)?;
|
||||
Ok(dst.len() - before)
|
||||
}
|
||||
|
||||
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
|
||||
decode_intermediate(src, self.max_frame_size)
|
||||
}
|
||||
|
||||
fn min_header_size(&self) -> usize {
|
||||
4
|
||||
}
|
||||
}
|
||||
|
||||
/// Secure Intermediate protocol codec
|
||||
pub struct SecureCodec {
|
||||
max_frame_size: usize,
|
||||
}
|
||||
|
||||
impl SecureCodec {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
max_frame_size: 16 * 1024 * 1024,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SecureCodec {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for SecureCodec {
|
||||
type Item = Frame;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
decode_secure(src, self.max_frame_size)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<Frame> for SecureCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
encode_secure(&frame, dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl FrameCodecTrait for SecureCodec {
|
||||
fn proto_tag(&self) -> ProtoTag {
|
||||
ProtoTag::Secure
|
||||
}
|
||||
|
||||
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
|
||||
let before = dst.len();
|
||||
encode_secure(frame, dst)?;
|
||||
Ok(dst.len() - before)
|
||||
}
|
||||
|
||||
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
|
||||
decode_secure(src, self.max_frame_size)
|
||||
}
|
||||
|
||||
fn min_header_size(&self) -> usize {
|
||||
4
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Tests =============
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio_util::codec::{FramedRead, FramedWrite};
|
||||
use tokio::io::duplex;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_framed_abridged() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let mut writer = FramedWrite::new(client, AbridgedCodec::new());
|
||||
let mut reader = FramedRead::new(server, AbridgedCodec::new());
|
||||
|
||||
// Write a frame
|
||||
let frame = Frame::new(Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]));
|
||||
writer.send(frame).await.unwrap();
|
||||
|
||||
// Read it back
|
||||
let received = reader.next().await.unwrap().unwrap();
|
||||
assert_eq!(&received.data[..], &[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_framed_intermediate() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
|
||||
let mut reader = FramedRead::new(server, IntermediateCodec::new());
|
||||
|
||||
let frame = Frame::new(Bytes::from_static(b"hello world"));
|
||||
writer.send(frame).await.unwrap();
|
||||
|
||||
let received = reader.next().await.unwrap().unwrap();
|
||||
assert_eq!(&received.data[..], b"hello world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_framed_secure() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let mut writer = FramedWrite::new(client, SecureCodec::new());
|
||||
let mut reader = FramedRead::new(server, SecureCodec::new());
|
||||
|
||||
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
let frame = Frame::new(original.clone());
|
||||
writer.send(frame).await.unwrap();
|
||||
|
||||
let received = reader.next().await.unwrap().unwrap();
|
||||
assert_eq!(&received.data[..], &original[..]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unified_codec() {
|
||||
for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag));
|
||||
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag));
|
||||
|
||||
// Use 4-byte aligned data for abridged compatibility
|
||||
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
let frame = Frame::new(original.clone());
|
||||
writer.send(frame).await.unwrap();
|
||||
|
||||
let received = reader.next().await.unwrap().unwrap();
|
||||
assert_eq!(received.data.len(), 8);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_frames() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
|
||||
let mut reader = FramedRead::new(server, IntermediateCodec::new());
|
||||
|
||||
// Send multiple frames
|
||||
for i in 0..10 {
|
||||
let data: Vec<u8> = (0..((i + 1) * 10)).map(|j| (j % 256) as u8).collect();
|
||||
let frame = Frame::new(Bytes::from(data));
|
||||
writer.send(frame).await.unwrap();
|
||||
}
|
||||
|
||||
// Receive them
|
||||
for i in 0..10 {
|
||||
let received = reader.next().await.unwrap().unwrap();
|
||||
assert_eq!(received.data.len(), (i + 1) * 10);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_quickack_flag() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
|
||||
let mut reader = FramedRead::new(server, IntermediateCodec::new());
|
||||
|
||||
let frame = Frame::quickack(Bytes::from_static(b"urgent"));
|
||||
writer.send(frame).await.unwrap();
|
||||
|
||||
let received = reader.next().await.unwrap().unwrap();
|
||||
assert!(received.meta.quickack);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frame_too_large() {
|
||||
let mut codec = FrameCodec::new(ProtoTag::Intermediate)
|
||||
.with_max_frame_size(100);
|
||||
|
||||
// Create a "frame" that claims to be very large
|
||||
let mut buf = BytesMut::new();
|
||||
buf.extend_from_slice(&1000u32.to_le_bytes()); // length = 1000
|
||||
buf.extend_from_slice(&[0u8; 10]); // partial data
|
||||
|
||||
let result = codec.decode(&mut buf);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,43 @@
|
||||
//! Stream wrappers for MTProto protocol layers
|
||||
|
||||
pub mod state;
|
||||
pub mod buffer_pool;
|
||||
pub mod traits;
|
||||
pub mod crypto_stream;
|
||||
pub mod tls_stream;
|
||||
pub mod frame;
|
||||
pub mod frame_codec;
|
||||
|
||||
// Legacy compatibility - will be removed later
|
||||
pub mod frame_stream;
|
||||
|
||||
// Re-export state machine types
|
||||
pub use state::{
|
||||
StreamState, Transition, PollResult,
|
||||
ReadBuffer, WriteBuffer, HeaderBuffer, YieldBuffer,
|
||||
};
|
||||
|
||||
// Re-export buffer pool
|
||||
pub use buffer_pool::{BufferPool, PooledBuffer, PoolStats};
|
||||
|
||||
// Re-export stream implementations
|
||||
pub use crypto_stream::{CryptoReader, CryptoWriter, PassthroughStream};
|
||||
pub use tls_stream::{FakeTlsReader, FakeTlsWriter};
|
||||
pub use frame_stream::*;
|
||||
|
||||
// Re-export frame types
|
||||
pub use frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait, create_codec};
|
||||
|
||||
// Re-export tokio-util compatible codecs
|
||||
pub use frame_codec::{
|
||||
FrameCodec,
|
||||
AbridgedCodec, IntermediateCodec, SecureCodec,
|
||||
};
|
||||
|
||||
// Legacy re-exports for compatibility
|
||||
pub use frame_stream::{
|
||||
AbridgedFrameReader, AbridgedFrameWriter,
|
||||
IntermediateFrameReader, IntermediateFrameWriter,
|
||||
SecureIntermediateFrameReader, SecureIntermediateFrameWriter,
|
||||
MtprotoFrameReader, MtprotoFrameWriter,
|
||||
FrameReaderKind, FrameWriterKind,
|
||||
};
|
||||
571
src/stream/state.rs
Normal file
571
src/stream/state.rs
Normal file
@@ -0,0 +1,571 @@
|
||||
//! State machine foundation types for async streams
|
||||
//!
|
||||
//! This module provides core types and traits for implementing
|
||||
//! stateful async streams with proper partial read/write handling.
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use std::io;
|
||||
|
||||
// ============= Core Traits =============
|
||||
|
||||
/// Trait for stream states
|
||||
pub trait StreamState: Sized {
|
||||
/// Check if this is a terminal state (no more transitions possible)
|
||||
fn is_terminal(&self) -> bool;
|
||||
|
||||
/// Check if stream is in poisoned/error state
|
||||
fn is_poisoned(&self) -> bool;
|
||||
|
||||
/// Get human-readable state name for debugging
|
||||
fn state_name(&self) -> &'static str;
|
||||
}
|
||||
|
||||
// ============= Transition Types =============
|
||||
|
||||
/// Result of a state transition
|
||||
#[derive(Debug)]
|
||||
pub enum Transition<S, O> {
|
||||
/// Stay in the same state, no output
|
||||
Same,
|
||||
/// Transition to a new state, no output
|
||||
Next(S),
|
||||
/// Complete with output, typically transitions to Idle
|
||||
Complete(O),
|
||||
/// Yield output and transition to new state
|
||||
Yield(O, S),
|
||||
/// Error occurred, transition to error state
|
||||
Error(io::Error),
|
||||
}
|
||||
|
||||
impl<S, O> Transition<S, O> {
|
||||
/// Check if transition produces output
|
||||
pub fn has_output(&self) -> bool {
|
||||
matches!(self, Transition::Complete(_) | Transition::Yield(_, _))
|
||||
}
|
||||
|
||||
/// Map the output value
|
||||
pub fn map_output<U, F: FnOnce(O) -> U>(self, f: F) -> Transition<S, U> {
|
||||
match self {
|
||||
Transition::Same => Transition::Same,
|
||||
Transition::Next(s) => Transition::Next(s),
|
||||
Transition::Complete(o) => Transition::Complete(f(o)),
|
||||
Transition::Yield(o, s) => Transition::Yield(f(o), s),
|
||||
Transition::Error(e) => Transition::Error(e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Map the state value
|
||||
pub fn map_state<T, F: FnOnce(S) -> T>(self, f: F) -> Transition<T, O> {
|
||||
match self {
|
||||
Transition::Same => Transition::Same,
|
||||
Transition::Next(s) => Transition::Next(f(s)),
|
||||
Transition::Complete(o) => Transition::Complete(o),
|
||||
Transition::Yield(o, s) => Transition::Yield(o, f(s)),
|
||||
Transition::Error(e) => Transition::Error(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Poll Result Types =============
|
||||
|
||||
/// Result of polling for more data
|
||||
#[derive(Debug)]
|
||||
pub enum PollResult<T> {
|
||||
/// Data is ready
|
||||
Ready(T),
|
||||
/// Operation would block, need to poll again
|
||||
Pending,
|
||||
/// Need more input data (minimum bytes required)
|
||||
NeedInput(usize),
|
||||
/// End of stream reached
|
||||
Eof,
|
||||
/// Error occurred
|
||||
Error(io::Error),
|
||||
}
|
||||
|
||||
impl<T> PollResult<T> {
|
||||
/// Check if result is ready
|
||||
pub fn is_ready(&self) -> bool {
|
||||
matches!(self, PollResult::Ready(_))
|
||||
}
|
||||
|
||||
/// Check if result indicates EOF
|
||||
pub fn is_eof(&self) -> bool {
|
||||
matches!(self, PollResult::Eof)
|
||||
}
|
||||
|
||||
/// Convert to Option, discarding non-ready states
|
||||
pub fn ok(self) -> Option<T> {
|
||||
match self {
|
||||
PollResult::Ready(t) => Some(t),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Map the value
|
||||
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> PollResult<U> {
|
||||
match self {
|
||||
PollResult::Ready(t) => PollResult::Ready(f(t)),
|
||||
PollResult::Pending => PollResult::Pending,
|
||||
PollResult::NeedInput(n) => PollResult::NeedInput(n),
|
||||
PollResult::Eof => PollResult::Eof,
|
||||
PollResult::Error(e) => PollResult::Error(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<io::Result<T>> for PollResult<T> {
|
||||
fn from(result: io::Result<T>) -> Self {
|
||||
match result {
|
||||
Ok(t) => PollResult::Ready(t),
|
||||
Err(e) if e.kind() == io::ErrorKind::WouldBlock => PollResult::Pending,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => PollResult::Eof,
|
||||
Err(e) => PollResult::Error(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Buffer State =============
|
||||
|
||||
/// State for buffered reading operations
|
||||
#[derive(Debug)]
|
||||
pub struct ReadBuffer {
|
||||
/// The buffer holding data
|
||||
buffer: BytesMut,
|
||||
/// Target number of bytes to read (if known)
|
||||
target: Option<usize>,
|
||||
}
|
||||
|
||||
impl ReadBuffer {
|
||||
/// Create new empty read buffer
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffer: BytesMut::with_capacity(8192),
|
||||
target: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with specific capacity
|
||||
pub fn with_capacity(capacity: usize) -> Self {
|
||||
Self {
|
||||
buffer: BytesMut::with_capacity(capacity),
|
||||
target: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with target size
|
||||
pub fn with_target(target: usize) -> Self {
|
||||
Self {
|
||||
buffer: BytesMut::with_capacity(target),
|
||||
target: Some(target),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current buffer length
|
||||
pub fn len(&self) -> usize {
|
||||
self.buffer.len()
|
||||
}
|
||||
|
||||
/// Check if buffer is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.buffer.is_empty()
|
||||
}
|
||||
|
||||
/// Check if target is reached
|
||||
pub fn is_complete(&self) -> bool {
|
||||
match self.target {
|
||||
Some(t) => self.buffer.len() >= t,
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get remaining bytes needed
|
||||
pub fn remaining(&self) -> usize {
|
||||
match self.target {
|
||||
Some(t) => t.saturating_sub(self.buffer.len()),
|
||||
None => 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Append data to buffer
|
||||
pub fn extend(&mut self, data: &[u8]) {
|
||||
self.buffer.extend_from_slice(data);
|
||||
}
|
||||
|
||||
/// Take all data from buffer
|
||||
pub fn take(&mut self) -> Bytes {
|
||||
self.target = None;
|
||||
self.buffer.split().freeze()
|
||||
}
|
||||
|
||||
/// Take exactly n bytes
|
||||
pub fn take_exact(&mut self, n: usize) -> Option<Bytes> {
|
||||
if self.buffer.len() >= n {
|
||||
Some(self.buffer.split_to(n).freeze())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a slice of the buffer
|
||||
pub fn as_slice(&self) -> &[u8] {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
/// Get mutable access to underlying BytesMut
|
||||
pub fn as_bytes_mut(&mut self) -> &mut BytesMut {
|
||||
&mut self.buffer
|
||||
}
|
||||
|
||||
/// Clear the buffer
|
||||
pub fn clear(&mut self) {
|
||||
self.buffer.clear();
|
||||
self.target = None;
|
||||
}
|
||||
|
||||
/// Set new target
|
||||
pub fn set_target(&mut self, target: usize) {
|
||||
self.target = Some(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ReadBuffer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// State for buffered writing operations
|
||||
#[derive(Debug)]
|
||||
pub struct WriteBuffer {
|
||||
/// The buffer holding data to write
|
||||
buffer: BytesMut,
|
||||
/// Position of next byte to write
|
||||
position: usize,
|
||||
/// Maximum buffer size
|
||||
max_size: usize,
|
||||
}
|
||||
|
||||
impl WriteBuffer {
|
||||
/// Create new write buffer with default max size (256KB)
|
||||
pub fn new() -> Self {
|
||||
Self::with_max_size(256 * 1024)
|
||||
}
|
||||
|
||||
/// Create with specific max size
|
||||
pub fn with_max_size(max_size: usize) -> Self {
|
||||
Self {
|
||||
buffer: BytesMut::with_capacity(8192),
|
||||
position: 0,
|
||||
max_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get pending bytes count
|
||||
pub fn len(&self) -> usize {
|
||||
self.buffer.len() - self.position
|
||||
}
|
||||
|
||||
/// Check if buffer is empty (all written)
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.position >= self.buffer.len()
|
||||
}
|
||||
|
||||
/// Check if buffer is full
|
||||
pub fn is_full(&self) -> bool {
|
||||
self.buffer.len() >= self.max_size
|
||||
}
|
||||
|
||||
/// Get remaining capacity
|
||||
pub fn remaining_capacity(&self) -> usize {
|
||||
self.max_size.saturating_sub(self.buffer.len())
|
||||
}
|
||||
|
||||
/// Append data to buffer
|
||||
pub fn extend(&mut self, data: &[u8]) -> Result<(), ()> {
|
||||
if self.buffer.len() + data.len() > self.max_size {
|
||||
return Err(());
|
||||
}
|
||||
self.buffer.extend_from_slice(data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get slice of data to write
|
||||
pub fn pending(&self) -> &[u8] {
|
||||
&self.buffer[self.position..]
|
||||
}
|
||||
|
||||
/// Advance position by n bytes (after successful write)
|
||||
pub fn advance(&mut self, n: usize) {
|
||||
self.position += n;
|
||||
|
||||
// If all data written, reset buffer
|
||||
if self.position >= self.buffer.len() {
|
||||
self.buffer.clear();
|
||||
self.position = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear the buffer
|
||||
pub fn clear(&mut self) {
|
||||
self.buffer.clear();
|
||||
self.position = 0;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WriteBuffer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Fixed-Size Buffer States =============
|
||||
|
||||
/// State for reading a fixed-size header
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HeaderBuffer<const N: usize> {
|
||||
/// The buffer
|
||||
data: [u8; N],
|
||||
/// Bytes filled so far
|
||||
filled: usize,
|
||||
}
|
||||
|
||||
impl<const N: usize> HeaderBuffer<N> {
|
||||
/// Create new empty header buffer
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
data: [0u8; N],
|
||||
filled: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get slice for reading into
|
||||
pub fn unfilled_mut(&mut self) -> &mut [u8] {
|
||||
&mut self.data[self.filled..]
|
||||
}
|
||||
|
||||
/// Advance filled count
|
||||
pub fn advance(&mut self, n: usize) {
|
||||
self.filled = (self.filled + n).min(N);
|
||||
}
|
||||
|
||||
/// Check if completely filled
|
||||
pub fn is_complete(&self) -> bool {
|
||||
self.filled >= N
|
||||
}
|
||||
|
||||
/// Get remaining bytes needed
|
||||
pub fn remaining(&self) -> usize {
|
||||
N - self.filled
|
||||
}
|
||||
|
||||
/// Get filled bytes as slice
|
||||
pub fn as_slice(&self) -> &[u8] {
|
||||
&self.data[..self.filled]
|
||||
}
|
||||
|
||||
/// Get complete buffer (panics if not complete)
|
||||
pub fn as_array(&self) -> &[u8; N] {
|
||||
assert!(self.is_complete());
|
||||
&self.data
|
||||
}
|
||||
|
||||
/// Take the buffer, resetting state
|
||||
pub fn take(&mut self) -> [u8; N] {
|
||||
let data = self.data;
|
||||
self.data = [0u8; N];
|
||||
self.filled = 0;
|
||||
data
|
||||
}
|
||||
|
||||
/// Reset to empty state
|
||||
pub fn reset(&mut self) {
|
||||
self.filled = 0;
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Default for HeaderBuffer<N> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Yield Buffer =============
|
||||
|
||||
/// Buffer for yielding data to caller in chunks
|
||||
#[derive(Debug)]
|
||||
pub struct YieldBuffer {
|
||||
data: Bytes,
|
||||
position: usize,
|
||||
}
|
||||
|
||||
impl YieldBuffer {
|
||||
/// Create new yield buffer
|
||||
pub fn new(data: Bytes) -> Self {
|
||||
Self { data, position: 0 }
|
||||
}
|
||||
|
||||
/// Check if all data has been yielded
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.position >= self.data.len()
|
||||
}
|
||||
|
||||
/// Get remaining bytes
|
||||
pub fn remaining(&self) -> usize {
|
||||
self.data.len() - self.position
|
||||
}
|
||||
|
||||
/// Copy data to output slice, return bytes copied
|
||||
pub fn copy_to(&mut self, dst: &mut [u8]) -> usize {
|
||||
let available = &self.data[self.position..];
|
||||
let to_copy = available.len().min(dst.len());
|
||||
dst[..to_copy].copy_from_slice(&available[..to_copy]);
|
||||
self.position += to_copy;
|
||||
to_copy
|
||||
}
|
||||
|
||||
/// Get remaining data as slice
|
||||
pub fn as_slice(&self) -> &[u8] {
|
||||
&self.data[self.position..]
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Macros =============
|
||||
|
||||
/// Macro to simplify state transitions in poll methods
|
||||
#[macro_export]
|
||||
macro_rules! transition {
|
||||
(same) => {
|
||||
$crate::stream::state::Transition::Same
|
||||
};
|
||||
(next $state:expr) => {
|
||||
$crate::stream::state::Transition::Next($state)
|
||||
};
|
||||
(complete $output:expr) => {
|
||||
$crate::stream::state::Transition::Complete($output)
|
||||
};
|
||||
(yield $output:expr, $state:expr) => {
|
||||
$crate::stream::state::Transition::Yield($output, $state)
|
||||
};
|
||||
(error $err:expr) => {
|
||||
$crate::stream::state::Transition::Error($err)
|
||||
};
|
||||
}
|
||||
|
||||
/// Macro to match poll ready or return pending
|
||||
#[macro_export]
|
||||
macro_rules! ready_or_pending {
|
||||
($poll:expr) => {
|
||||
match $poll {
|
||||
std::task::Poll::Ready(t) => t,
|
||||
std::task::Poll::Pending => return std::task::Poll::Pending,
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_read_buffer_basic() {
|
||||
let mut buf = ReadBuffer::with_target(10);
|
||||
assert_eq!(buf.remaining(), 10);
|
||||
assert!(!buf.is_complete());
|
||||
|
||||
buf.extend(b"hello");
|
||||
assert_eq!(buf.len(), 5);
|
||||
assert_eq!(buf.remaining(), 5);
|
||||
assert!(!buf.is_complete());
|
||||
|
||||
buf.extend(b"world");
|
||||
assert_eq!(buf.len(), 10);
|
||||
assert!(buf.is_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_buffer_take() {
|
||||
let mut buf = ReadBuffer::new();
|
||||
buf.extend(b"test data");
|
||||
|
||||
let data = buf.take();
|
||||
assert_eq!(&data[..], b"test data");
|
||||
assert!(buf.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_buffer_basic() {
|
||||
let mut buf = WriteBuffer::with_max_size(100);
|
||||
assert!(buf.is_empty());
|
||||
|
||||
buf.extend(b"hello").unwrap();
|
||||
assert_eq!(buf.len(), 5);
|
||||
assert!(!buf.is_empty());
|
||||
|
||||
buf.advance(3);
|
||||
assert_eq!(buf.len(), 2);
|
||||
assert_eq!(buf.pending(), b"lo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_buffer_overflow() {
|
||||
let mut buf = WriteBuffer::with_max_size(10);
|
||||
assert!(buf.extend(b"short").is_ok());
|
||||
assert!(buf.extend(b"toolong").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_header_buffer() {
|
||||
let mut buf = HeaderBuffer::<5>::new();
|
||||
assert!(!buf.is_complete());
|
||||
assert_eq!(buf.remaining(), 5);
|
||||
|
||||
buf.unfilled_mut()[..3].copy_from_slice(b"hel");
|
||||
buf.advance(3);
|
||||
assert_eq!(buf.remaining(), 2);
|
||||
|
||||
buf.unfilled_mut()[..2].copy_from_slice(b"lo");
|
||||
buf.advance(2);
|
||||
assert!(buf.is_complete());
|
||||
assert_eq!(buf.as_array(), b"hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_yield_buffer() {
|
||||
let mut buf = YieldBuffer::new(Bytes::from_static(b"hello world"));
|
||||
|
||||
let mut dst = [0u8; 5];
|
||||
assert_eq!(buf.copy_to(&mut dst), 5);
|
||||
assert_eq!(&dst, b"hello");
|
||||
|
||||
assert_eq!(buf.remaining(), 6);
|
||||
|
||||
let mut dst = [0u8; 10];
|
||||
assert_eq!(buf.copy_to(&mut dst), 6);
|
||||
assert_eq!(&dst[..6], b" world");
|
||||
|
||||
assert!(buf.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transition_map() {
|
||||
let t: Transition<i32, String> = Transition::Complete("hello".to_string());
|
||||
let t = t.map_output(|s| s.len());
|
||||
|
||||
match t {
|
||||
Transition::Complete(5) => {}
|
||||
_ => panic!("Expected Complete(5)"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_poll_result() {
|
||||
let r: PollResult<i32> = PollResult::Ready(42);
|
||||
assert!(r.is_ready());
|
||||
assert_eq!(r.ok(), Some(42));
|
||||
|
||||
let r: PollResult<i32> = PollResult::Eof;
|
||||
assert!(r.is_eof());
|
||||
assert_eq!(r.ok(), None);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,11 @@
|
||||
pub mod pool;
|
||||
pub mod proxy_protocol;
|
||||
pub mod socket;
|
||||
pub mod socks;
|
||||
pub mod upstream;
|
||||
|
||||
pub use pool::ConnectionPool;
|
||||
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
|
||||
pub use socket::*;
|
||||
pub use socks::*;
|
||||
pub use upstream::UpstreamManager;
|
||||
@@ -1,7 +1,7 @@
|
||||
//! TCP Socket Configuration
|
||||
|
||||
use std::io::Result;
|
||||
use std::net::SocketAddr;
|
||||
use std::net::{SocketAddr, IpAddr};
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use socket2::{Socket, TcpKeepalive, Domain, Type, Protocol};
|
||||
@@ -30,20 +30,13 @@ pub fn configure_tcp_socket(
|
||||
socket.set_tcp_keepalive(&keepalive)?;
|
||||
}
|
||||
|
||||
// Set buffer sizes
|
||||
set_buffer_sizes(&socket, 65536, 65536)?;
|
||||
// CHANGED: Removed manual buffer size setting (was 256KB).
|
||||
// Allowing the OS kernel to handle TCP window scaling (Autotuning) is critical
|
||||
// for mobile clients to avoid bufferbloat and stalled connections during uploads.
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set socket buffer sizes
|
||||
fn set_buffer_sizes(socket: &socket2::SockRef, recv: usize, send: usize) -> Result<()> {
|
||||
// These may fail on some systems, so we ignore errors
|
||||
let _ = socket.set_recv_buffer_size(recv);
|
||||
let _ = socket.set_send_buffer_size(send);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configure socket for accepting client connections
|
||||
pub fn configure_client_socket(
|
||||
stream: &TcpStream,
|
||||
@@ -65,6 +58,8 @@ pub fn configure_client_socket(
|
||||
socket.set_tcp_keepalive(&keepalive)?;
|
||||
|
||||
// Set TCP user timeout (Linux only)
|
||||
// NOTE: iOS does not support TCP_USER_TIMEOUT - application-level timeout
|
||||
// is implemented in relay_bidirectional instead
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
use std::os::unix::io::AsRawFd;
|
||||
@@ -93,6 +88,11 @@ pub fn set_linger_zero(stream: &TcpStream) -> Result<()> {
|
||||
|
||||
/// Create a new TCP socket for outgoing connections
|
||||
pub fn create_outgoing_socket(addr: SocketAddr) -> Result<Socket> {
|
||||
create_outgoing_socket_bound(addr, None)
|
||||
}
|
||||
|
||||
/// Create a new TCP socket for outgoing connections, optionally bound to a specific interface
|
||||
pub fn create_outgoing_socket_bound(addr: SocketAddr, bind_addr: Option<IpAddr>) -> Result<Socket> {
|
||||
let domain = if addr.is_ipv4() {
|
||||
Domain::IPV4
|
||||
} else {
|
||||
@@ -107,9 +107,16 @@ pub fn create_outgoing_socket(addr: SocketAddr) -> Result<Socket> {
|
||||
// Disable Nagle
|
||||
socket.set_nodelay(true)?;
|
||||
|
||||
if let Some(bind_ip) = bind_addr {
|
||||
let bind_sock_addr = SocketAddr::new(bind_ip, 0);
|
||||
socket.bind(&bind_sock_addr.into())?;
|
||||
debug!("Bound outgoing socket to {}", bind_ip);
|
||||
}
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
|
||||
/// Get local address of a socket
|
||||
pub fn get_local_addr(stream: &TcpStream) -> Option<SocketAddr> {
|
||||
stream.local_addr().ok()
|
||||
|
||||
145
src/transport/socks.rs
Normal file
145
src/transport/socks.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
//! SOCKS4/5 Client Implementation
|
||||
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use crate::error::{ProxyError, Result};
|
||||
|
||||
pub async fn connect_socks4(
|
||||
stream: &mut TcpStream,
|
||||
target: SocketAddr,
|
||||
user_id: Option<&str>,
|
||||
) -> Result<()> {
|
||||
let ip = match target.ip() {
|
||||
IpAddr::V4(ip) => ip,
|
||||
IpAddr::V6(_) => return Err(ProxyError::Proxy("SOCKS4 does not support IPv6".to_string())),
|
||||
};
|
||||
|
||||
let port = target.port();
|
||||
let user = user_id.unwrap_or("").as_bytes();
|
||||
|
||||
// VN (4) | CD (1) | DSTPORT (2) | DSTIP (4) | USERID (variable) | NULL (1)
|
||||
let mut buf = Vec::with_capacity(9 + user.len());
|
||||
buf.push(4); // VN
|
||||
buf.push(1); // CD (CONNECT)
|
||||
buf.extend_from_slice(&port.to_be_bytes());
|
||||
buf.extend_from_slice(&ip.octets());
|
||||
buf.extend_from_slice(user);
|
||||
buf.push(0); // NULL
|
||||
|
||||
stream.write_all(&buf).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
// Response: VN (1) | CD (1) | DSTPORT (2) | DSTIP (4)
|
||||
let mut resp = [0u8; 8];
|
||||
stream.read_exact(&mut resp).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
if resp[1] != 90 {
|
||||
return Err(ProxyError::Proxy(format!("SOCKS4 request rejected: code {}", resp[1])));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn connect_socks5(
|
||||
stream: &mut TcpStream,
|
||||
target: SocketAddr,
|
||||
username: Option<&str>,
|
||||
password: Option<&str>,
|
||||
) -> Result<()> {
|
||||
// 1. Auth negotiation
|
||||
// VER (1) | NMETHODS (1) | METHODS (variable)
|
||||
let mut methods = vec![0u8]; // No auth
|
||||
if username.is_some() {
|
||||
methods.push(2u8); // Username/Password
|
||||
}
|
||||
|
||||
let mut buf = vec![5u8, methods.len() as u8];
|
||||
buf.extend_from_slice(&methods);
|
||||
|
||||
stream.write_all(&buf).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
let mut resp = [0u8; 2];
|
||||
stream.read_exact(&mut resp).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
if resp[0] != 5 {
|
||||
return Err(ProxyError::Proxy("Invalid SOCKS5 version".to_string()));
|
||||
}
|
||||
|
||||
match resp[1] {
|
||||
0 => {}, // No auth
|
||||
2 => {
|
||||
// Username/Password auth
|
||||
if let (Some(u), Some(p)) = (username, password) {
|
||||
let u_bytes = u.as_bytes();
|
||||
let p_bytes = p.as_bytes();
|
||||
|
||||
let mut auth_buf = Vec::with_capacity(3 + u_bytes.len() + p_bytes.len());
|
||||
auth_buf.push(1); // VER
|
||||
auth_buf.push(u_bytes.len() as u8);
|
||||
auth_buf.extend_from_slice(u_bytes);
|
||||
auth_buf.push(p_bytes.len() as u8);
|
||||
auth_buf.extend_from_slice(p_bytes);
|
||||
|
||||
stream.write_all(&auth_buf).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
let mut auth_resp = [0u8; 2];
|
||||
stream.read_exact(&mut auth_resp).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
if auth_resp[1] != 0 {
|
||||
return Err(ProxyError::Proxy("SOCKS5 authentication failed".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(ProxyError::Proxy("SOCKS5 server requires authentication".to_string()));
|
||||
}
|
||||
},
|
||||
_ => return Err(ProxyError::Proxy("Unsupported SOCKS5 auth method".to_string())),
|
||||
}
|
||||
|
||||
// 2. Connection request
|
||||
// VER (1) | CMD (1) | RSV (1) | ATYP (1) | DST.ADDR (variable) | DST.PORT (2)
|
||||
let mut req = vec![5u8, 1u8, 0u8]; // CONNECT
|
||||
|
||||
match target {
|
||||
SocketAddr::V4(v4) => {
|
||||
req.push(1u8); // IPv4
|
||||
req.extend_from_slice(&v4.ip().octets());
|
||||
},
|
||||
SocketAddr::V6(v6) => {
|
||||
req.push(4u8); // IPv6
|
||||
req.extend_from_slice(&v6.ip().octets());
|
||||
},
|
||||
}
|
||||
|
||||
req.extend_from_slice(&target.port().to_be_bytes());
|
||||
|
||||
stream.write_all(&req).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
// Response
|
||||
let mut head = [0u8; 4];
|
||||
stream.read_exact(&mut head).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
if head[1] != 0 {
|
||||
return Err(ProxyError::Proxy(format!("SOCKS5 request failed: code {}", head[1])));
|
||||
}
|
||||
|
||||
// Skip address part of response
|
||||
match head[3] {
|
||||
1 => { // IPv4
|
||||
let mut addr = [0u8; 4 + 2];
|
||||
stream.read_exact(&mut addr).await.map_err(|e| ProxyError::Io(e))?;
|
||||
},
|
||||
3 => { // Domain
|
||||
let mut len = [0u8; 1];
|
||||
stream.read_exact(&mut len).await.map_err(|e| ProxyError::Io(e))?;
|
||||
let mut addr = vec![0u8; len[0] as usize + 2];
|
||||
stream.read_exact(&mut addr).await.map_err(|e| ProxyError::Io(e))?;
|
||||
},
|
||||
4 => { // IPv6
|
||||
let mut addr = [0u8; 16 + 2];
|
||||
stream.read_exact(&mut addr).await.map_err(|e| ProxyError::Io(e))?;
|
||||
},
|
||||
_ => return Err(ProxyError::Proxy("Invalid address type in SOCKS5 response".to_string())),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
259
src/transport/upstream.rs
Normal file
259
src/transport/upstream.rs
Normal file
@@ -0,0 +1,259 @@
|
||||
//! Upstream Management
|
||||
|
||||
use std::net::{SocketAddr, IpAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::RwLock;
|
||||
use rand::Rng;
|
||||
use tracing::{debug, warn, error, info};
|
||||
|
||||
use crate::config::{UpstreamConfig, UpstreamType};
|
||||
use crate::error::{Result, ProxyError};
|
||||
use crate::transport::socket::create_outgoing_socket_bound;
|
||||
use crate::transport::socks::{connect_socks4, connect_socks5};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UpstreamState {
|
||||
config: UpstreamConfig,
|
||||
healthy: bool,
|
||||
fails: u32,
|
||||
last_check: std::time::Instant,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct UpstreamManager {
|
||||
upstreams: Arc<RwLock<Vec<UpstreamState>>>,
|
||||
}
|
||||
|
||||
impl UpstreamManager {
|
||||
pub fn new(configs: Vec<UpstreamConfig>) -> Self {
|
||||
let states = configs.into_iter()
|
||||
.filter(|c| c.enabled)
|
||||
.map(|c| UpstreamState {
|
||||
config: c,
|
||||
healthy: true, // Optimistic start
|
||||
fails: 0,
|
||||
last_check: std::time::Instant::now(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
upstreams: Arc::new(RwLock::new(states)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Select an upstream using Weighted Round Robin (simplified)
|
||||
async fn select_upstream(&self) -> Option<usize> {
|
||||
let upstreams = self.upstreams.read().await;
|
||||
if upstreams.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let healthy_indices: Vec<usize> = upstreams.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, u)| u.healthy)
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
if healthy_indices.is_empty() {
|
||||
// If all unhealthy, try any random one
|
||||
return Some(rand::thread_rng().gen_range(0..upstreams.len()));
|
||||
}
|
||||
|
||||
// Weighted selection
|
||||
let total_weight: u32 = healthy_indices.iter()
|
||||
.map(|&i| upstreams[i].config.weight as u32)
|
||||
.sum();
|
||||
|
||||
if total_weight == 0 {
|
||||
return Some(healthy_indices[rand::thread_rng().gen_range(0..healthy_indices.len())]);
|
||||
}
|
||||
|
||||
let mut choice = rand::thread_rng().gen_range(0..total_weight);
|
||||
|
||||
for &idx in &healthy_indices {
|
||||
let weight = upstreams[idx].config.weight as u32;
|
||||
if choice < weight {
|
||||
return Some(idx);
|
||||
}
|
||||
choice -= weight;
|
||||
}
|
||||
|
||||
Some(healthy_indices[0])
|
||||
}
|
||||
|
||||
pub async fn connect(&self, target: SocketAddr) -> Result<TcpStream> {
|
||||
let idx = self.select_upstream().await
|
||||
.ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?;
|
||||
|
||||
let upstream = {
|
||||
let guard = self.upstreams.read().await;
|
||||
guard[idx].config.clone()
|
||||
};
|
||||
|
||||
match self.connect_via_upstream(&upstream, target).await {
|
||||
Ok(stream) => {
|
||||
// Mark success
|
||||
let mut guard = self.upstreams.write().await;
|
||||
if let Some(u) = guard.get_mut(idx) {
|
||||
if !u.healthy {
|
||||
debug!("Upstream recovered: {:?}", u.config);
|
||||
}
|
||||
u.healthy = true;
|
||||
u.fails = 0;
|
||||
}
|
||||
Ok(stream)
|
||||
},
|
||||
Err(e) => {
|
||||
// Mark failure
|
||||
let mut guard = self.upstreams.write().await;
|
||||
if let Some(u) = guard.get_mut(idx) {
|
||||
u.fails += 1;
|
||||
warn!("Failed to connect via upstream {:?}: {}. Fails: {}", u.config, e, u.fails);
|
||||
if u.fails > 3 {
|
||||
u.healthy = false;
|
||||
warn!("Upstream disabled due to failures: {:?}", u.config);
|
||||
}
|
||||
}
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_via_upstream(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<TcpStream> {
|
||||
match &config.upstream_type {
|
||||
UpstreamType::Direct { interface } => {
|
||||
let bind_ip = interface.as_ref()
|
||||
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||
|
||||
let socket = create_outgoing_socket_bound(target, bind_ip)?;
|
||||
|
||||
// Non-blocking connect logic
|
||||
socket.set_nonblocking(true)?;
|
||||
match socket.connect(&target.into()) {
|
||||
Ok(()) => {},
|
||||
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {},
|
||||
Err(err) => return Err(ProxyError::Io(err)),
|
||||
}
|
||||
|
||||
let std_stream: std::net::TcpStream = socket.into();
|
||||
let stream = TcpStream::from_std(std_stream)?;
|
||||
|
||||
// Wait for connection to complete
|
||||
stream.writable().await?;
|
||||
if let Some(e) = stream.take_error()? {
|
||||
return Err(ProxyError::Io(e));
|
||||
}
|
||||
|
||||
Ok(stream)
|
||||
},
|
||||
UpstreamType::Socks4 { address, interface, user_id } => {
|
||||
info!("Connecting to target {} via SOCKS4 proxy {}", target, address);
|
||||
|
||||
let proxy_addr: SocketAddr = address.parse()
|
||||
.map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?;
|
||||
|
||||
let bind_ip = interface.as_ref()
|
||||
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||
|
||||
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
|
||||
|
||||
// Non-blocking connect logic
|
||||
socket.set_nonblocking(true)?;
|
||||
match socket.connect(&proxy_addr.into()) {
|
||||
Ok(()) => {},
|
||||
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {},
|
||||
Err(err) => return Err(ProxyError::Io(err)),
|
||||
}
|
||||
|
||||
let std_stream: std::net::TcpStream = socket.into();
|
||||
let mut stream = TcpStream::from_std(std_stream)?;
|
||||
|
||||
// Wait for connection to complete
|
||||
stream.writable().await?;
|
||||
if let Some(e) = stream.take_error()? {
|
||||
return Err(ProxyError::Io(e));
|
||||
}
|
||||
|
||||
connect_socks4(&mut stream, target, user_id.as_deref()).await?;
|
||||
Ok(stream)
|
||||
},
|
||||
UpstreamType::Socks5 { address, interface, username, password } => {
|
||||
info!("Connecting to target {} via SOCKS5 proxy {}", target, address);
|
||||
|
||||
let proxy_addr: SocketAddr = address.parse()
|
||||
.map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?;
|
||||
|
||||
let bind_ip = interface.as_ref()
|
||||
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||
|
||||
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
|
||||
|
||||
// Non-blocking connect logic
|
||||
socket.set_nonblocking(true)?;
|
||||
match socket.connect(&proxy_addr.into()) {
|
||||
Ok(()) => {},
|
||||
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {},
|
||||
Err(err) => return Err(ProxyError::Io(err)),
|
||||
}
|
||||
|
||||
let std_stream: std::net::TcpStream = socket.into();
|
||||
let mut stream = TcpStream::from_std(std_stream)?;
|
||||
|
||||
// Wait for connection to complete
|
||||
stream.writable().await?;
|
||||
if let Some(e) = stream.take_error()? {
|
||||
return Err(ProxyError::Io(e));
|
||||
}
|
||||
|
||||
connect_socks5(&mut stream, target, username.as_deref(), password.as_deref()).await?;
|
||||
Ok(stream)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Background task to check health
|
||||
pub async fn run_health_checks(&self) {
|
||||
// Simple TCP connect check to a known stable DC (e.g. 149.154.167.50:443 - DC2)
|
||||
let check_target: SocketAddr = "149.154.167.50:443".parse().unwrap();
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
|
||||
let count = self.upstreams.read().await.len();
|
||||
for i in 0..count {
|
||||
let config = {
|
||||
let guard = self.upstreams.read().await;
|
||||
guard[i].config.clone()
|
||||
};
|
||||
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(10),
|
||||
self.connect_via_upstream(&config, check_target)
|
||||
).await;
|
||||
|
||||
let mut guard = self.upstreams.write().await;
|
||||
let u = &mut guard[i];
|
||||
|
||||
match result {
|
||||
Ok(Ok(_stream)) => {
|
||||
if !u.healthy {
|
||||
debug!("Upstream recovered: {:?}", u.config);
|
||||
}
|
||||
u.healthy = true;
|
||||
u.fails = 0;
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
debug!("Health check failed for {:?}: {}", u.config, e);
|
||||
// Don't mark unhealthy immediately in background check
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Health check timeout for {:?}", u.config);
|
||||
}
|
||||
}
|
||||
u.last_check = std::time::Instant::now();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//! IP Addr Detect
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::net::{IpAddr, SocketAddr, UdpSocket};
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
@@ -40,28 +40,74 @@ const IPV6_URLS: &[&str] = &[
|
||||
"http://api6.ipify.org/",
|
||||
];
|
||||
|
||||
/// Detect local IP address by connecting to a public DNS
|
||||
/// This does not actually send any packets
|
||||
fn get_local_ip(target: &str) -> Option<IpAddr> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").ok()?;
|
||||
socket.connect(target).ok()?;
|
||||
socket.local_addr().ok().map(|addr| addr.ip())
|
||||
}
|
||||
|
||||
fn get_local_ipv6(target: &str) -> Option<IpAddr> {
|
||||
let socket = UdpSocket::bind("[::]:0").ok()?;
|
||||
socket.connect(target).ok()?;
|
||||
socket.local_addr().ok().map(|addr| addr.ip())
|
||||
}
|
||||
|
||||
/// Detect public IP addresses
|
||||
pub async fn detect_ip() -> IpInfo {
|
||||
let mut info = IpInfo::default();
|
||||
|
||||
// Detect IPv4
|
||||
for url in IPV4_URLS {
|
||||
if let Some(ip) = fetch_ip(url).await {
|
||||
if ip.is_ipv4() {
|
||||
info.ipv4 = Some(ip);
|
||||
debug!(ip = %ip, "Detected IPv4 address");
|
||||
break;
|
||||
// Try to get local interface IP first (default gateway interface)
|
||||
// We connect to Google DNS to find out which interface is used for routing
|
||||
if let Some(ip) = get_local_ip("8.8.8.8:80") {
|
||||
if ip.is_ipv4() && !ip.is_loopback() {
|
||||
info.ipv4 = Some(ip);
|
||||
debug!(ip = %ip, "Detected local IPv4 address via routing");
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ip) = get_local_ipv6("[2001:4860:4860::8888]:80") {
|
||||
if ip.is_ipv6() && !ip.is_loopback() {
|
||||
info.ipv6 = Some(ip);
|
||||
debug!(ip = %ip, "Detected local IPv6 address via routing");
|
||||
}
|
||||
}
|
||||
|
||||
// If local detection failed or returned private IP (and we want public),
|
||||
// or just as a fallback/verification, we might want to check external services.
|
||||
// However, the requirement is: "if IP for listening is not set... it should be IP from interface...
|
||||
// if impossible - request external resources".
|
||||
|
||||
// So if we found a local IP, we might be good. But often servers are behind NAT.
|
||||
// If the local IP is private, we probably want the public IP for the tg:// link.
|
||||
// Let's check if the detected IPs are private.
|
||||
|
||||
let need_external_v4 = info.ipv4.map_or(true, |ip| is_private_ip(ip));
|
||||
let need_external_v6 = info.ipv6.map_or(true, |ip| is_private_ip(ip));
|
||||
|
||||
if need_external_v4 {
|
||||
debug!("Local IPv4 is private or missing, checking external services...");
|
||||
for url in IPV4_URLS {
|
||||
if let Some(ip) = fetch_ip(url).await {
|
||||
if ip.is_ipv4() {
|
||||
info.ipv4 = Some(ip);
|
||||
debug!(ip = %ip, "Detected public IPv4 address");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Detect IPv6
|
||||
for url in IPV6_URLS {
|
||||
if let Some(ip) = fetch_ip(url).await {
|
||||
if ip.is_ipv6() {
|
||||
info.ipv6 = Some(ip);
|
||||
debug!(ip = %ip, "Detected IPv6 address");
|
||||
break;
|
||||
if need_external_v6 {
|
||||
debug!("Local IPv6 is private or missing, checking external services...");
|
||||
for url in IPV6_URLS {
|
||||
if let Some(ip) = fetch_ip(url).await {
|
||||
if ip.is_ipv6() {
|
||||
info.ipv6 = Some(ip);
|
||||
debug!(ip = %ip, "Detected public IPv6 address");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -73,6 +119,17 @@ pub async fn detect_ip() -> IpInfo {
|
||||
info
|
||||
}
|
||||
|
||||
fn is_private_ip(ip: IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(ipv4) => {
|
||||
ipv4.is_private() || ipv4.is_loopback() || ipv4.is_link_local()
|
||||
}
|
||||
IpAddr::V6(ipv6) => {
|
||||
ipv6.is_loopback() || (ipv6.segments()[0] & 0xfe00) == 0xfc00 // Unique Local
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch IP from URL
|
||||
async fn fetch_ip(url: &str) -> Option<IpAddr> {
|
||||
let client = reqwest::Client::builder()
|
||||
|
||||
12
telemt.service
Normal file
12
telemt.service
Normal file
@@ -0,0 +1,12 @@
|
||||
[Unit]
|
||||
Description=Telemt
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
WorkingDirectory=/bin
|
||||
ExecStart=/bin/telemt /etc/telemt.toml
|
||||
Restart=on-failure
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
Reference in New Issue
Block a user