Compare commits
39 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7f8cde8317 | ||
|
|
e32d8e6c7d | ||
|
|
d405756b94 | ||
|
|
a8c3128c50 | ||
|
|
70859aa5cf | ||
|
|
9b850b0bfb | ||
|
|
de28655dd2 | ||
|
|
e62b41ae64 | ||
|
|
f1c1f42de8 | ||
|
|
a494dfa9eb | ||
|
|
e6bf7ac40e | ||
|
|
889a5fa19b | ||
|
|
d8ff958481 | ||
|
|
28ee74787b | ||
|
|
a688bfe22f | ||
|
|
91eea914b3 | ||
|
|
3ba97a08fa | ||
|
|
6e445be108 | ||
|
|
3c6752644a | ||
|
|
9bd12f6acb | ||
|
|
61581203c4 | ||
|
|
84668e671e | ||
|
|
5bde202866 | ||
|
|
9304d5256a | ||
|
|
364bc6e278 | ||
|
|
e83db704b7 | ||
|
|
acf90043eb | ||
|
|
0011e20653 | ||
|
|
41fb307858 | ||
|
|
6a78c44d2e | ||
|
|
be9c9858ac | ||
|
|
2fa8d85b4c | ||
|
|
310666fd44 | ||
|
|
6cafee153a | ||
|
|
32f60f34db | ||
|
|
158eae8d2a | ||
|
|
92cedabc81 | ||
|
|
b9428d9780 | ||
|
|
5876f0c4d5 |
45
.github/workflows/codeql.yml
vendored
Normal file
45
.github/workflows/codeql.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
name: "CodeQL Advanced"
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "main" ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ "main" ]
|
||||||
|
schedule:
|
||||||
|
- cron: '0 0 * * 0'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
analyze:
|
||||||
|
name: Analyze (${{ matrix.language }})
|
||||||
|
runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
security-events: write
|
||||||
|
packages: read
|
||||||
|
actions: read
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- language: actions
|
||||||
|
build-mode: none
|
||||||
|
- language: rust
|
||||||
|
build-mode: none
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Initialize CodeQL
|
||||||
|
uses: github/codeql-action/init@v4
|
||||||
|
with:
|
||||||
|
languages: ${{ matrix.language }}
|
||||||
|
build-mode: ${{ matrix.build-mode }}
|
||||||
|
config-file: .github/codeql/codeql-config.yml
|
||||||
|
|
||||||
|
- name: Perform CodeQL Analysis
|
||||||
|
uses: github/codeql-action/analyze@v4
|
||||||
|
with:
|
||||||
|
category: "/language:${{ matrix.language }}"
|
||||||
20
.github/workflows/queries/common/ProductionOnly.qll
vendored
Normal file
20
.github/workflows/queries/common/ProductionOnly.qll
vendored
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import rust
|
||||||
|
|
||||||
|
predicate isTestOnly(Item i) {
|
||||||
|
exists(ConditionalCompilation cc |
|
||||||
|
cc.getItem() = i and
|
||||||
|
cc.getCfg().toString() = "test"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
predicate hasTestAttribute(Item i) {
|
||||||
|
exists(Attribute a |
|
||||||
|
a.getItem() = i and
|
||||||
|
a.getName() = "test"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
predicate isProductionCode(Item i) {
|
||||||
|
not isTestOnly(i) and
|
||||||
|
not hasTestAttribute(i)
|
||||||
|
}
|
||||||
4
.github/workflows/queries/qlpack.yml
vendored
Normal file
4
.github/workflows/queries/qlpack.yml
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
name: rust-production-only
|
||||||
|
version: 0.0.1
|
||||||
|
dependencies:
|
||||||
|
codeql/rust-all: "*"
|
||||||
5
.github/workflows/rust.yml
vendored
5
.github/workflows/rust.yml
vendored
@@ -14,6 +14,11 @@ jobs:
|
|||||||
name: Build
|
name: Build
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
actions: write
|
||||||
|
checks: write
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
2742
Cargo.lock
generated
Normal file
2742
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
27
Cargo.toml
27
Cargo.toml
@@ -1,15 +1,14 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "telemt"
|
name = "telemt"
|
||||||
version = "1.0.0"
|
version = "1.2.0"
|
||||||
edition = "2021"
|
edition = "2024"
|
||||||
rust-version = "1.75"
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# C
|
# C
|
||||||
libc = "0.2"
|
libc = "0.2"
|
||||||
|
|
||||||
# Async runtime
|
# Async runtime
|
||||||
tokio = { version = "1.35", features = ["full", "tracing"] }
|
tokio = { version = "1.42", features = ["full", "tracing"] }
|
||||||
tokio-util = { version = "0.7", features = ["codec"] }
|
tokio-util = { version = "0.7", features = ["codec"] }
|
||||||
|
|
||||||
# Crypto
|
# Crypto
|
||||||
@@ -20,41 +19,41 @@ sha2 = "0.10"
|
|||||||
sha1 = "0.10"
|
sha1 = "0.10"
|
||||||
md-5 = "0.10"
|
md-5 = "0.10"
|
||||||
hmac = "0.12"
|
hmac = "0.12"
|
||||||
crc32fast = "1.3"
|
crc32fast = "1.4"
|
||||||
|
zeroize = { version = "1.8", features = ["derive"] }
|
||||||
|
|
||||||
# Network
|
# Network
|
||||||
socket2 = { version = "0.5", features = ["all"] }
|
socket2 = { version = "0.5", features = ["all"] }
|
||||||
rustls = "0.22"
|
|
||||||
|
|
||||||
# Serial
|
# Serialization
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
toml = "0.8"
|
toml = "0.8"
|
||||||
|
|
||||||
# Utils
|
# Utils
|
||||||
bytes = "1.5"
|
bytes = "1.9"
|
||||||
thiserror = "1.0"
|
thiserror = "2.0"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
parking_lot = "0.12"
|
parking_lot = "0.12"
|
||||||
dashmap = "5.5"
|
dashmap = "5.5"
|
||||||
lru = "0.12"
|
lru = "0.12"
|
||||||
rand = "0.8"
|
rand = "0.9"
|
||||||
chrono = { version = "0.4", features = ["serde"] }
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
hex = "0.4"
|
hex = "0.4"
|
||||||
base64 = "0.21"
|
base64 = "0.22"
|
||||||
url = "2.5"
|
url = "2.5"
|
||||||
regex = "1.10"
|
regex = "1.11"
|
||||||
once_cell = "1.19"
|
|
||||||
crossbeam-queue = "0.3"
|
crossbeam-queue = "0.3"
|
||||||
|
|
||||||
# HTTP
|
# HTTP
|
||||||
reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false }
|
reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio-test = "0.4"
|
tokio-test = "0.4"
|
||||||
criterion = "0.5"
|
criterion = "0.5"
|
||||||
proptest = "1.4"
|
proptest = "1.4"
|
||||||
|
futures = "0.3"
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
name = "crypto_bench"
|
name = "crypto_bench"
|
||||||
|
|||||||
21
README.md
21
README.md
@@ -2,6 +2,26 @@
|
|||||||
|
|
||||||
**Telemt** is a fast, secure, and feature-rich server written in Rust: it fully implements the official Telegram proxy algo and adds many production-ready improvements such as connection pooling, replay protection, detailed statistics, masking from "prying" eyes
|
**Telemt** is a fast, secure, and feature-rich server written in Rust: it fully implements the official Telegram proxy algo and adds many production-ready improvements such as connection pooling, replay protection, detailed statistics, masking from "prying" eyes
|
||||||
|
|
||||||
|
## Emergency
|
||||||
|
**Важное сообщение для пользователей из России**
|
||||||
|
|
||||||
|
Мы работаем над проектом с Нового года и сейчас готовим новый релиз - 1.2
|
||||||
|
|
||||||
|
В нём имплементируется поддержка Middle Proxy Protocol - основного терминатора для Ad Tag:
|
||||||
|
работа над ним идёт с 6 ферваля, а уже 10 февраля произошли "громкие события"...
|
||||||
|
|
||||||
|
Если у вас есть компетенции в асинхронных сетевых приложениях - мы открыты к предложениям и pull requests
|
||||||
|
|
||||||
|
**Important message for users from Russia**
|
||||||
|
|
||||||
|
We've been working on the project since December 30 and are currently preparing a new release – 1.2
|
||||||
|
|
||||||
|
It implements support for the Middle Proxy Protocol – the primary point for the Ad Tag:
|
||||||
|
development on it started on February 6th, and by February 10th, "big activity" in Russia had already "taken place"...
|
||||||
|
|
||||||
|
If you have expertise in asynchronous network applications – we are open to ideas and pull requests!
|
||||||
|
|
||||||
|
# Features
|
||||||
💥 The configuration structure has changed since version 1.1.0.0, change it in your environment!
|
💥 The configuration structure has changed since version 1.1.0.0, change it in your environment!
|
||||||
|
|
||||||
⚓ Our implementation of **TLS-fronting** is one of the most deeply debugged, focused, advanced and *almost* **"behaviorally consistent to real"**: we are confident we have it right - [see evidence on our validation and traces](#recognizability-for-dpi-and-crawler)
|
⚓ Our implementation of **TLS-fronting** is one of the most deeply debugged, focused, advanced and *almost* **"behaviorally consistent to real"**: we are confident we have it right - [see evidence on our validation and traces](#recognizability-for-dpi-and-crawler)
|
||||||
@@ -168,6 +188,7 @@ tls_domain = "petrovich.ru"
|
|||||||
mask = true
|
mask = true
|
||||||
mask_port = 443
|
mask_port = 443
|
||||||
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
|
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
|
||||||
|
# mask_unix_sock = "/var/run/nginx.sock" # Unix socket (mutually exclusive with mask_host)
|
||||||
fake_cert_len = 2048
|
fake_cert_len = 2048
|
||||||
|
|
||||||
# === Access Control & Users ===
|
# === Access Control & Users ===
|
||||||
|
|||||||
18
config.toml
18
config.toml
@@ -6,8 +6,13 @@ show_link = ["hello"]
|
|||||||
[general]
|
[general]
|
||||||
prefer_ipv6 = false
|
prefer_ipv6 = false
|
||||||
fast_mode = true
|
fast_mode = true
|
||||||
use_middle_proxy = false
|
use_middle_proxy = true
|
||||||
# ad_tag = "..."
|
ad_tag = "00000000000000000000000000000000"
|
||||||
|
|
||||||
|
# Log level: debug | verbose | normal | silent
|
||||||
|
# Can be overridden with --silent or --log-level CLI flags
|
||||||
|
# RUST_LOG env var takes absolute priority over all of these
|
||||||
|
log_level = "normal"
|
||||||
|
|
||||||
[general.modes]
|
[general.modes]
|
||||||
classic = false
|
classic = false
|
||||||
@@ -43,12 +48,13 @@ tls_domain = "petrovich.ru"
|
|||||||
mask = true
|
mask = true
|
||||||
mask_port = 443
|
mask_port = 443
|
||||||
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
|
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
|
||||||
|
# mask_unix_sock = "/var/run/nginx.sock" # Unix socket (mutually exclusive with mask_host)
|
||||||
fake_cert_len = 2048
|
fake_cert_len = 2048
|
||||||
|
|
||||||
# === Access Control & Users ===
|
# === Access Control & Users ===
|
||||||
# username "hello" is used for example
|
|
||||||
[access]
|
[access]
|
||||||
replay_check_len = 65536
|
replay_check_len = 65536
|
||||||
|
replay_window_secs = 1800
|
||||||
ignore_time_skew = false
|
ignore_time_skew = false
|
||||||
|
|
||||||
[access.users]
|
[access.users]
|
||||||
@@ -62,17 +68,13 @@ hello = "00000000000000000000000000000000"
|
|||||||
# hello = 1073741824 # 1 GB
|
# hello = 1073741824 # 1 GB
|
||||||
|
|
||||||
# === Upstreams & Routing ===
|
# === Upstreams & Routing ===
|
||||||
# By default, direct connection is used, but you can add SOCKS proxy
|
|
||||||
|
|
||||||
# Direct - Default
|
|
||||||
[[upstreams]]
|
[[upstreams]]
|
||||||
type = "direct"
|
type = "direct"
|
||||||
enabled = true
|
enabled = true
|
||||||
weight = 10
|
weight = 10
|
||||||
|
|
||||||
# SOCKS5
|
|
||||||
# [[upstreams]]
|
# [[upstreams]]
|
||||||
# type = "socks5"
|
# type = "socks5"
|
||||||
# address = "127.0.0.1:9050"
|
# address = "127.0.0.1:1080"
|
||||||
# enabled = false
|
# enabled = false
|
||||||
# weight = 1
|
# weight = 1
|
||||||
300
src/cli.rs
Normal file
300
src/cli.rs
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
//! CLI commands: --init (fire-and-forget setup)
|
||||||
|
|
||||||
|
use std::fs;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::process::Command;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
/// Options for the init command
|
||||||
|
pub struct InitOptions {
|
||||||
|
pub port: u16,
|
||||||
|
pub domain: String,
|
||||||
|
pub secret: Option<String>,
|
||||||
|
pub username: String,
|
||||||
|
pub config_dir: PathBuf,
|
||||||
|
pub no_start: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for InitOptions {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
port: 443,
|
||||||
|
domain: "www.google.com".to_string(),
|
||||||
|
secret: None,
|
||||||
|
username: "user".to_string(),
|
||||||
|
config_dir: PathBuf::from("/etc/telemt"),
|
||||||
|
no_start: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse --init subcommand options from CLI args.
|
||||||
|
///
|
||||||
|
/// Returns `Some(InitOptions)` if `--init` was found, `None` otherwise.
|
||||||
|
pub fn parse_init_args(args: &[String]) -> Option<InitOptions> {
|
||||||
|
if !args.iter().any(|a| a == "--init") {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut opts = InitOptions::default();
|
||||||
|
let mut i = 0;
|
||||||
|
|
||||||
|
while i < args.len() {
|
||||||
|
match args[i].as_str() {
|
||||||
|
"--port" => {
|
||||||
|
i += 1;
|
||||||
|
if i < args.len() {
|
||||||
|
opts.port = args[i].parse().unwrap_or(443);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--domain" => {
|
||||||
|
i += 1;
|
||||||
|
if i < args.len() {
|
||||||
|
opts.domain = args[i].clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--secret" => {
|
||||||
|
i += 1;
|
||||||
|
if i < args.len() {
|
||||||
|
opts.secret = Some(args[i].clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--user" => {
|
||||||
|
i += 1;
|
||||||
|
if i < args.len() {
|
||||||
|
opts.username = args[i].clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--config-dir" => {
|
||||||
|
i += 1;
|
||||||
|
if i < args.len() {
|
||||||
|
opts.config_dir = PathBuf::from(&args[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--no-start" => {
|
||||||
|
opts.no_start = true;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the fire-and-forget setup.
|
||||||
|
pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
eprintln!("[telemt] Fire-and-forget setup");
|
||||||
|
eprintln!();
|
||||||
|
|
||||||
|
// 1. Generate or validate secret
|
||||||
|
let secret = match opts.secret {
|
||||||
|
Some(s) => {
|
||||||
|
if s.len() != 32 || !s.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||||
|
eprintln!("[error] Secret must be exactly 32 hex characters");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
s
|
||||||
|
}
|
||||||
|
None => generate_secret(),
|
||||||
|
};
|
||||||
|
|
||||||
|
eprintln!("[+] Secret: {}", secret);
|
||||||
|
eprintln!("[+] User: {}", opts.username);
|
||||||
|
eprintln!("[+] Port: {}", opts.port);
|
||||||
|
eprintln!("[+] Domain: {}", opts.domain);
|
||||||
|
|
||||||
|
// 2. Create config directory
|
||||||
|
fs::create_dir_all(&opts.config_dir)?;
|
||||||
|
let config_path = opts.config_dir.join("config.toml");
|
||||||
|
|
||||||
|
// 3. Write config
|
||||||
|
let config_content = generate_config(&opts.username, &secret, opts.port, &opts.domain);
|
||||||
|
fs::write(&config_path, &config_content)?;
|
||||||
|
eprintln!("[+] Config written to {}", config_path.display());
|
||||||
|
|
||||||
|
// 4. Write systemd unit
|
||||||
|
let exe_path = std::env::current_exe()
|
||||||
|
.unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt"));
|
||||||
|
|
||||||
|
let unit_path = Path::new("/etc/systemd/system/telemt.service");
|
||||||
|
let unit_content = generate_systemd_unit(&exe_path, &config_path);
|
||||||
|
|
||||||
|
match fs::write(unit_path, &unit_content) {
|
||||||
|
Ok(()) => {
|
||||||
|
eprintln!("[+] Systemd unit written to {}", unit_path.display());
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("[!] Cannot write systemd unit (run as root?): {}", e);
|
||||||
|
eprintln!("[!] Manual unit file content:");
|
||||||
|
eprintln!("{}", unit_content);
|
||||||
|
|
||||||
|
// Still print links and config
|
||||||
|
print_links(&opts.username, &secret, opts.port, &opts.domain);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Reload systemd
|
||||||
|
run_cmd("systemctl", &["daemon-reload"]);
|
||||||
|
|
||||||
|
// 6. Enable service
|
||||||
|
run_cmd("systemctl", &["enable", "telemt.service"]);
|
||||||
|
eprintln!("[+] Service enabled");
|
||||||
|
|
||||||
|
// 7. Start service (unless --no-start)
|
||||||
|
if !opts.no_start {
|
||||||
|
run_cmd("systemctl", &["start", "telemt.service"]);
|
||||||
|
eprintln!("[+] Service started");
|
||||||
|
|
||||||
|
// Brief delay then check status
|
||||||
|
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||||
|
let status = Command::new("systemctl")
|
||||||
|
.args(["is-active", "telemt.service"])
|
||||||
|
.output();
|
||||||
|
|
||||||
|
match status {
|
||||||
|
Ok(out) if out.status.success() => {
|
||||||
|
eprintln!("[+] Service is running");
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
eprintln!("[!] Service may not have started correctly");
|
||||||
|
eprintln!("[!] Check: journalctl -u telemt.service -n 20");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
eprintln!("[+] Service not started (--no-start)");
|
||||||
|
eprintln!("[+] Start manually: systemctl start telemt.service");
|
||||||
|
}
|
||||||
|
|
||||||
|
eprintln!();
|
||||||
|
|
||||||
|
// 8. Print links
|
||||||
|
print_links(&opts.username, &secret, opts.port, &opts.domain);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_secret() -> String {
|
||||||
|
let mut rng = rand::rng();
|
||||||
|
let bytes: Vec<u8> = (0..16).map(|_| rng.random::<u8>()).collect();
|
||||||
|
hex::encode(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_config(username: &str, secret: &str, port: u16, domain: &str) -> String {
|
||||||
|
format!(
|
||||||
|
r#"# Telemt MTProxy — auto-generated config
|
||||||
|
# Re-run `telemt --init` to regenerate
|
||||||
|
|
||||||
|
show_link = ["{username}"]
|
||||||
|
|
||||||
|
[general]
|
||||||
|
prefer_ipv6 = false
|
||||||
|
fast_mode = true
|
||||||
|
use_middle_proxy = false
|
||||||
|
log_level = "normal"
|
||||||
|
|
||||||
|
[general.modes]
|
||||||
|
classic = false
|
||||||
|
secure = false
|
||||||
|
tls = true
|
||||||
|
|
||||||
|
[server]
|
||||||
|
port = {port}
|
||||||
|
listen_addr_ipv4 = "0.0.0.0"
|
||||||
|
listen_addr_ipv6 = "::"
|
||||||
|
|
||||||
|
[[server.listeners]]
|
||||||
|
ip = "0.0.0.0"
|
||||||
|
|
||||||
|
[[server.listeners]]
|
||||||
|
ip = "::"
|
||||||
|
|
||||||
|
[timeouts]
|
||||||
|
client_handshake = 15
|
||||||
|
tg_connect = 10
|
||||||
|
client_keepalive = 60
|
||||||
|
client_ack = 300
|
||||||
|
|
||||||
|
[censorship]
|
||||||
|
tls_domain = "{domain}"
|
||||||
|
mask = true
|
||||||
|
mask_port = 443
|
||||||
|
fake_cert_len = 2048
|
||||||
|
|
||||||
|
[access]
|
||||||
|
replay_check_len = 65536
|
||||||
|
replay_window_secs = 1800
|
||||||
|
ignore_time_skew = false
|
||||||
|
|
||||||
|
[access.users]
|
||||||
|
{username} = "{secret}"
|
||||||
|
|
||||||
|
[[upstreams]]
|
||||||
|
type = "direct"
|
||||||
|
enabled = true
|
||||||
|
weight = 10
|
||||||
|
"#,
|
||||||
|
username = username,
|
||||||
|
secret = secret,
|
||||||
|
port = port,
|
||||||
|
domain = domain,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_systemd_unit(exe_path: &Path, config_path: &Path) -> String {
|
||||||
|
format!(
|
||||||
|
r#"[Unit]
|
||||||
|
Description=Telemt MTProxy
|
||||||
|
Documentation=https://github.com/nicepkg/telemt
|
||||||
|
After=network-online.target
|
||||||
|
Wants=network-online.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
ExecStart={exe} {config}
|
||||||
|
Restart=always
|
||||||
|
RestartSec=5
|
||||||
|
LimitNOFILE=65535
|
||||||
|
# Security hardening
|
||||||
|
NoNewPrivileges=true
|
||||||
|
ProtectSystem=strict
|
||||||
|
ProtectHome=true
|
||||||
|
ReadWritePaths=/etc/telemt
|
||||||
|
PrivateTmp=true
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
|
"#,
|
||||||
|
exe = exe_path.display(),
|
||||||
|
config = config_path.display(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_cmd(cmd: &str, args: &[&str]) {
|
||||||
|
match Command::new(cmd).args(args).output() {
|
||||||
|
Ok(output) => {
|
||||||
|
if !output.status.success() {
|
||||||
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
|
eprintln!("[!] {} {} failed: {}", cmd, args.join(" "), stderr.trim());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("[!] Failed to run {} {}: {}", cmd, args.join(" "), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_links(username: &str, secret: &str, port: u16, domain: &str) {
|
||||||
|
let domain_hex = hex::encode(domain);
|
||||||
|
|
||||||
|
println!("=== Proxy Links ===");
|
||||||
|
println!("[{}]", username);
|
||||||
|
println!(" EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}",
|
||||||
|
port, secret, domain_hex);
|
||||||
|
println!();
|
||||||
|
println!("Replace YOUR_SERVER_IP with your server's public IP.");
|
||||||
|
println!("The proxy will auto-detect and display the correct link on startup.");
|
||||||
|
println!("Check: journalctl -u telemt.service | head -30");
|
||||||
|
println!("===================");
|
||||||
|
}
|
||||||
@@ -1,31 +1,108 @@
|
|||||||
//! Configuration
|
//! Configuration
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use crate::error::{ProxyError, Result};
|
||||||
use std::net::IpAddr;
|
|
||||||
use std::path::Path;
|
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use crate::error::{ProxyError, Result};
|
use std::collections::HashMap;
|
||||||
|
use std::net::{IpAddr, SocketAddr};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
// ============= Helper Defaults =============
|
// ============= Helper Defaults =============
|
||||||
|
|
||||||
fn default_true() -> bool { true }
|
fn default_true() -> bool {
|
||||||
fn default_port() -> u16 { 443 }
|
true
|
||||||
fn default_tls_domain() -> String { "www.google.com".to_string() }
|
}
|
||||||
fn default_mask_port() -> u16 { 443 }
|
fn default_port() -> u16 {
|
||||||
fn default_replay_check_len() -> usize { 65536 }
|
443
|
||||||
fn default_handshake_timeout() -> u64 { 15 }
|
}
|
||||||
fn default_connect_timeout() -> u64 { 10 }
|
fn default_tls_domain() -> String {
|
||||||
fn default_keepalive() -> u64 { 60 }
|
"www.google.com".to_string()
|
||||||
fn default_ack_timeout() -> u64 { 300 }
|
}
|
||||||
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
|
fn default_mask_port() -> u16 {
|
||||||
fn default_fake_cert_len() -> usize { 2048 }
|
443
|
||||||
fn default_weight() -> u16 { 1 }
|
}
|
||||||
|
fn default_replay_check_len() -> usize {
|
||||||
|
65536
|
||||||
|
}
|
||||||
|
fn default_replay_window_secs() -> u64 {
|
||||||
|
1800
|
||||||
|
}
|
||||||
|
fn default_handshake_timeout() -> u64 {
|
||||||
|
15
|
||||||
|
}
|
||||||
|
fn default_connect_timeout() -> u64 {
|
||||||
|
10
|
||||||
|
}
|
||||||
|
fn default_keepalive() -> u64 {
|
||||||
|
60
|
||||||
|
}
|
||||||
|
fn default_ack_timeout() -> u64 {
|
||||||
|
300
|
||||||
|
}
|
||||||
|
fn default_listen_addr() -> String {
|
||||||
|
"0.0.0.0".to_string()
|
||||||
|
}
|
||||||
|
fn default_fake_cert_len() -> usize {
|
||||||
|
2048
|
||||||
|
}
|
||||||
|
fn default_weight() -> u16 {
|
||||||
|
1
|
||||||
|
}
|
||||||
fn default_metrics_whitelist() -> Vec<IpAddr> {
|
fn default_metrics_whitelist() -> Vec<IpAddr> {
|
||||||
vec![
|
vec!["127.0.0.1".parse().unwrap(), "::1".parse().unwrap()]
|
||||||
"127.0.0.1".parse().unwrap(),
|
}
|
||||||
"::1".parse().unwrap(),
|
|
||||||
]
|
// ============= Log Level =============
|
||||||
|
|
||||||
|
/// Logging verbosity level
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum LogLevel {
|
||||||
|
/// All messages including trace (trace + debug + info + warn + error)
|
||||||
|
Debug,
|
||||||
|
/// Detailed operational logs (debug + info + warn + error)
|
||||||
|
Verbose,
|
||||||
|
/// Standard operational logs (info + warn + error)
|
||||||
|
#[default]
|
||||||
|
Normal,
|
||||||
|
/// Minimal output: only warnings and errors (warn + error).
|
||||||
|
/// Startup messages (config, DC connectivity, proxy links) are always shown
|
||||||
|
/// via info! before the filter is applied.
|
||||||
|
Silent,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LogLevel {
|
||||||
|
/// Convert to tracing EnvFilter directive string
|
||||||
|
pub fn to_filter_str(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
LogLevel::Debug => "trace",
|
||||||
|
LogLevel::Verbose => "debug",
|
||||||
|
LogLevel::Normal => "info",
|
||||||
|
LogLevel::Silent => "warn",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse from a loose string (CLI argument)
|
||||||
|
pub fn from_str_loose(s: &str) -> Self {
|
||||||
|
match s.to_lowercase().as_str() {
|
||||||
|
"debug" | "trace" => LogLevel::Debug,
|
||||||
|
"verbose" => LogLevel::Verbose,
|
||||||
|
"normal" | "info" => LogLevel::Normal,
|
||||||
|
"silent" | "quiet" | "error" | "warn" => LogLevel::Silent,
|
||||||
|
_ => LogLevel::Normal,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for LogLevel {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
LogLevel::Debug => write!(f, "debug"),
|
||||||
|
LogLevel::Verbose => write!(f, "verbose"),
|
||||||
|
LogLevel::Normal => write!(f, "normal"),
|
||||||
|
LogLevel::Silent => write!(f, "silent"),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============= Sub-Configs =============
|
// ============= Sub-Configs =============
|
||||||
@@ -42,7 +119,11 @@ pub struct ProxyModes {
|
|||||||
|
|
||||||
impl Default for ProxyModes {
|
impl Default for ProxyModes {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self { classic: true, secure: true, tls: true }
|
Self {
|
||||||
|
classic: true,
|
||||||
|
secure: true,
|
||||||
|
tls: true,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,6 +143,27 @@ pub struct GeneralConfig {
|
|||||||
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub ad_tag: Option<String>,
|
pub ad_tag: Option<String>,
|
||||||
|
|
||||||
|
/// Path to proxy-secret binary file (auto-downloaded if absent).
|
||||||
|
/// Infrastructure secret from https://core.telegram.org/getProxySecret
|
||||||
|
#[serde(default)]
|
||||||
|
pub proxy_secret_path: Option<String>,
|
||||||
|
|
||||||
|
/// Public IP override for middle-proxy NAT environments.
|
||||||
|
/// When set, this IP is used in ME key derivation and RPC_PROXY_REQ "our_addr".
|
||||||
|
#[serde(default)]
|
||||||
|
pub middle_proxy_nat_ip: Option<IpAddr>,
|
||||||
|
|
||||||
|
/// Enable STUN-based NAT probing to discover public IP:port for ME KDF.
|
||||||
|
#[serde(default)]
|
||||||
|
pub middle_proxy_nat_probe: bool,
|
||||||
|
|
||||||
|
/// Optional STUN server address (host:port) for NAT probing.
|
||||||
|
#[serde(default)]
|
||||||
|
pub middle_proxy_nat_stun: Option<String>,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub log_level: LogLevel,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for GeneralConfig {
|
impl Default for GeneralConfig {
|
||||||
@@ -72,6 +174,11 @@ impl Default for GeneralConfig {
|
|||||||
fast_mode: true,
|
fast_mode: true,
|
||||||
use_middle_proxy: false,
|
use_middle_proxy: false,
|
||||||
ad_tag: None,
|
ad_tag: None,
|
||||||
|
proxy_secret_path: None,
|
||||||
|
middle_proxy_nat_ip: None,
|
||||||
|
middle_proxy_nat_probe: false,
|
||||||
|
middle_proxy_nat_stun: None,
|
||||||
|
log_level: LogLevel::Normal,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -154,6 +261,9 @@ pub struct AntiCensorshipConfig {
|
|||||||
#[serde(default = "default_mask_port")]
|
#[serde(default = "default_mask_port")]
|
||||||
pub mask_port: u16,
|
pub mask_port: u16,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub mask_unix_sock: Option<String>,
|
||||||
|
|
||||||
#[serde(default = "default_fake_cert_len")]
|
#[serde(default = "default_fake_cert_len")]
|
||||||
pub fake_cert_len: usize,
|
pub fake_cert_len: usize,
|
||||||
}
|
}
|
||||||
@@ -165,6 +275,7 @@ impl Default for AntiCensorshipConfig {
|
|||||||
mask: true,
|
mask: true,
|
||||||
mask_host: None,
|
mask_host: None,
|
||||||
mask_port: default_mask_port(),
|
mask_port: default_mask_port(),
|
||||||
|
mask_unix_sock: None,
|
||||||
fake_cert_len: default_fake_cert_len(),
|
fake_cert_len: default_fake_cert_len(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -187,6 +298,9 @@ pub struct AccessConfig {
|
|||||||
#[serde(default = "default_replay_check_len")]
|
#[serde(default = "default_replay_check_len")]
|
||||||
pub replay_check_len: usize,
|
pub replay_check_len: usize,
|
||||||
|
|
||||||
|
#[serde(default = "default_replay_window_secs")]
|
||||||
|
pub replay_window_secs: u64,
|
||||||
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub ignore_time_skew: bool,
|
pub ignore_time_skew: bool,
|
||||||
}
|
}
|
||||||
@@ -194,13 +308,17 @@ pub struct AccessConfig {
|
|||||||
impl Default for AccessConfig {
|
impl Default for AccessConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
let mut users = HashMap::new();
|
let mut users = HashMap::new();
|
||||||
users.insert("default".to_string(), "00000000000000000000000000000000".to_string());
|
users.insert(
|
||||||
|
"default".to_string(),
|
||||||
|
"00000000000000000000000000000000".to_string(),
|
||||||
|
);
|
||||||
Self {
|
Self {
|
||||||
users,
|
users,
|
||||||
user_max_tcp_conns: HashMap::new(),
|
user_max_tcp_conns: HashMap::new(),
|
||||||
user_expirations: HashMap::new(),
|
user_expirations: HashMap::new(),
|
||||||
user_data_quota: HashMap::new(),
|
user_data_quota: HashMap::new(),
|
||||||
replay_check_len: default_replay_check_len(),
|
replay_check_len: default_replay_check_len(),
|
||||||
|
replay_window_secs: default_replay_window_secs(),
|
||||||
ignore_time_skew: false,
|
ignore_time_skew: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -274,15 +392,30 @@ pub struct ProxyConfig {
|
|||||||
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub show_link: Vec<String>,
|
pub show_link: Vec<String>,
|
||||||
|
|
||||||
|
/// DC address overrides for non-standard DCs (CDN, media, test, etc.)
|
||||||
|
/// Keys are DC indices as strings, values are "ip:port" addresses.
|
||||||
|
/// Matches the C implementation's `proxy_for <dc_id> <ip>:<port>` config directive.
|
||||||
|
/// Example in config.toml:
|
||||||
|
/// [dc_overrides]
|
||||||
|
/// "203" = "149.154.175.100:443"
|
||||||
|
#[serde(default)]
|
||||||
|
pub dc_overrides: HashMap<String, String>,
|
||||||
|
|
||||||
|
/// Default DC index (1-5) for unmapped non-standard DCs.
|
||||||
|
/// Matches the C implementation's `default <dc_id>` config directive.
|
||||||
|
/// If not set, defaults to 2 (matching Telegram's official `default 2;` in proxy-multi.conf).
|
||||||
|
#[serde(default)]
|
||||||
|
pub default_dc: Option<u8>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ProxyConfig {
|
impl ProxyConfig {
|
||||||
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
|
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
|
||||||
let content = std::fs::read_to_string(path)
|
let content =
|
||||||
.map_err(|e| ProxyError::Config(e.to_string()))?;
|
std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?;
|
||||||
|
|
||||||
let mut config: ProxyConfig = toml::from_str(&content)
|
let mut config: ProxyConfig =
|
||||||
.map_err(|e| ProxyError::Config(e.to_string()))?;
|
toml::from_str(&content).map_err(|e| ProxyError::Config(e.to_string()))?;
|
||||||
|
|
||||||
// Validate secrets
|
// Validate secrets
|
||||||
for (user, secret) in &config.access.users {
|
for (user, secret) in &config.access.users {
|
||||||
@@ -299,20 +432,40 @@ impl ProxyConfig {
|
|||||||
return Err(ProxyError::Config("tls_domain cannot be empty".to_string()));
|
return Err(ProxyError::Config("tls_domain cannot be empty".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Warn if using default tls_domain
|
// Validate mask_unix_sock
|
||||||
if config.censorship.tls_domain == "www.google.com" {
|
if let Some(ref sock_path) = config.censorship.mask_unix_sock {
|
||||||
tracing::warn!("Using default tls_domain (www.google.com). Consider setting a custom domain in config.toml");
|
if sock_path.is_empty() {
|
||||||
|
return Err(ProxyError::Config(
|
||||||
|
"mask_unix_sock cannot be empty".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
#[cfg(unix)]
|
||||||
|
if sock_path.len() > 107 {
|
||||||
|
return Err(ProxyError::Config(format!(
|
||||||
|
"mask_unix_sock path too long: {} bytes (max 107)",
|
||||||
|
sock_path.len()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
return Err(ProxyError::Config(
|
||||||
|
"mask_unix_sock is only supported on Unix platforms".to_string(),
|
||||||
|
));
|
||||||
|
|
||||||
|
if config.censorship.mask_host.is_some() {
|
||||||
|
return Err(ProxyError::Config(
|
||||||
|
"mask_unix_sock and mask_host are mutually exclusive".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default mask_host to tls_domain if not set
|
// Default mask_host to tls_domain if not set and no unix socket configured
|
||||||
if config.censorship.mask_host.is_none() {
|
if config.censorship.mask_host.is_none() && config.censorship.mask_unix_sock.is_none() {
|
||||||
tracing::info!("mask_host not set, using tls_domain ({}) for masking", config.censorship.tls_domain);
|
|
||||||
config.censorship.mask_host = Some(config.censorship.tls_domain.clone());
|
config.censorship.mask_host = Some(config.censorship.tls_domain.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Random fake_cert_len
|
// Random fake_cert_len
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
config.censorship.fake_cert_len = rand::thread_rng().gen_range(1024..4096);
|
config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096);
|
||||||
|
|
||||||
// Migration: Populate listeners if empty
|
// Migration: Populate listeners if empty
|
||||||
if config.server.listeners.is_empty() {
|
if config.server.listeners.is_empty() {
|
||||||
@@ -353,11 +506,11 @@ impl ProxyConfig {
|
|||||||
return Err(ProxyError::Config("No modes enabled".to_string()));
|
return Err(ProxyError::Config("No modes enabled".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate tls_domain format (basic check)
|
|
||||||
if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') {
|
if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') {
|
||||||
return Err(ProxyError::Config(
|
return Err(ProxyError::Config(format!(
|
||||||
format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain)
|
"Invalid tls_domain: '{}'. Must be a valid domain name",
|
||||||
));
|
self.censorship.tls_domain
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -1,9 +1,19 @@
|
|||||||
//! AES encryption implementations
|
//! AES encryption implementations
|
||||||
//!
|
//!
|
||||||
//! Provides AES-256-CTR and AES-256-CBC modes for MTProto encryption.
|
//! Provides AES-256-CTR and AES-256-CBC modes for MTProto encryption.
|
||||||
|
//!
|
||||||
|
//! ## Zeroize policy
|
||||||
|
//!
|
||||||
|
//! - `AesCbc` stores raw key/IV bytes and zeroizes them on drop.
|
||||||
|
//! - `AesCtr` wraps an opaque `Aes256Ctr` cipher from the `ctr` crate.
|
||||||
|
//! The expanded key schedule lives inside that type and cannot be
|
||||||
|
//! zeroized from outside. Callers that hold raw key material (e.g.
|
||||||
|
//! `HandshakeSuccess`, `ObfuscationParams`) are responsible for
|
||||||
|
//! zeroizing their own copies.
|
||||||
|
|
||||||
use aes::Aes256;
|
use aes::Aes256;
|
||||||
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
|
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
|
||||||
|
use zeroize::Zeroize;
|
||||||
use crate::error::{ProxyError, Result};
|
use crate::error::{ProxyError, Result};
|
||||||
|
|
||||||
type Aes256Ctr = Ctr128BE<Aes256>;
|
type Aes256Ctr = Ctr128BE<Aes256>;
|
||||||
@@ -12,7 +22,12 @@ type Aes256Ctr = Ctr128BE<Aes256>;
|
|||||||
|
|
||||||
/// AES-256-CTR encryptor/decryptor
|
/// AES-256-CTR encryptor/decryptor
|
||||||
///
|
///
|
||||||
/// CTR mode is symmetric - encryption and decryption are the same operation.
|
/// CTR mode is symmetric — encryption and decryption are the same operation.
|
||||||
|
///
|
||||||
|
/// **Zeroize note:** The inner `Aes256Ctr` cipher state (expanded key schedule
|
||||||
|
/// + counter) is opaque and cannot be zeroized. If you need to protect key
|
||||||
|
/// material, zeroize the `[u8; 32]` key and `u128` IV at the call site
|
||||||
|
/// before dropping them.
|
||||||
pub struct AesCtr {
|
pub struct AesCtr {
|
||||||
cipher: Aes256Ctr,
|
cipher: Aes256Ctr,
|
||||||
}
|
}
|
||||||
@@ -62,14 +77,23 @@ impl AesCtr {
|
|||||||
|
|
||||||
/// AES-256-CBC cipher with proper chaining
|
/// AES-256-CBC cipher with proper chaining
|
||||||
///
|
///
|
||||||
/// Unlike CTR mode, CBC is NOT symmetric - encryption and decryption
|
/// Unlike CTR mode, CBC is NOT symmetric — encryption and decryption
|
||||||
/// are different operations. This implementation handles CBC chaining
|
/// are different operations. This implementation handles CBC chaining
|
||||||
/// correctly across multiple blocks.
|
/// correctly across multiple blocks.
|
||||||
|
///
|
||||||
|
/// Key and IV are zeroized on drop.
|
||||||
pub struct AesCbc {
|
pub struct AesCbc {
|
||||||
key: [u8; 32],
|
key: [u8; 32],
|
||||||
iv: [u8; 16],
|
iv: [u8; 16],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Drop for AesCbc {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.key.zeroize();
|
||||||
|
self.iv.zeroize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl AesCbc {
|
impl AesCbc {
|
||||||
/// AES block size
|
/// AES block size
|
||||||
const BLOCK_SIZE: usize = 16;
|
const BLOCK_SIZE: usize = 16;
|
||||||
@@ -141,17 +165,9 @@ impl AesCbc {
|
|||||||
|
|
||||||
for chunk in data.chunks(Self::BLOCK_SIZE) {
|
for chunk in data.chunks(Self::BLOCK_SIZE) {
|
||||||
let plaintext: [u8; 16] = chunk.try_into().unwrap();
|
let plaintext: [u8; 16] = chunk.try_into().unwrap();
|
||||||
|
|
||||||
// XOR plaintext with previous ciphertext (or IV for first block)
|
|
||||||
let xored = Self::xor_blocks(&plaintext, &prev_ciphertext);
|
let xored = Self::xor_blocks(&plaintext, &prev_ciphertext);
|
||||||
|
|
||||||
// Encrypt the XORed block
|
|
||||||
let ciphertext = self.encrypt_block(&xored, &key_schedule);
|
let ciphertext = self.encrypt_block(&xored, &key_schedule);
|
||||||
|
|
||||||
// Save for next iteration
|
|
||||||
prev_ciphertext = ciphertext;
|
prev_ciphertext = ciphertext;
|
||||||
|
|
||||||
// Append to result
|
|
||||||
result.extend_from_slice(&ciphertext);
|
result.extend_from_slice(&ciphertext);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,17 +196,9 @@ impl AesCbc {
|
|||||||
|
|
||||||
for chunk in data.chunks(Self::BLOCK_SIZE) {
|
for chunk in data.chunks(Self::BLOCK_SIZE) {
|
||||||
let ciphertext: [u8; 16] = chunk.try_into().unwrap();
|
let ciphertext: [u8; 16] = chunk.try_into().unwrap();
|
||||||
|
|
||||||
// Decrypt the block
|
|
||||||
let decrypted = self.decrypt_block(&ciphertext, &key_schedule);
|
let decrypted = self.decrypt_block(&ciphertext, &key_schedule);
|
||||||
|
|
||||||
// XOR with previous ciphertext (or IV for first block)
|
|
||||||
let plaintext = Self::xor_blocks(&decrypted, &prev_ciphertext);
|
let plaintext = Self::xor_blocks(&decrypted, &prev_ciphertext);
|
||||||
|
|
||||||
// Save current ciphertext for next iteration
|
|
||||||
prev_ciphertext = ciphertext;
|
prev_ciphertext = ciphertext;
|
||||||
|
|
||||||
// Append to result
|
|
||||||
result.extend_from_slice(&plaintext);
|
result.extend_from_slice(&plaintext);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -217,16 +225,13 @@ impl AesCbc {
|
|||||||
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
|
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
|
||||||
let block = &mut data[i..i + Self::BLOCK_SIZE];
|
let block = &mut data[i..i + Self::BLOCK_SIZE];
|
||||||
|
|
||||||
// XOR with previous ciphertext
|
|
||||||
for j in 0..Self::BLOCK_SIZE {
|
for j in 0..Self::BLOCK_SIZE {
|
||||||
block[j] ^= prev_ciphertext[j];
|
block[j] ^= prev_ciphertext[j];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encrypt in-place
|
|
||||||
let block_array: &mut [u8; 16] = block.try_into().unwrap();
|
let block_array: &mut [u8; 16] = block.try_into().unwrap();
|
||||||
*block_array = self.encrypt_block(block_array, &key_schedule);
|
*block_array = self.encrypt_block(block_array, &key_schedule);
|
||||||
|
|
||||||
// Save for next iteration
|
|
||||||
prev_ciphertext = *block_array;
|
prev_ciphertext = *block_array;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,26 +253,20 @@ impl AesCbc {
|
|||||||
use aes::cipher::KeyInit;
|
use aes::cipher::KeyInit;
|
||||||
let key_schedule = aes::Aes256::new((&self.key).into());
|
let key_schedule = aes::Aes256::new((&self.key).into());
|
||||||
|
|
||||||
// For in-place decryption, we need to save ciphertext blocks
|
|
||||||
// before we overwrite them
|
|
||||||
let mut prev_ciphertext = self.iv;
|
let mut prev_ciphertext = self.iv;
|
||||||
|
|
||||||
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
|
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
|
||||||
let block = &mut data[i..i + Self::BLOCK_SIZE];
|
let block = &mut data[i..i + Self::BLOCK_SIZE];
|
||||||
|
|
||||||
// Save current ciphertext before modifying
|
|
||||||
let current_ciphertext: [u8; 16] = block.try_into().unwrap();
|
let current_ciphertext: [u8; 16] = block.try_into().unwrap();
|
||||||
|
|
||||||
// Decrypt in-place
|
|
||||||
let block_array: &mut [u8; 16] = block.try_into().unwrap();
|
let block_array: &mut [u8; 16] = block.try_into().unwrap();
|
||||||
*block_array = self.decrypt_block(block_array, &key_schedule);
|
*block_array = self.decrypt_block(block_array, &key_schedule);
|
||||||
|
|
||||||
// XOR with previous ciphertext
|
|
||||||
for j in 0..Self::BLOCK_SIZE {
|
for j in 0..Self::BLOCK_SIZE {
|
||||||
block[j] ^= prev_ciphertext[j];
|
block[j] ^= prev_ciphertext[j];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save for next iteration
|
|
||||||
prev_ciphertext = current_ciphertext;
|
prev_ciphertext = current_ciphertext;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -347,10 +346,8 @@ mod tests {
|
|||||||
let mut cipher = AesCtr::new(&key, iv);
|
let mut cipher = AesCtr::new(&key, iv);
|
||||||
cipher.apply(&mut data);
|
cipher.apply(&mut data);
|
||||||
|
|
||||||
// Encrypted should be different
|
|
||||||
assert_ne!(&data[..], original);
|
assert_ne!(&data[..], original);
|
||||||
|
|
||||||
// Decrypt with fresh cipher
|
|
||||||
let mut cipher = AesCtr::new(&key, iv);
|
let mut cipher = AesCtr::new(&key, iv);
|
||||||
cipher.apply(&mut data);
|
cipher.apply(&mut data);
|
||||||
|
|
||||||
@@ -364,7 +361,7 @@ mod tests {
|
|||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
let iv = [0u8; 16];
|
let iv = [0u8; 16];
|
||||||
|
|
||||||
let original = [0u8; 32]; // 2 blocks
|
let original = [0u8; 32];
|
||||||
|
|
||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
let encrypted = cipher.encrypt(&original).unwrap();
|
let encrypted = cipher.encrypt(&original).unwrap();
|
||||||
@@ -375,31 +372,25 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_cbc_chaining_works() {
|
fn test_aes_cbc_chaining_works() {
|
||||||
// This is the key test - verify CBC chaining is correct
|
|
||||||
let key = [0x42u8; 32];
|
let key = [0x42u8; 32];
|
||||||
let iv = [0x00u8; 16];
|
let iv = [0x00u8; 16];
|
||||||
|
|
||||||
// Two IDENTICAL plaintext blocks
|
|
||||||
let plaintext = [0xAAu8; 32];
|
let plaintext = [0xAAu8; 32];
|
||||||
|
|
||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
||||||
|
|
||||||
// With proper CBC, identical plaintext blocks produce DIFFERENT ciphertext
|
|
||||||
let block1 = &ciphertext[0..16];
|
let block1 = &ciphertext[0..16];
|
||||||
let block2 = &ciphertext[16..32];
|
let block2 = &ciphertext[16..32];
|
||||||
|
|
||||||
assert_ne!(
|
assert_ne!(
|
||||||
block1, block2,
|
block1, block2,
|
||||||
"CBC chaining broken: identical plaintext blocks produced identical ciphertext. \
|
"CBC chaining broken: identical plaintext blocks produced identical ciphertext"
|
||||||
This indicates ECB mode, not CBC!"
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_cbc_known_vector() {
|
fn test_aes_cbc_known_vector() {
|
||||||
// Test with known NIST test vector
|
|
||||||
// AES-256-CBC with zero key and zero IV
|
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
let iv = [0u8; 16];
|
let iv = [0u8; 16];
|
||||||
let plaintext = [0u8; 16];
|
let plaintext = [0u8; 16];
|
||||||
@@ -407,11 +398,9 @@ mod tests {
|
|||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
||||||
|
|
||||||
// Decrypt and verify roundtrip
|
|
||||||
let decrypted = cipher.decrypt(&ciphertext).unwrap();
|
let decrypted = cipher.decrypt(&ciphertext).unwrap();
|
||||||
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
|
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
|
||||||
|
|
||||||
// Ciphertext should not be all zeros
|
|
||||||
assert_ne!(ciphertext.as_slice(), plaintext.as_slice());
|
assert_ne!(ciphertext.as_slice(), plaintext.as_slice());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -420,7 +409,6 @@ mod tests {
|
|||||||
let key = [0x12u8; 32];
|
let key = [0x12u8; 32];
|
||||||
let iv = [0x34u8; 16];
|
let iv = [0x34u8; 16];
|
||||||
|
|
||||||
// 5 blocks = 80 bytes
|
|
||||||
let plaintext: Vec<u8> = (0..80).collect();
|
let plaintext: Vec<u8> = (0..80).collect();
|
||||||
|
|
||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
@@ -435,7 +423,7 @@ mod tests {
|
|||||||
let key = [0x12u8; 32];
|
let key = [0x12u8; 32];
|
||||||
let iv = [0x34u8; 16];
|
let iv = [0x34u8; 16];
|
||||||
|
|
||||||
let original = [0x56u8; 48]; // 3 blocks
|
let original = [0x56u8; 48];
|
||||||
let mut buffer = original;
|
let mut buffer = original;
|
||||||
|
|
||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
@@ -462,41 +450,33 @@ mod tests {
|
|||||||
fn test_aes_cbc_unaligned_error() {
|
fn test_aes_cbc_unaligned_error() {
|
||||||
let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
|
let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
|
||||||
|
|
||||||
// 15 bytes - not aligned to block size
|
|
||||||
let result = cipher.encrypt(&[0u8; 15]);
|
let result = cipher.encrypt(&[0u8; 15]);
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
|
|
||||||
// 17 bytes - not aligned
|
|
||||||
let result = cipher.encrypt(&[0u8; 17]);
|
let result = cipher.encrypt(&[0u8; 17]);
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_cbc_avalanche_effect() {
|
fn test_aes_cbc_avalanche_effect() {
|
||||||
// Changing one bit in plaintext should change entire ciphertext block
|
|
||||||
// and all subsequent blocks (due to chaining)
|
|
||||||
let key = [0xAB; 32];
|
let key = [0xAB; 32];
|
||||||
let iv = [0xCD; 16];
|
let iv = [0xCD; 16];
|
||||||
|
|
||||||
let mut plaintext1 = [0u8; 32];
|
let plaintext1 = [0u8; 32];
|
||||||
let mut plaintext2 = [0u8; 32];
|
let mut plaintext2 = [0u8; 32];
|
||||||
plaintext2[0] = 0x01; // Single bit difference in first block
|
plaintext2[0] = 0x01;
|
||||||
|
|
||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
|
|
||||||
let ciphertext1 = cipher.encrypt(&plaintext1).unwrap();
|
let ciphertext1 = cipher.encrypt(&plaintext1).unwrap();
|
||||||
let ciphertext2 = cipher.encrypt(&plaintext2).unwrap();
|
let ciphertext2 = cipher.encrypt(&plaintext2).unwrap();
|
||||||
|
|
||||||
// First blocks should be different
|
|
||||||
assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]);
|
assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]);
|
||||||
|
|
||||||
// Second blocks should ALSO be different (chaining effect)
|
|
||||||
assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]);
|
assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_cbc_iv_matters() {
|
fn test_aes_cbc_iv_matters() {
|
||||||
// Same plaintext with different IVs should produce different ciphertext
|
|
||||||
let key = [0x55; 32];
|
let key = [0x55; 32];
|
||||||
let plaintext = [0x77u8; 16];
|
let plaintext = [0x77u8; 16];
|
||||||
|
|
||||||
@@ -511,7 +491,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_cbc_deterministic() {
|
fn test_aes_cbc_deterministic() {
|
||||||
// Same key, IV, plaintext should always produce same ciphertext
|
|
||||||
let key = [0x99; 32];
|
let key = [0x99; 32];
|
||||||
let iv = [0x88; 16];
|
let iv = [0x88; 16];
|
||||||
let plaintext = [0x77u8; 32];
|
let plaintext = [0x77u8; 32];
|
||||||
@@ -524,6 +503,23 @@ mod tests {
|
|||||||
assert_eq!(ciphertext1, ciphertext2);
|
assert_eq!(ciphertext1, ciphertext2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============= Zeroize Tests =============
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_aes_cbc_zeroize_on_drop() {
|
||||||
|
let key = [0xAA; 32];
|
||||||
|
let iv = [0xBB; 16];
|
||||||
|
|
||||||
|
let cipher = AesCbc::new(key, iv);
|
||||||
|
// Verify key/iv are set
|
||||||
|
assert_eq!(cipher.key, [0xAA; 32]);
|
||||||
|
assert_eq!(cipher.iv, [0xBB; 16]);
|
||||||
|
|
||||||
|
drop(cipher);
|
||||||
|
// After drop, key/iv are zeroized (can't observe directly,
|
||||||
|
// but the Drop impl runs without panic)
|
||||||
|
}
|
||||||
|
|
||||||
// ============= Error Handling Tests =============
|
// ============= Error Handling Tests =============
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -1,3 +1,16 @@
|
|||||||
|
//! Cryptographic hash functions
|
||||||
|
//!
|
||||||
|
//! ## Protocol-required algorithms
|
||||||
|
//!
|
||||||
|
//! This module exposes MD5 and SHA-1 alongside SHA-256. These weaker
|
||||||
|
//! hash functions are **required by the Telegram Middle Proxy protocol**
|
||||||
|
//! (`derive_middleproxy_keys`) and cannot be replaced without breaking
|
||||||
|
//! compatibility. They are NOT used for any security-sensitive purpose
|
||||||
|
//! outside of that specific key derivation scheme mandated by Telegram.
|
||||||
|
//!
|
||||||
|
//! Static analysis tools (CodeQL, cargo-audit) may flag them — the
|
||||||
|
//! usages are intentional and protocol-mandated.
|
||||||
|
|
||||||
use hmac::{Hmac, Mac};
|
use hmac::{Hmac, Mac};
|
||||||
use sha2::Sha256;
|
use sha2::Sha256;
|
||||||
use md5::Md5;
|
use md5::Md5;
|
||||||
@@ -21,14 +34,16 @@ pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] {
|
|||||||
mac.finalize().into_bytes().into()
|
mac.finalize().into_bytes().into()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// SHA-1
|
/// SHA-1 — **protocol-required** by Telegram Middle Proxy key derivation.
|
||||||
|
/// Not used for general-purpose hashing.
|
||||||
pub fn sha1(data: &[u8]) -> [u8; 20] {
|
pub fn sha1(data: &[u8]) -> [u8; 20] {
|
||||||
let mut hasher = Sha1::new();
|
let mut hasher = Sha1::new();
|
||||||
hasher.update(data);
|
hasher.update(data);
|
||||||
hasher.finalize().into()
|
hasher.finalize().into()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// MD5
|
/// MD5 — **protocol-required** by Telegram Middle Proxy key derivation.
|
||||||
|
/// Not used for general-purpose hashing.
|
||||||
pub fn md5(data: &[u8]) -> [u8; 16] {
|
pub fn md5(data: &[u8]) -> [u8; 16] {
|
||||||
let mut hasher = Md5::new();
|
let mut hasher = Md5::new();
|
||||||
hasher.update(data);
|
hasher.update(data);
|
||||||
@@ -40,8 +55,11 @@ pub fn crc32(data: &[u8]) -> u32 {
|
|||||||
crc32fast::hash(data)
|
crc32fast::hash(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Middle Proxy Keygen
|
/// Build the exact prekey buffer used by Telegram Middle Proxy KDF.
|
||||||
pub fn derive_middleproxy_keys(
|
///
|
||||||
|
/// Returned buffer layout (IPv4):
|
||||||
|
/// nonce_srv | nonce_clt | clt_ts | srv_ip | clt_port | purpose | clt_ip | srv_port | secret | nonce_srv | [clt_v6 | srv_v6] | nonce_clt
|
||||||
|
pub fn build_middleproxy_prekey(
|
||||||
nonce_srv: &[u8; 16],
|
nonce_srv: &[u8; 16],
|
||||||
nonce_clt: &[u8; 16],
|
nonce_clt: &[u8; 16],
|
||||||
clt_ts: &[u8; 4],
|
clt_ts: &[u8; 4],
|
||||||
@@ -53,7 +71,7 @@ pub fn derive_middleproxy_keys(
|
|||||||
secret: &[u8],
|
secret: &[u8],
|
||||||
clt_ipv6: Option<&[u8; 16]>,
|
clt_ipv6: Option<&[u8; 16]>,
|
||||||
srv_ipv6: Option<&[u8; 16]>,
|
srv_ipv6: Option<&[u8; 16]>,
|
||||||
) -> ([u8; 32], [u8; 16]) {
|
) -> Vec<u8> {
|
||||||
const EMPTY_IP: [u8; 4] = [0, 0, 0, 0];
|
const EMPTY_IP: [u8; 4] = [0, 0, 0, 0];
|
||||||
|
|
||||||
let srv_ip = srv_ip.unwrap_or(&EMPTY_IP);
|
let srv_ip = srv_ip.unwrap_or(&EMPTY_IP);
|
||||||
@@ -77,6 +95,40 @@ pub fn derive_middleproxy_keys(
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.extend_from_slice(nonce_clt);
|
s.extend_from_slice(nonce_clt);
|
||||||
|
s
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Middle Proxy key derivation
|
||||||
|
///
|
||||||
|
/// Uses MD5 + SHA-1 as mandated by the Telegram Middle Proxy protocol.
|
||||||
|
/// These algorithms are NOT replaceable here — changing them would break
|
||||||
|
/// interoperability with Telegram's middle proxy infrastructure.
|
||||||
|
pub fn derive_middleproxy_keys(
|
||||||
|
nonce_srv: &[u8; 16],
|
||||||
|
nonce_clt: &[u8; 16],
|
||||||
|
clt_ts: &[u8; 4],
|
||||||
|
srv_ip: Option<&[u8]>,
|
||||||
|
clt_port: &[u8; 2],
|
||||||
|
purpose: &[u8],
|
||||||
|
clt_ip: Option<&[u8]>,
|
||||||
|
srv_port: &[u8; 2],
|
||||||
|
secret: &[u8],
|
||||||
|
clt_ipv6: Option<&[u8; 16]>,
|
||||||
|
srv_ipv6: Option<&[u8; 16]>,
|
||||||
|
) -> ([u8; 32], [u8; 16]) {
|
||||||
|
let s = build_middleproxy_prekey(
|
||||||
|
nonce_srv,
|
||||||
|
nonce_clt,
|
||||||
|
clt_ts,
|
||||||
|
srv_ip,
|
||||||
|
clt_port,
|
||||||
|
purpose,
|
||||||
|
clt_ip,
|
||||||
|
srv_port,
|
||||||
|
secret,
|
||||||
|
clt_ipv6,
|
||||||
|
srv_ipv6,
|
||||||
|
);
|
||||||
|
|
||||||
let md5_1 = md5(&s[1..]);
|
let md5_1 = md5(&s[1..]);
|
||||||
let sha1_sum = sha1(&s);
|
let sha1_sum = sha1(&s);
|
||||||
@@ -88,3 +140,39 @@ pub fn derive_middleproxy_keys(
|
|||||||
|
|
||||||
(key, md5_2)
|
(key, md5_2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn middleproxy_prekey_sha_is_stable() {
|
||||||
|
let nonce_srv = [0x11u8; 16];
|
||||||
|
let nonce_clt = [0x22u8; 16];
|
||||||
|
let clt_ts = 0x44332211u32.to_le_bytes();
|
||||||
|
let srv_ip = Some([149u8, 154, 175, 50].as_ref());
|
||||||
|
let clt_ip = Some([10u8, 0, 0, 1].as_ref());
|
||||||
|
let clt_port = 0x1f90u16.to_le_bytes(); // 8080
|
||||||
|
let srv_port = 0x22b8u16.to_le_bytes(); // 8888
|
||||||
|
let secret = vec![0x55u8; 128];
|
||||||
|
|
||||||
|
let prekey = build_middleproxy_prekey(
|
||||||
|
&nonce_srv,
|
||||||
|
&nonce_clt,
|
||||||
|
&clt_ts,
|
||||||
|
srv_ip,
|
||||||
|
&clt_port,
|
||||||
|
b"CLIENT",
|
||||||
|
clt_ip,
|
||||||
|
&srv_port,
|
||||||
|
&secret,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let digest = sha256(&prekey);
|
||||||
|
assert_eq!(
|
||||||
|
hex::encode(digest),
|
||||||
|
"a4595b75f1f610f2575ace802ddc65c91b5acef3b0e0d18189e0c7c9f787d15c"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,5 +5,5 @@ pub mod hash;
|
|||||||
pub mod random;
|
pub mod random;
|
||||||
|
|
||||||
pub use aes::{AesCtr, AesCbc};
|
pub use aes::{AesCtr, AesCbc};
|
||||||
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32};
|
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32, derive_middleproxy_keys, build_middleproxy_prekey};
|
||||||
pub use random::{SecureRandom, SECURE_RANDOM};
|
pub use random::SecureRandom;
|
||||||
|
|||||||
@@ -3,11 +3,8 @@
|
|||||||
use rand::{Rng, RngCore, SeedableRng};
|
use rand::{Rng, RngCore, SeedableRng};
|
||||||
use rand::rngs::StdRng;
|
use rand::rngs::StdRng;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
use zeroize::Zeroize;
|
||||||
use crate::crypto::AesCtr;
|
use crate::crypto::AesCtr;
|
||||||
use once_cell::sync::Lazy;
|
|
||||||
|
|
||||||
/// Global secure random instance
|
|
||||||
pub static SECURE_RANDOM: Lazy<SecureRandom> = Lazy::new(SecureRandom::new);
|
|
||||||
|
|
||||||
/// Cryptographically secure PRNG with AES-CTR
|
/// Cryptographically secure PRNG with AES-CTR
|
||||||
pub struct SecureRandom {
|
pub struct SecureRandom {
|
||||||
@@ -20,18 +17,30 @@ struct SecureRandomInner {
|
|||||||
buffer: Vec<u8>,
|
buffer: Vec<u8>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Drop for SecureRandomInner {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.buffer.zeroize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl SecureRandom {
|
impl SecureRandom {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
let mut rng = StdRng::from_entropy();
|
let mut seed_source = rand::rng();
|
||||||
|
let mut rng = StdRng::from_rng(&mut seed_source);
|
||||||
|
|
||||||
let mut key = [0u8; 32];
|
let mut key = [0u8; 32];
|
||||||
rng.fill_bytes(&mut key);
|
rng.fill_bytes(&mut key);
|
||||||
let iv: u128 = rng.gen();
|
let iv: u128 = rng.random();
|
||||||
|
|
||||||
|
let cipher = AesCtr::new(&key, iv);
|
||||||
|
|
||||||
|
// Zeroize local key copy — cipher already consumed it
|
||||||
|
key.zeroize();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
inner: Mutex::new(SecureRandomInner {
|
inner: Mutex::new(SecureRandomInner {
|
||||||
rng,
|
rng,
|
||||||
cipher: AesCtr::new(&key, iv),
|
cipher,
|
||||||
buffer: Vec::with_capacity(1024),
|
buffer: Vec::with_capacity(1024),
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
@@ -78,7 +87,6 @@ impl SecureRandom {
|
|||||||
result |= (b as u64) << (i * 8);
|
result |= (b as u64) << (i * 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mask extra bits
|
|
||||||
if k < 64 {
|
if k < 64 {
|
||||||
result &= (1u64 << k) - 1;
|
result &= (1u64 << k) - 1;
|
||||||
}
|
}
|
||||||
@@ -107,13 +115,13 @@ impl SecureRandom {
|
|||||||
/// Generate random u32
|
/// Generate random u32
|
||||||
pub fn u32(&self) -> u32 {
|
pub fn u32(&self) -> u32 {
|
||||||
let mut inner = self.inner.lock();
|
let mut inner = self.inner.lock();
|
||||||
inner.rng.gen()
|
inner.rng.random()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate random u64
|
/// Generate random u64
|
||||||
pub fn u64(&self) -> u64 {
|
pub fn u64(&self) -> u64 {
|
||||||
let mut inner = self.inner.lock();
|
let mut inner = self.inner.lock();
|
||||||
inner.rng.gen()
|
inner.rng.random()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,12 +170,10 @@ mod tests {
|
|||||||
fn test_bits() {
|
fn test_bits() {
|
||||||
let rng = SecureRandom::new();
|
let rng = SecureRandom::new();
|
||||||
|
|
||||||
// Single bit should be 0 or 1
|
|
||||||
for _ in 0..100 {
|
for _ in 0..100 {
|
||||||
assert!(rng.bits(1) <= 1);
|
assert!(rng.bits(1) <= 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 8 bits should be 0-255
|
|
||||||
for _ in 0..100 {
|
for _ in 0..100 {
|
||||||
assert!(rng.bits(8) <= 255);
|
assert!(rng.bits(8) <= 255);
|
||||||
}
|
}
|
||||||
@@ -185,10 +191,8 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should have seen all items
|
|
||||||
assert_eq!(seen.len(), 5);
|
assert_eq!(seen.len(), 5);
|
||||||
|
|
||||||
// Empty slice should return None
|
|
||||||
let empty: Vec<i32> = vec![];
|
let empty: Vec<i32> = vec![];
|
||||||
assert!(rng.choose(&empty).is_none());
|
assert!(rng.choose(&empty).is_none());
|
||||||
}
|
}
|
||||||
@@ -201,12 +205,10 @@ mod tests {
|
|||||||
let mut shuffled = original.clone();
|
let mut shuffled = original.clone();
|
||||||
rng.shuffle(&mut shuffled);
|
rng.shuffle(&mut shuffled);
|
||||||
|
|
||||||
// Should contain same elements
|
|
||||||
let mut sorted = shuffled.clone();
|
let mut sorted = shuffled.clone();
|
||||||
sorted.sort();
|
sorted.sort();
|
||||||
assert_eq!(sorted, original);
|
assert_eq!(sorted, original);
|
||||||
|
|
||||||
// Should be different order (with very high probability)
|
|
||||||
assert_ne!(shuffled, original);
|
assert_ne!(shuffled, original);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
13
src/error.rs
13
src/error.rs
@@ -118,16 +118,13 @@ pub trait Recoverable {
|
|||||||
impl Recoverable for StreamError {
|
impl Recoverable for StreamError {
|
||||||
fn is_recoverable(&self) -> bool {
|
fn is_recoverable(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
// Partial operations can be retried
|
|
||||||
Self::PartialRead { .. } | Self::PartialWrite { .. } => true,
|
Self::PartialRead { .. } | Self::PartialWrite { .. } => true,
|
||||||
// I/O errors depend on kind
|
|
||||||
Self::Io(e) => matches!(
|
Self::Io(e) => matches!(
|
||||||
e.kind(),
|
e.kind(),
|
||||||
std::io::ErrorKind::WouldBlock
|
std::io::ErrorKind::WouldBlock
|
||||||
| std::io::ErrorKind::Interrupted
|
| std::io::ErrorKind::Interrupted
|
||||||
| std::io::ErrorKind::TimedOut
|
| std::io::ErrorKind::TimedOut
|
||||||
),
|
),
|
||||||
// These are not recoverable
|
|
||||||
Self::Poisoned { .. }
|
Self::Poisoned { .. }
|
||||||
| Self::BufferOverflow { .. }
|
| Self::BufferOverflow { .. }
|
||||||
| Self::InvalidFrame { .. }
|
| Self::InvalidFrame { .. }
|
||||||
@@ -137,13 +134,9 @@ impl Recoverable for StreamError {
|
|||||||
|
|
||||||
fn can_continue(&self) -> bool {
|
fn can_continue(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
// Poisoned stream cannot be used
|
|
||||||
Self::Poisoned { .. } => false,
|
Self::Poisoned { .. } => false,
|
||||||
// EOF means stream is done
|
|
||||||
Self::UnexpectedEof => false,
|
Self::UnexpectedEof => false,
|
||||||
// Buffer overflow is fatal
|
|
||||||
Self::BufferOverflow { .. } => false,
|
Self::BufferOverflow { .. } => false,
|
||||||
// Others might allow continuation
|
|
||||||
_ => true,
|
_ => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -383,18 +376,18 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_handshake_result() {
|
fn test_handshake_result() {
|
||||||
let success: HandshakeResult<i32> = HandshakeResult::Success(42);
|
let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
|
||||||
assert!(success.is_success());
|
assert!(success.is_success());
|
||||||
assert!(!success.is_bad_client());
|
assert!(!success.is_bad_client());
|
||||||
|
|
||||||
let bad: HandshakeResult<i32> = HandshakeResult::BadClient;
|
let bad: HandshakeResult<i32, (), ()> = HandshakeResult::BadClient { reader: (), writer: () };
|
||||||
assert!(!bad.is_success());
|
assert!(!bad.is_success());
|
||||||
assert!(bad.is_bad_client());
|
assert!(bad.is_bad_client());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_handshake_result_map() {
|
fn test_handshake_result_map() {
|
||||||
let success: HandshakeResult<i32> = HandshakeResult::Success(42);
|
let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
|
||||||
let mapped = success.map(|x| x * 2);
|
let mapped = success.map(|x| x * 2);
|
||||||
|
|
||||||
match mapped {
|
match mapped {
|
||||||
|
|||||||
430
src/main.rs
430
src/main.rs
@@ -1,13 +1,15 @@
|
|||||||
//! Telemt - MTProxy on Rust
|
//! telemt — Telegram MTProto Proxy
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tracing::{info, error, warn};
|
use tokio::sync::Semaphore;
|
||||||
use tracing_subscriber::{fmt, EnvFilter};
|
use tracing::{debug, error, info, warn};
|
||||||
|
use tracing_subscriber::{EnvFilter, fmt, prelude::*, reload};
|
||||||
|
|
||||||
|
mod cli;
|
||||||
mod config;
|
mod config;
|
||||||
mod crypto;
|
mod crypto;
|
||||||
mod error;
|
mod error;
|
||||||
@@ -18,78 +20,356 @@ mod stream;
|
|||||||
mod transport;
|
mod transport;
|
||||||
mod util;
|
mod util;
|
||||||
|
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::{LogLevel, ProxyConfig};
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
use crate::proxy::ClientHandler;
|
use crate::proxy::ClientHandler;
|
||||||
use crate::stats::{Stats, ReplayChecker};
|
use crate::stats::{ReplayChecker, Stats};
|
||||||
use crate::transport::{create_listener, ListenOptions, UpstreamManager};
|
|
||||||
use crate::util::ip::detect_ip;
|
|
||||||
use crate::stream::BufferPool;
|
use crate::stream::BufferPool;
|
||||||
|
use crate::transport::middle_proxy::MePool;
|
||||||
|
use crate::transport::{ListenOptions, UpstreamManager, create_listener};
|
||||||
|
use crate::util::ip::detect_ip;
|
||||||
|
|
||||||
|
fn parse_cli() -> (String, bool, Option<String>) {
|
||||||
|
let mut config_path = "config.toml".to_string();
|
||||||
|
let mut silent = false;
|
||||||
|
let mut log_level: Option<String> = None;
|
||||||
|
|
||||||
|
let args: Vec<String> = std::env::args().skip(1).collect();
|
||||||
|
|
||||||
|
// Check for --init first (handled before tokio)
|
||||||
|
if let Some(init_opts) = cli::parse_init_args(&args) {
|
||||||
|
if let Err(e) = cli::run_init(init_opts) {
|
||||||
|
eprintln!("[telemt] Init failed: {}", e);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
std::process::exit(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut i = 0;
|
||||||
|
while i < args.len() {
|
||||||
|
match args[i].as_str() {
|
||||||
|
"--silent" | "-s" => {
|
||||||
|
silent = true;
|
||||||
|
}
|
||||||
|
"--log-level" => {
|
||||||
|
i += 1;
|
||||||
|
if i < args.len() {
|
||||||
|
log_level = Some(args[i].clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s if s.starts_with("--log-level=") => {
|
||||||
|
log_level = Some(s.trim_start_matches("--log-level=").to_string());
|
||||||
|
}
|
||||||
|
"--help" | "-h" => {
|
||||||
|
eprintln!("Usage: telemt [config.toml] [OPTIONS]");
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("Options:");
|
||||||
|
eprintln!(" --silent, -s Suppress info logs");
|
||||||
|
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
|
||||||
|
eprintln!(" --help, -h Show this help");
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("Setup (fire-and-forget):");
|
||||||
|
eprintln!(
|
||||||
|
" --init Generate config, install systemd service, start"
|
||||||
|
);
|
||||||
|
eprintln!(" --port <PORT> Listen port (default: 443)");
|
||||||
|
eprintln!(
|
||||||
|
" --domain <DOMAIN> TLS domain for masking (default: www.google.com)"
|
||||||
|
);
|
||||||
|
eprintln!(
|
||||||
|
" --secret <HEX> 32-char hex secret (auto-generated if omitted)"
|
||||||
|
);
|
||||||
|
eprintln!(" --user <NAME> Username (default: user)");
|
||||||
|
eprintln!(" --config-dir <DIR> Config directory (default: /etc/telemt)");
|
||||||
|
eprintln!(" --no-start Don't start the service after install");
|
||||||
|
std::process::exit(0);
|
||||||
|
}
|
||||||
|
s if !s.starts_with('-') => {
|
||||||
|
config_path = s.to_string();
|
||||||
|
}
|
||||||
|
other => {
|
||||||
|
eprintln!("Unknown option: {}", other);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
(config_path, silent, log_level)
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||||
// Initialize logging
|
let (config_path, cli_silent, cli_log_level) = parse_cli();
|
||||||
fmt()
|
|
||||||
.with_env_filter(EnvFilter::from_default_env().add_directive("info".parse().unwrap()))
|
|
||||||
.init();
|
|
||||||
|
|
||||||
// Load config
|
|
||||||
let config_path = std::env::args().nth(1).unwrap_or_else(|| "config.toml".to_string());
|
|
||||||
let config = match ProxyConfig::load(&config_path) {
|
let config = match ProxyConfig::load(&config_path) {
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// If config doesn't exist, try to create default
|
|
||||||
if std::path::Path::new(&config_path).exists() {
|
if std::path::Path::new(&config_path).exists() {
|
||||||
error!("Failed to load config: {}", e);
|
eprintln!("[telemt] Error: {}", e);
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
} else {
|
} else {
|
||||||
let default = ProxyConfig::default();
|
let default = ProxyConfig::default();
|
||||||
let toml = toml::to_string_pretty(&default).unwrap();
|
std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
|
||||||
std::fs::write(&config_path, toml).unwrap();
|
eprintln!("[telemt] Created default config at {}", config_path);
|
||||||
info!("Created default config at {}", config_path);
|
|
||||||
default
|
default
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
config.validate()?;
|
if let Err(e) = config.validate() {
|
||||||
|
eprintln!("[telemt] Invalid config: {}", e);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
// Log loaded configuration for debugging
|
let has_rust_log = std::env::var("RUST_LOG").is_ok();
|
||||||
info!("=== Configuration Loaded ===");
|
let effective_log_level = if cli_silent {
|
||||||
info!("TLS Domain: {}", config.censorship.tls_domain);
|
LogLevel::Silent
|
||||||
info!("Mask enabled: {}", config.censorship.mask);
|
} else if let Some(ref s) = cli_log_level {
|
||||||
info!("Mask host: {}", config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain));
|
LogLevel::from_str_loose(s)
|
||||||
info!("Mask port: {}", config.censorship.mask_port);
|
} else {
|
||||||
info!("Modes: classic={}, secure={}, tls={}",
|
config.general.log_level.clone()
|
||||||
config.general.modes.classic,
|
};
|
||||||
config.general.modes.secure,
|
|
||||||
config.general.modes.tls
|
let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info"));
|
||||||
|
tracing_subscriber::registry()
|
||||||
|
.with(filter_layer)
|
||||||
|
.with(fmt::Layer::default())
|
||||||
|
.init();
|
||||||
|
|
||||||
|
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
|
||||||
|
info!("Log level: {}", effective_log_level);
|
||||||
|
info!(
|
||||||
|
"Modes: classic={} secure={} tls={}",
|
||||||
|
config.general.modes.classic, config.general.modes.secure, config.general.modes.tls
|
||||||
);
|
);
|
||||||
info!("============================");
|
info!("TLS domain: {}", config.censorship.tls_domain);
|
||||||
|
if let Some(ref sock) = config.censorship.mask_unix_sock {
|
||||||
|
info!("Mask: {} -> unix:{}", config.censorship.mask, sock);
|
||||||
|
if !std::path::Path::new(sock).exists() {
|
||||||
|
warn!(
|
||||||
|
"Unix socket '{}' does not exist yet. Masking will fail until it appears.",
|
||||||
|
sock
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
info!(
|
||||||
|
"Mask: {} -> {}:{}",
|
||||||
|
config.censorship.mask,
|
||||||
|
config
|
||||||
|
.censorship
|
||||||
|
.mask_host
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or(&config.censorship.tls_domain),
|
||||||
|
config.censorship.mask_port
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.censorship.tls_domain == "www.google.com" {
|
||||||
|
warn!("Using default tls_domain. Consider setting a custom domain.");
|
||||||
|
}
|
||||||
|
|
||||||
|
let prefer_ipv6 = config.general.prefer_ipv6;
|
||||||
|
let use_middle_proxy = config.general.use_middle_proxy;
|
||||||
let config = Arc::new(config);
|
let config = Arc::new(config);
|
||||||
let stats = Arc::new(Stats::new());
|
let stats = Arc::new(Stats::new());
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
|
||||||
// Initialize global ReplayChecker
|
let replay_checker = Arc::new(ReplayChecker::new(
|
||||||
// Using sharded implementation for better concurrency
|
config.access.replay_check_len,
|
||||||
let replay_checker = Arc::new(ReplayChecker::new(config.access.replay_check_len));
|
Duration::from_secs(config.access.replay_window_secs),
|
||||||
|
));
|
||||||
|
|
||||||
// Initialize Upstream Manager
|
|
||||||
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
|
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
|
||||||
|
|
||||||
// Initialize Buffer Pool
|
|
||||||
// 16KB buffers, max 4096 buffers (~64MB total cached)
|
|
||||||
let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
|
let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
|
||||||
|
|
||||||
// Start Health Checks
|
// Connection concurrency limit
|
||||||
let um_clone = upstream_manager.clone();
|
let _max_connections = Arc::new(Semaphore::new(10_000));
|
||||||
tokio::spawn(async move {
|
|
||||||
um_clone.run_health_checks().await;
|
// =====================================================================
|
||||||
|
// Middle Proxy initialization (if enabled)
|
||||||
|
// =====================================================================
|
||||||
|
let me_pool: Option<Arc<MePool>> = if use_middle_proxy {
|
||||||
|
info!("=== Middle Proxy Mode ===");
|
||||||
|
|
||||||
|
// ad_tag (proxy_tag) for advertising
|
||||||
|
let proxy_tag = config.general.ad_tag.as_ref().map(|tag| {
|
||||||
|
hex::decode(tag).unwrap_or_else(|_| {
|
||||||
|
warn!("Invalid ad_tag hex, middle proxy ad_tag will be empty");
|
||||||
|
Vec::new()
|
||||||
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
// Detect public IP if needed (once at startup)
|
// =============================================================
|
||||||
let detected_ip = detect_ip().await;
|
// CRITICAL: Download Telegram proxy-secret (NOT user secret!)
|
||||||
|
//
|
||||||
|
// C MTProxy uses TWO separate secrets:
|
||||||
|
// -S flag = 16-byte user secret for client obfuscation
|
||||||
|
// --aes-pwd = 32-512 byte binary file for ME RPC auth
|
||||||
|
//
|
||||||
|
// proxy-secret is from: https://core.telegram.org/getProxySecret
|
||||||
|
// =============================================================
|
||||||
|
let proxy_secret_path = config.general.proxy_secret_path.as_deref();
|
||||||
|
match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).await {
|
||||||
|
Ok(proxy_secret) => {
|
||||||
|
info!(
|
||||||
|
secret_len = proxy_secret.len(),
|
||||||
|
key_sig = format_args!(
|
||||||
|
"0x{:08x}",
|
||||||
|
if proxy_secret.len() >= 4 {
|
||||||
|
u32::from_le_bytes([
|
||||||
|
proxy_secret[0],
|
||||||
|
proxy_secret[1],
|
||||||
|
proxy_secret[2],
|
||||||
|
proxy_secret[3],
|
||||||
|
])
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"Proxy-secret loaded"
|
||||||
|
);
|
||||||
|
|
||||||
|
let pool = MePool::new(
|
||||||
|
proxy_tag,
|
||||||
|
proxy_secret,
|
||||||
|
config.general.middle_proxy_nat_ip,
|
||||||
|
config.general.middle_proxy_nat_probe,
|
||||||
|
config.general.middle_proxy_nat_stun.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
match pool.init(2, &rng).await {
|
||||||
|
Ok(()) => {
|
||||||
|
info!("Middle-End pool initialized successfully");
|
||||||
|
|
||||||
|
// Phase 4: Start health monitor
|
||||||
|
let pool_clone = pool.clone();
|
||||||
|
let rng_clone = rng.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
crate::transport::middle_proxy::me_health_monitor(
|
||||||
|
pool_clone, rng_clone, 2,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
|
||||||
|
Some(pool)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!(error = %e, "Failed to initialize ME pool. Falling back to direct mode.");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!(error = %e, "Failed to fetch proxy-secret. Falling back to direct mode.");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
if me_pool.is_some() {
|
||||||
|
info!("Transport: Middle Proxy (supports all DCs including CDN)");
|
||||||
|
} else {
|
||||||
|
info!("Transport: Direct TCP (standard DCs only)");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Startup DC ping (only meaningful in direct mode)
|
||||||
|
if me_pool.is_none() {
|
||||||
|
info!("================= Telegram DC Connectivity =================");
|
||||||
|
|
||||||
|
let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await;
|
||||||
|
|
||||||
|
for upstream_result in &ping_results {
|
||||||
|
// Show which IP version is in use and which is fallback
|
||||||
|
if upstream_result.both_available {
|
||||||
|
if prefer_ipv6 {
|
||||||
|
info!(" IPv6 in use and IPv4 is fallback");
|
||||||
|
} else {
|
||||||
|
info!(" IPv4 in use and IPv6 is fallback");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let v6_works = upstream_result
|
||||||
|
.v6_results
|
||||||
|
.iter()
|
||||||
|
.any(|r| r.rtt_ms.is_some());
|
||||||
|
let v4_works = upstream_result
|
||||||
|
.v4_results
|
||||||
|
.iter()
|
||||||
|
.any(|r| r.rtt_ms.is_some());
|
||||||
|
if v6_works && !v4_works {
|
||||||
|
info!(" IPv6 only (IPv4 unavailable)");
|
||||||
|
} else if v4_works && !v6_works {
|
||||||
|
info!(" IPv4 only (IPv6 unavailable)");
|
||||||
|
} else if !v6_works && !v4_works {
|
||||||
|
info!(" No connectivity!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(" via {}", upstream_result.upstream_name);
|
||||||
|
info!("============================================================");
|
||||||
|
|
||||||
|
// Print IPv6 results first
|
||||||
|
for dc in &upstream_result.v6_results {
|
||||||
|
let addr_str = format!("{}:{}", dc.dc_addr.ip(), dc.dc_addr.port());
|
||||||
|
match &dc.rtt_ms {
|
||||||
|
Some(rtt) => {
|
||||||
|
// Align: IPv6 addresses are longer, use fewer tabs
|
||||||
|
// [2001:b28:f23d:f001::a]:443 = ~28 chars
|
||||||
|
info!(" DC{} [IPv6] {}:\t\t{:.0} ms", dc.dc_idx, addr_str, rtt);
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let err = dc.error.as_deref().unwrap_or("fail");
|
||||||
|
info!(" DC{} [IPv6] {}:\t\tFAIL ({})", dc.dc_idx, addr_str, err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("============================================================");
|
||||||
|
|
||||||
|
// Print IPv4 results
|
||||||
|
for dc in &upstream_result.v4_results {
|
||||||
|
let addr_str = format!("{}:{}", dc.dc_addr.ip(), dc.dc_addr.port());
|
||||||
|
match &dc.rtt_ms {
|
||||||
|
Some(rtt) => {
|
||||||
|
// Align: IPv4 addresses are shorter, use more tabs
|
||||||
|
// 149.154.175.50:443 = ~18 chars
|
||||||
|
info!(
|
||||||
|
" DC{} [IPv4] {}:\t\t\t\t{:.0} ms",
|
||||||
|
dc.dc_idx, addr_str, rtt
|
||||||
|
);
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let err = dc.error.as_deref().unwrap_or("fail");
|
||||||
|
info!(
|
||||||
|
" DC{} [IPv4] {}:\t\t\t\tFAIL ({})",
|
||||||
|
dc.dc_idx, addr_str, err
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("============================================================");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Background tasks
|
||||||
|
let um_clone = upstream_manager.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
um_clone.run_health_checks(prefer_ipv6).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let rc_clone = replay_checker.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
rc_clone.run_periodic_cleanup().await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let detected_ip = detect_ip().await;
|
||||||
|
debug!(
|
||||||
|
"Detected IPs: v4={:?} v6={:?}",
|
||||||
|
detected_ip.ipv4, detected_ip.ipv6
|
||||||
|
);
|
||||||
|
|
||||||
// Start Listeners
|
|
||||||
let mut listeners = Vec::new();
|
let mut listeners = Vec::new();
|
||||||
|
|
||||||
for listener_conf in &config.server.listeners {
|
for listener_conf in &config.server.listeners {
|
||||||
@@ -104,7 +384,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let listener = TcpListener::from_std(socket.into())?;
|
let listener = TcpListener::from_std(socket.into())?;
|
||||||
info!("Listening on {}", addr);
|
info!("Listening on {}", addr);
|
||||||
|
|
||||||
// Determine public IP for tg:// links
|
|
||||||
let public_ip = if let Some(ip) = listener_conf.announce_ip {
|
let public_ip = if let Some(ip) = listener_conf.announce_ip {
|
||||||
ip
|
ip
|
||||||
} else if listener_conf.ip.is_unspecified() {
|
} else if listener_conf.ip.is_unspecified() {
|
||||||
@@ -117,37 +396,39 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
listener_conf.ip
|
listener_conf.ip
|
||||||
};
|
};
|
||||||
|
|
||||||
// Show links for configured users
|
|
||||||
if !config.show_link.is_empty() {
|
if !config.show_link.is_empty() {
|
||||||
info!("--- Proxy Links for {} ---", public_ip);
|
info!("--- Proxy Links ({}) ---", public_ip);
|
||||||
for user_name in &config.show_link {
|
for user_name in &config.show_link {
|
||||||
if let Some(secret) = config.access.users.get(user_name) {
|
if let Some(secret) = config.access.users.get(user_name) {
|
||||||
info!("User: {}", user_name);
|
info!("User: {}", user_name);
|
||||||
|
|
||||||
if config.general.modes.classic {
|
if config.general.modes.classic {
|
||||||
info!(" Classic: tg://proxy?server={}&port={}&secret={}",
|
info!(
|
||||||
public_ip, config.server.port, secret);
|
" Classic: tg://proxy?server={}&port={}&secret={}",
|
||||||
|
public_ip, config.server.port, secret
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.general.modes.secure {
|
if config.general.modes.secure {
|
||||||
info!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
|
info!(
|
||||||
public_ip, config.server.port, secret);
|
" DD: tg://proxy?server={}&port={}&secret=dd{}",
|
||||||
|
public_ip, config.server.port, secret
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.general.modes.tls {
|
if config.general.modes.tls {
|
||||||
let domain_hex = hex::encode(&config.censorship.tls_domain);
|
let domain_hex = hex::encode(&config.censorship.tls_domain);
|
||||||
info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
|
info!(
|
||||||
public_ip, config.server.port, secret, domain_hex);
|
" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
|
||||||
|
public_ip, config.server.port, secret, domain_hex
|
||||||
|
);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
warn!("User '{}' specified in show_link not found in users list", user_name);
|
warn!("User '{}' in show_link not found", user_name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
info!("-----------------------------------");
|
info!("------------------------");
|
||||||
}
|
}
|
||||||
|
|
||||||
listeners.push(listener);
|
listeners.push(listener);
|
||||||
},
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to bind to {}: {}", addr, e);
|
error!("Failed to bind to {}: {}", addr, e);
|
||||||
}
|
}
|
||||||
@@ -155,17 +436,28 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if listeners.is_empty() {
|
if listeners.is_empty() {
|
||||||
error!("No listeners could be started. Exiting.");
|
error!("No listeners. Exiting.");
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accept loop
|
// Switch to user-configured log level after startup
|
||||||
|
let runtime_filter = if has_rust_log {
|
||||||
|
EnvFilter::from_default_env()
|
||||||
|
} else {
|
||||||
|
EnvFilter::new(effective_log_level.to_filter_str())
|
||||||
|
};
|
||||||
|
filter_handle
|
||||||
|
.reload(runtime_filter)
|
||||||
|
.expect("Failed to switch log filter");
|
||||||
|
|
||||||
for listener in listeners {
|
for listener in listeners {
|
||||||
let config = config.clone();
|
let config = config.clone();
|
||||||
let stats = stats.clone();
|
let stats = stats.clone();
|
||||||
let upstream_manager = upstream_manager.clone();
|
let upstream_manager = upstream_manager.clone();
|
||||||
let replay_checker = replay_checker.clone();
|
let replay_checker = replay_checker.clone();
|
||||||
let buffer_pool = buffer_pool.clone();
|
let buffer_pool = buffer_pool.clone();
|
||||||
|
let rng = rng.clone();
|
||||||
|
let me_pool = me_pool.clone();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
@@ -176,6 +468,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let upstream_manager = upstream_manager.clone();
|
let upstream_manager = upstream_manager.clone();
|
||||||
let replay_checker = replay_checker.clone();
|
let replay_checker = replay_checker.clone();
|
||||||
let buffer_pool = buffer_pool.clone();
|
let buffer_pool = buffer_pool.clone();
|
||||||
|
let rng = rng.clone();
|
||||||
|
let me_pool = me_pool.clone();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = ClientHandler::new(
|
if let Err(e) = ClientHandler::new(
|
||||||
@@ -185,9 +479,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
stats,
|
stats,
|
||||||
upstream_manager,
|
upstream_manager,
|
||||||
replay_checker,
|
replay_checker,
|
||||||
buffer_pool
|
buffer_pool,
|
||||||
).run().await {
|
rng,
|
||||||
// Log only relevant errors
|
me_pool,
|
||||||
|
)
|
||||||
|
.run()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
debug!(peer = %peer_addr, error = %e, "Connection error");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -200,7 +499,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for signal
|
|
||||||
match signal::ctrl_c().await {
|
match signal::ctrl_c().await {
|
||||||
Ok(()) => info!("Shutting down..."),
|
Ok(()) => info!("Shutting down..."),
|
||||||
Err(e) => error!("Signal error: {}", e),
|
Err(e) => error!("Signal error: {}", e),
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
//! Protocol constants and datacenter addresses
|
//! Protocol constants and datacenter addresses
|
||||||
|
|
||||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||||
use once_cell::sync::Lazy;
|
use std::sync::LazyLock;
|
||||||
|
|
||||||
// ============= Telegram Datacenters =============
|
// ============= Telegram Datacenters =============
|
||||||
|
|
||||||
pub const TG_DATACENTER_PORT: u16 = 443;
|
pub const TG_DATACENTER_PORT: u16 = 443;
|
||||||
|
|
||||||
pub static TG_DATACENTERS_V4: Lazy<Vec<IpAddr>> = Lazy::new(|| {
|
pub static TG_DATACENTERS_V4: LazyLock<Vec<IpAddr>> = LazyLock::new(|| {
|
||||||
vec![
|
vec![
|
||||||
IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)),
|
IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)),
|
||||||
IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)),
|
IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)),
|
||||||
@@ -17,7 +17,7 @@ pub static TG_DATACENTERS_V4: Lazy<Vec<IpAddr>> = Lazy::new(|| {
|
|||||||
]
|
]
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static TG_DATACENTERS_V6: Lazy<Vec<IpAddr>> = Lazy::new(|| {
|
pub static TG_DATACENTERS_V6: LazyLock<Vec<IpAddr>> = LazyLock::new(|| {
|
||||||
vec![
|
vec![
|
||||||
IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()),
|
IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()),
|
||||||
IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()),
|
IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()),
|
||||||
@@ -29,8 +29,8 @@ pub static TG_DATACENTERS_V6: Lazy<Vec<IpAddr>> = Lazy::new(|| {
|
|||||||
|
|
||||||
// ============= Middle Proxies (for advertising) =============
|
// ============= Middle Proxies (for advertising) =============
|
||||||
|
|
||||||
pub static TG_MIDDLE_PROXIES_V4: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
|
pub static TG_MIDDLE_PROXIES_V4: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
|
||||||
Lazy::new(|| {
|
LazyLock::new(|| {
|
||||||
let mut m = std::collections::HashMap::new();
|
let mut m = std::collections::HashMap::new();
|
||||||
m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
|
m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
|
||||||
m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
|
m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
|
||||||
@@ -45,8 +45,8 @@ pub static TG_MIDDLE_PROXIES_V4: Lazy<std::collections::HashMap<i32, Vec<(IpAddr
|
|||||||
m
|
m
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static TG_MIDDLE_PROXIES_V6: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
|
pub static TG_MIDDLE_PROXIES_V6: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
|
||||||
Lazy::new(|| {
|
LazyLock::new(|| {
|
||||||
let mut m = std::collections::HashMap::new();
|
let mut m = std::collections::HashMap::new();
|
||||||
m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
|
m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
|
||||||
m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
|
m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
|
||||||
@@ -167,8 +167,6 @@ pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300;
|
|||||||
// ============= Buffer Sizes =============
|
// ============= Buffer Sizes =============
|
||||||
|
|
||||||
/// Default buffer size
|
/// Default buffer size
|
||||||
/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and align with
|
|
||||||
/// the new buffering strategy for better iOS upload performance.
|
|
||||||
pub const DEFAULT_BUFFER_SIZE: usize = 16384;
|
pub const DEFAULT_BUFFER_SIZE: usize = 16384;
|
||||||
|
|
||||||
/// Small buffer size for bad client handling
|
/// Small buffer size for bad client handling
|
||||||
@@ -204,6 +202,17 @@ pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[
|
|||||||
// ============= RPC Constants (for Middle Proxy) =============
|
// ============= RPC Constants (for Middle Proxy) =============
|
||||||
|
|
||||||
/// RPC Proxy Request
|
/// RPC Proxy Request
|
||||||
|
|
||||||
|
/// RPC Flags (from Erlang mtp_rpc.erl)
|
||||||
|
pub const RPC_FLAG_NOT_ENCRYPTED: u32 = 0x2;
|
||||||
|
pub const RPC_FLAG_HAS_AD_TAG: u32 = 0x8;
|
||||||
|
pub const RPC_FLAG_MAGIC: u32 = 0x1000;
|
||||||
|
pub const RPC_FLAG_EXTMODE2: u32 = 0x20000;
|
||||||
|
pub const RPC_FLAG_PAD: u32 = 0x8000000;
|
||||||
|
pub const RPC_FLAG_INTERMEDIATE: u32 = 0x20000000;
|
||||||
|
pub const RPC_FLAG_ABRIDGED: u32 = 0x40000000;
|
||||||
|
pub const RPC_FLAG_QUICKACK: u32 = 0x80000000;
|
||||||
|
|
||||||
pub const RPC_PROXY_REQ: [u8; 4] = [0xee, 0xf1, 0xce, 0x36];
|
pub const RPC_PROXY_REQ: [u8; 4] = [0xee, 0xf1, 0xce, 0x36];
|
||||||
/// RPC Proxy Answer
|
/// RPC Proxy Answer
|
||||||
pub const RPC_PROXY_ANS: [u8; 4] = [0x0d, 0xda, 0x03, 0x44];
|
pub const RPC_PROXY_ANS: [u8; 4] = [0x0d, 0xda, 0x03, 0x44];
|
||||||
@@ -230,7 +239,56 @@ pub mod rpc_flags {
|
|||||||
pub const FLAG_QUICKACK: u32 = 0x80000000;
|
pub const FLAG_QUICKACK: u32 = 0x80000000;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
|
// ============= Middle-End Proxy Servers =============
|
||||||
|
pub const ME_PROXY_PORT: u16 = 8888;
|
||||||
|
|
||||||
|
pub static TG_MIDDLE_PROXIES_FLAT_V4: LazyLock<Vec<(IpAddr, u16)>> = LazyLock::new(|| {
|
||||||
|
vec![
|
||||||
|
(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888),
|
||||||
|
(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888),
|
||||||
|
(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888),
|
||||||
|
(IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888),
|
||||||
|
(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888),
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
// ============= RPC Constants (u32 native endian) =============
|
||||||
|
// From mtproto-common.h + net-tcp-rpc-common.h + mtproto-proxy.c
|
||||||
|
|
||||||
|
pub const RPC_NONCE_U32: u32 = 0x7acb87aa;
|
||||||
|
pub const RPC_HANDSHAKE_U32: u32 = 0x7682eef5;
|
||||||
|
pub const RPC_HANDSHAKE_ERROR_U32: u32 = 0x6a27beda;
|
||||||
|
pub const TL_PROXY_TAG_U32: u32 = 0xdb1e26ae; // mtproto-proxy.c:121
|
||||||
|
|
||||||
|
// mtproto-common.h
|
||||||
|
pub const RPC_PROXY_REQ_U32: u32 = 0x36cef1ee;
|
||||||
|
pub const RPC_PROXY_ANS_U32: u32 = 0x4403da0d;
|
||||||
|
pub const RPC_CLOSE_CONN_U32: u32 = 0x1fcf425d;
|
||||||
|
pub const RPC_CLOSE_EXT_U32: u32 = 0x5eb634a2;
|
||||||
|
pub const RPC_SIMPLE_ACK_U32: u32 = 0x3bac409b;
|
||||||
|
pub const RPC_PING_U32: u32 = 0x5730a2df;
|
||||||
|
pub const RPC_PONG_U32: u32 = 0x8430eaa7;
|
||||||
|
|
||||||
|
pub const RPC_CRYPTO_NONE_U32: u32 = 0;
|
||||||
|
pub const RPC_CRYPTO_AES_U32: u32 = 1;
|
||||||
|
|
||||||
|
pub mod proxy_flags {
|
||||||
|
pub const FLAG_HAS_AD_TAG: u32 = 1;
|
||||||
|
pub const FLAG_NOT_ENCRYPTED: u32 = 0x2;
|
||||||
|
pub const FLAG_HAS_AD_TAG2: u32 = 0x8;
|
||||||
|
pub const FLAG_MAGIC: u32 = 0x1000;
|
||||||
|
pub const FLAG_EXTMODE2: u32 = 0x20000;
|
||||||
|
pub const FLAG_PAD: u32 = 0x8000000;
|
||||||
|
pub const FLAG_INTERMEDIATE: u32 = 0x20000000;
|
||||||
|
pub const FLAG_ABRIDGED: u32 = 0x40000000;
|
||||||
|
pub const FLAG_QUICKACK: u32 = 0x80000000;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5;
|
||||||
|
pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
//! MTProto Obfuscation
|
//! MTProto Obfuscation
|
||||||
|
|
||||||
|
use zeroize::Zeroize;
|
||||||
use crate::crypto::{sha256, AesCtr};
|
use crate::crypto::{sha256, AesCtr};
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
use super::constants::*;
|
use super::constants::*;
|
||||||
|
|
||||||
/// Obfuscation parameters from handshake
|
/// Obfuscation parameters from handshake
|
||||||
|
///
|
||||||
|
/// Key material is zeroized on drop.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ObfuscationParams {
|
pub struct ObfuscationParams {
|
||||||
/// Key for decrypting client -> proxy traffic
|
/// Key for decrypting client -> proxy traffic
|
||||||
@@ -21,25 +24,31 @@ pub struct ObfuscationParams {
|
|||||||
pub dc_idx: i16,
|
pub dc_idx: i16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Drop for ObfuscationParams {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.decrypt_key.zeroize();
|
||||||
|
self.decrypt_iv.zeroize();
|
||||||
|
self.encrypt_key.zeroize();
|
||||||
|
self.encrypt_iv.zeroize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ObfuscationParams {
|
impl ObfuscationParams {
|
||||||
/// Parse obfuscation parameters from handshake bytes
|
/// Parse obfuscation parameters from handshake bytes
|
||||||
/// Returns None if handshake doesn't match any user secret
|
/// Returns None if handshake doesn't match any user secret
|
||||||
pub fn from_handshake(
|
pub fn from_handshake(
|
||||||
handshake: &[u8; HANDSHAKE_LEN],
|
handshake: &[u8; HANDSHAKE_LEN],
|
||||||
secrets: &[(String, Vec<u8>)], // (username, secret_bytes)
|
secrets: &[(String, Vec<u8>)],
|
||||||
) -> Option<(Self, String)> {
|
) -> Option<(Self, String)> {
|
||||||
// Extract prekey and IV for decryption
|
|
||||||
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
|
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
|
||||||
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
|
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
|
||||||
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
|
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
|
||||||
|
|
||||||
// Reversed for encryption direction
|
|
||||||
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
|
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
|
||||||
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
|
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
|
||||||
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
|
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
|
||||||
|
|
||||||
for (username, secret) in secrets {
|
for (username, secret) in secrets {
|
||||||
// Derive decryption key
|
|
||||||
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
|
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
|
||||||
dec_key_input.extend_from_slice(dec_prekey);
|
dec_key_input.extend_from_slice(dec_prekey);
|
||||||
dec_key_input.extend_from_slice(secret);
|
dec_key_input.extend_from_slice(secret);
|
||||||
@@ -47,26 +56,22 @@ impl ObfuscationParams {
|
|||||||
|
|
||||||
let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
|
let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
|
||||||
|
|
||||||
// Create decryptor and decrypt handshake
|
|
||||||
let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv);
|
let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv);
|
||||||
let decrypted = decryptor.decrypt(handshake);
|
let decrypted = decryptor.decrypt(handshake);
|
||||||
|
|
||||||
// Check protocol tag
|
|
||||||
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
|
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
|
||||||
.try_into()
|
.try_into()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
|
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
|
||||||
Some(tag) => tag,
|
Some(tag) => tag,
|
||||||
None => continue, // Try next secret
|
None => continue,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Extract DC index
|
|
||||||
let dc_idx = i16::from_le_bytes(
|
let dc_idx = i16::from_le_bytes(
|
||||||
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
|
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Derive encryption key
|
|
||||||
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
|
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
|
||||||
enc_key_input.extend_from_slice(enc_prekey);
|
enc_key_input.extend_from_slice(enc_prekey);
|
||||||
enc_key_input.extend_from_slice(secret);
|
enc_key_input.extend_from_slice(secret);
|
||||||
@@ -123,18 +128,15 @@ pub fn generate_nonce<R: FnMut(usize) -> Vec<u8>>(mut random_bytes: R) -> [u8; H
|
|||||||
|
|
||||||
/// Check if nonce is valid (not matching reserved patterns)
|
/// Check if nonce is valid (not matching reserved patterns)
|
||||||
pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
|
pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
|
||||||
// Check first byte
|
|
||||||
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
|
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check first 4 bytes
|
|
||||||
let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
|
let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
|
||||||
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
|
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check bytes 4-7
|
|
||||||
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
|
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
|
||||||
if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
|
if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
|
||||||
return false;
|
return false;
|
||||||
@@ -147,12 +149,10 @@ pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
|
|||||||
pub fn prepare_tg_nonce(
|
pub fn prepare_tg_nonce(
|
||||||
nonce: &mut [u8; HANDSHAKE_LEN],
|
nonce: &mut [u8; HANDSHAKE_LEN],
|
||||||
proto_tag: ProtoTag,
|
proto_tag: ProtoTag,
|
||||||
enc_key_iv: Option<&[u8]>, // For fast mode
|
enc_key_iv: Option<&[u8]>,
|
||||||
) {
|
) {
|
||||||
// Set protocol tag
|
|
||||||
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
|
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
|
||||||
|
|
||||||
// For fast mode, copy the reversed enc_key_iv
|
|
||||||
if let Some(key_iv) = enc_key_iv {
|
if let Some(key_iv) = enc_key_iv {
|
||||||
let reversed: Vec<u8> = key_iv.iter().rev().copied().collect();
|
let reversed: Vec<u8> = key_iv.iter().rev().copied().collect();
|
||||||
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed);
|
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed);
|
||||||
@@ -161,14 +161,12 @@ pub fn prepare_tg_nonce(
|
|||||||
|
|
||||||
/// Encrypt the outgoing nonce for Telegram
|
/// Encrypt the outgoing nonce for Telegram
|
||||||
pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
|
pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
|
||||||
// Derive encryption key from the nonce itself
|
|
||||||
let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
||||||
let enc_key = sha256(key_iv);
|
let enc_key = sha256(key_iv);
|
||||||
let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap());
|
let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap());
|
||||||
|
|
||||||
let mut encryptor = AesCtr::new(&enc_key, enc_iv);
|
let mut encryptor = AesCtr::new(&enc_key, enc_iv);
|
||||||
|
|
||||||
// Only encrypt from PROTO_TAG_POS onwards
|
|
||||||
let mut result = nonce.to_vec();
|
let mut result = nonce.to_vec();
|
||||||
let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]);
|
let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]);
|
||||||
result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part);
|
result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part);
|
||||||
@@ -182,22 +180,18 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_is_valid_nonce() {
|
fn test_is_valid_nonce() {
|
||||||
// Valid nonce
|
|
||||||
let mut valid = [0x42u8; HANDSHAKE_LEN];
|
let mut valid = [0x42u8; HANDSHAKE_LEN];
|
||||||
valid[4..8].copy_from_slice(&[1, 2, 3, 4]);
|
valid[4..8].copy_from_slice(&[1, 2, 3, 4]);
|
||||||
assert!(is_valid_nonce(&valid));
|
assert!(is_valid_nonce(&valid));
|
||||||
|
|
||||||
// Invalid: starts with 0xef
|
|
||||||
let mut invalid = [0x00u8; HANDSHAKE_LEN];
|
let mut invalid = [0x00u8; HANDSHAKE_LEN];
|
||||||
invalid[0] = 0xef;
|
invalid[0] = 0xef;
|
||||||
assert!(!is_valid_nonce(&invalid));
|
assert!(!is_valid_nonce(&invalid));
|
||||||
|
|
||||||
// Invalid: starts with HEAD
|
|
||||||
let mut invalid = [0x00u8; HANDSHAKE_LEN];
|
let mut invalid = [0x00u8; HANDSHAKE_LEN];
|
||||||
invalid[..4].copy_from_slice(b"HEAD");
|
invalid[..4].copy_from_slice(b"HEAD");
|
||||||
assert!(!is_valid_nonce(&invalid));
|
assert!(!is_valid_nonce(&invalid));
|
||||||
|
|
||||||
// Invalid: bytes 4-7 are zeros
|
|
||||||
let mut invalid = [0x42u8; HANDSHAKE_LEN];
|
let mut invalid = [0x42u8; HANDSHAKE_LEN];
|
||||||
invalid[4..8].copy_from_slice(&[0, 0, 0, 0]);
|
invalid[4..8].copy_from_slice(&[0, 0, 0, 0]);
|
||||||
assert!(!is_valid_nonce(&invalid));
|
assert!(!is_valid_nonce(&invalid));
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
//! for domain fronting. The handshake looks like valid TLS 1.3 but
|
//! for domain fronting. The handshake looks like valid TLS 1.3 but
|
||||||
//! actually carries MTProto authentication data.
|
//! actually carries MTProto authentication data.
|
||||||
|
|
||||||
use crate::crypto::{sha256_hmac, random::SECURE_RANDOM};
|
use crate::crypto::{sha256_hmac, SecureRandom};
|
||||||
use crate::error::{ProxyError, Result};
|
use crate::error::{ProxyError, Result};
|
||||||
use super::constants::*;
|
use super::constants::*;
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
@@ -315,8 +315,8 @@ pub fn validate_tls_handshake(
|
|||||||
///
|
///
|
||||||
/// This generates random bytes that look like a valid X25519 public key.
|
/// This generates random bytes that look like a valid X25519 public key.
|
||||||
/// Since we're not doing real TLS, the actual cryptographic properties don't matter.
|
/// Since we're not doing real TLS, the actual cryptographic properties don't matter.
|
||||||
pub fn gen_fake_x25519_key() -> [u8; 32] {
|
pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] {
|
||||||
let bytes = SECURE_RANDOM.bytes(32);
|
let bytes = rng.bytes(32);
|
||||||
bytes.try_into().unwrap()
|
bytes.try_into().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,8 +333,9 @@ pub fn build_server_hello(
|
|||||||
client_digest: &[u8; TLS_DIGEST_LEN],
|
client_digest: &[u8; TLS_DIGEST_LEN],
|
||||||
session_id: &[u8],
|
session_id: &[u8],
|
||||||
fake_cert_len: usize,
|
fake_cert_len: usize,
|
||||||
|
rng: &SecureRandom,
|
||||||
) -> Vec<u8> {
|
) -> Vec<u8> {
|
||||||
let x25519_key = gen_fake_x25519_key();
|
let x25519_key = gen_fake_x25519_key(rng);
|
||||||
|
|
||||||
// Build ServerHello
|
// Build ServerHello
|
||||||
let server_hello = ServerHelloBuilder::new(session_id.to_vec())
|
let server_hello = ServerHelloBuilder::new(session_id.to_vec())
|
||||||
@@ -351,7 +352,7 @@ pub fn build_server_hello(
|
|||||||
];
|
];
|
||||||
|
|
||||||
// Build fake certificate (Application Data record)
|
// Build fake certificate (Application Data record)
|
||||||
let fake_cert = SECURE_RANDOM.bytes(fake_cert_len);
|
let fake_cert = rng.bytes(fake_cert_len);
|
||||||
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
|
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
|
||||||
app_data_record.push(TLS_RECORD_APPLICATION);
|
app_data_record.push(TLS_RECORD_APPLICATION);
|
||||||
app_data_record.extend_from_slice(&TLS_VERSION);
|
app_data_record.extend_from_slice(&TLS_VERSION);
|
||||||
@@ -489,8 +490,9 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_gen_fake_x25519_key() {
|
fn test_gen_fake_x25519_key() {
|
||||||
let key1 = gen_fake_x25519_key();
|
let rng = SecureRandom::new();
|
||||||
let key2 = gen_fake_x25519_key();
|
let key1 = gen_fake_x25519_key(&rng);
|
||||||
|
let key2 = gen_fake_x25519_key(&rng);
|
||||||
|
|
||||||
assert_eq!(key1.len(), 32);
|
assert_eq!(key1.len(), 32);
|
||||||
assert_eq!(key2.len(), 32);
|
assert_eq!(key2.len(), 32);
|
||||||
@@ -545,7 +547,8 @@ mod tests {
|
|||||||
let client_digest = [0x42u8; 32];
|
let client_digest = [0x42u8; 32];
|
||||||
let session_id = vec![0xAA; 32];
|
let session_id = vec![0xAA; 32];
|
||||||
|
|
||||||
let response = build_server_hello(secret, &client_digest, &session_id, 2048);
|
let rng = SecureRandom::new();
|
||||||
|
let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng);
|
||||||
|
|
||||||
// Should have at least 3 records
|
// Should have at least 3 records
|
||||||
assert!(response.len() > 100);
|
assert!(response.len() > 100);
|
||||||
@@ -577,8 +580,9 @@ mod tests {
|
|||||||
let client_digest = [0x42u8; 32];
|
let client_digest = [0x42u8; 32];
|
||||||
let session_id = vec![0xAA; 32];
|
let session_id = vec![0xAA; 32];
|
||||||
|
|
||||||
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024);
|
let rng = SecureRandom::new();
|
||||||
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024);
|
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
|
||||||
|
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
|
||||||
|
|
||||||
// Digest position should have non-zero data
|
// Digest position should have non-zero data
|
||||||
let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
|
let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
|
||||||
|
|||||||
@@ -3,32 +3,28 @@
|
|||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
use tracing::{debug, info, warn, error, trace};
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
use crate::error::{ProxyError, Result, HandshakeResult};
|
use crate::crypto::SecureRandom;
|
||||||
|
use crate::error::{HandshakeResult, ProxyError, Result};
|
||||||
use crate::protocol::constants::*;
|
use crate::protocol::constants::*;
|
||||||
use crate::protocol::tls;
|
use crate::protocol::tls;
|
||||||
use crate::stats::{Stats, ReplayChecker};
|
use crate::stats::{ReplayChecker, Stats};
|
||||||
use crate::transport::{configure_client_socket, UpstreamManager};
|
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
||||||
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool};
|
use crate::transport::middle_proxy::MePool;
|
||||||
use crate::crypto::AesCtr;
|
use crate::transport::{UpstreamManager, configure_client_socket};
|
||||||
|
|
||||||
// Use absolute paths to avoid confusion
|
use crate::proxy::direct_relay::handle_via_direct;
|
||||||
use crate::proxy::handshake::{
|
use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake};
|
||||||
handle_tls_handshake, handle_mtproto_handshake,
|
|
||||||
HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce,
|
|
||||||
};
|
|
||||||
use crate::proxy::relay::relay_bidirectional;
|
|
||||||
use crate::proxy::masking::handle_bad_client;
|
use crate::proxy::masking::handle_bad_client;
|
||||||
|
use crate::proxy::middle_relay::handle_via_middle_proxy;
|
||||||
|
|
||||||
/// Client connection handler (builder struct)
|
|
||||||
pub struct ClientHandler;
|
pub struct ClientHandler;
|
||||||
|
|
||||||
/// Running client handler with stream and context
|
|
||||||
pub struct RunningClientHandler {
|
pub struct RunningClientHandler {
|
||||||
stream: TcpStream,
|
stream: TcpStream,
|
||||||
peer: SocketAddr,
|
peer: SocketAddr,
|
||||||
@@ -37,10 +33,11 @@ pub struct RunningClientHandler {
|
|||||||
replay_checker: Arc<ReplayChecker>,
|
replay_checker: Arc<ReplayChecker>,
|
||||||
upstream_manager: Arc<UpstreamManager>,
|
upstream_manager: Arc<UpstreamManager>,
|
||||||
buffer_pool: Arc<BufferPool>,
|
buffer_pool: Arc<BufferPool>,
|
||||||
|
rng: Arc<SecureRandom>,
|
||||||
|
me_pool: Option<Arc<MePool>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClientHandler {
|
impl ClientHandler {
|
||||||
/// Create new client handler instance
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
stream: TcpStream,
|
stream: TcpStream,
|
||||||
peer: SocketAddr,
|
peer: SocketAddr,
|
||||||
@@ -49,6 +46,8 @@ impl ClientHandler {
|
|||||||
upstream_manager: Arc<UpstreamManager>,
|
upstream_manager: Arc<UpstreamManager>,
|
||||||
replay_checker: Arc<ReplayChecker>,
|
replay_checker: Arc<ReplayChecker>,
|
||||||
buffer_pool: Arc<BufferPool>,
|
buffer_pool: Arc<BufferPool>,
|
||||||
|
rng: Arc<SecureRandom>,
|
||||||
|
me_pool: Option<Arc<MePool>>,
|
||||||
) -> RunningClientHandler {
|
) -> RunningClientHandler {
|
||||||
RunningClientHandler {
|
RunningClientHandler {
|
||||||
stream,
|
stream,
|
||||||
@@ -58,19 +57,19 @@ impl ClientHandler {
|
|||||||
replay_checker,
|
replay_checker,
|
||||||
upstream_manager,
|
upstream_manager,
|
||||||
buffer_pool,
|
buffer_pool,
|
||||||
|
rng,
|
||||||
|
me_pool,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RunningClientHandler {
|
impl RunningClientHandler {
|
||||||
/// Run the client handler
|
|
||||||
pub async fn run(mut self) -> Result<()> {
|
pub async fn run(mut self) -> Result<()> {
|
||||||
self.stats.increment_connects_all();
|
self.stats.increment_connects_all();
|
||||||
|
|
||||||
let peer = self.peer;
|
let peer = self.peer;
|
||||||
debug!(peer = %peer, "New connection");
|
debug!(peer = %peer, "New connection");
|
||||||
|
|
||||||
// Configure socket
|
|
||||||
if let Err(e) = configure_client_socket(
|
if let Err(e) = configure_client_socket(
|
||||||
&self.stream,
|
&self.stream,
|
||||||
self.config.timeouts.client_keepalive,
|
self.config.timeouts.client_keepalive,
|
||||||
@@ -79,16 +78,10 @@ impl RunningClientHandler {
|
|||||||
debug!(peer = %peer, error = %e, "Failed to configure client socket");
|
debug!(peer = %peer, error = %e, "Failed to configure client socket");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform handshake with timeout
|
|
||||||
let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
|
let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
|
||||||
|
|
||||||
// Clone stats for error handling block
|
|
||||||
let stats = self.stats.clone();
|
let stats = self.stats.clone();
|
||||||
|
|
||||||
let result = timeout(
|
let result = timeout(handshake_timeout, self.do_handshake()).await;
|
||||||
handshake_timeout,
|
|
||||||
self.do_handshake()
|
|
||||||
).await;
|
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(Ok(())) => {
|
Ok(Ok(())) => {
|
||||||
@@ -107,16 +100,14 @@ impl RunningClientHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Perform handshake and relay
|
|
||||||
async fn do_handshake(mut self) -> Result<()> {
|
async fn do_handshake(mut self) -> Result<()> {
|
||||||
// Read first bytes to determine handshake type
|
|
||||||
let mut first_bytes = [0u8; 5];
|
let mut first_bytes = [0u8; 5];
|
||||||
self.stream.read_exact(&mut first_bytes).await?;
|
self.stream.read_exact(&mut first_bytes).await?;
|
||||||
|
|
||||||
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
||||||
let peer = self.peer;
|
let peer = self.peer;
|
||||||
|
|
||||||
debug!(peer = %peer, is_tls = is_tls, first_bytes = %hex::encode(&first_bytes), "Handshake type detected");
|
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
|
||||||
|
|
||||||
if is_tls {
|
if is_tls {
|
||||||
self.handle_tls_client(first_bytes).await
|
self.handle_tls_client(first_bytes).await
|
||||||
@@ -125,14 +116,9 @@ impl RunningClientHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle TLS-wrapped client
|
async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
|
||||||
async fn handle_tls_client(
|
|
||||||
mut self,
|
|
||||||
first_bytes: [u8; 5],
|
|
||||||
) -> Result<()> {
|
|
||||||
let peer = self.peer;
|
let peer = self.peer;
|
||||||
|
|
||||||
// Read TLS handshake length
|
|
||||||
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
||||||
|
|
||||||
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake");
|
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake");
|
||||||
@@ -140,27 +126,23 @@ impl RunningClientHandler {
|
|||||||
if tls_len < 512 {
|
if tls_len < 512 {
|
||||||
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
|
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
|
||||||
self.stats.increment_connects_bad();
|
self.stats.increment_connects_bad();
|
||||||
// FIX: Split stream into reader/writer for handle_bad_client
|
|
||||||
let (reader, writer) = self.stream.into_split();
|
let (reader, writer) = self.stream.into_split();
|
||||||
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
|
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read full TLS handshake
|
|
||||||
let mut handshake = vec![0u8; 5 + tls_len];
|
let mut handshake = vec![0u8; 5 + tls_len];
|
||||||
handshake[..5].copy_from_slice(&first_bytes);
|
handshake[..5].copy_from_slice(&first_bytes);
|
||||||
self.stream.read_exact(&mut handshake[5..]).await?;
|
self.stream.read_exact(&mut handshake[5..]).await?;
|
||||||
|
|
||||||
// Extract fields before consuming self.stream
|
|
||||||
let config = self.config.clone();
|
let config = self.config.clone();
|
||||||
let replay_checker = self.replay_checker.clone();
|
let replay_checker = self.replay_checker.clone();
|
||||||
let stats = self.stats.clone();
|
let stats = self.stats.clone();
|
||||||
let buffer_pool = self.buffer_pool.clone();
|
let buffer_pool = self.buffer_pool.clone();
|
||||||
|
|
||||||
// Split stream for reading/writing
|
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
|
||||||
let (read_half, write_half) = self.stream.into_split();
|
let (read_half, write_half) = self.stream.into_split();
|
||||||
|
|
||||||
// Handle TLS handshake
|
|
||||||
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
|
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
|
||||||
&handshake,
|
&handshake,
|
||||||
read_half,
|
read_half,
|
||||||
@@ -168,7 +150,10 @@ impl RunningClientHandler {
|
|||||||
peer,
|
peer,
|
||||||
&config,
|
&config,
|
||||||
&replay_checker,
|
&replay_checker,
|
||||||
).await {
|
&self.rng,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
HandshakeResult::Success(result) => result,
|
HandshakeResult::Success(result) => result,
|
||||||
HandshakeResult::BadClient { reader, writer } => {
|
HandshakeResult::BadClient { reader, writer } => {
|
||||||
stats.increment_connects_bad();
|
stats.increment_connects_bad();
|
||||||
@@ -178,13 +163,12 @@ impl RunningClientHandler {
|
|||||||
HandshakeResult::Error(e) => return Err(e),
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Read MTProto handshake through TLS
|
|
||||||
debug!(peer = %peer, "Reading MTProto handshake through TLS");
|
debug!(peer = %peer, "Reading MTProto handshake through TLS");
|
||||||
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
|
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
|
||||||
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
|
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..]
|
||||||
|
.try_into()
|
||||||
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
|
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
|
||||||
|
|
||||||
// Handle MTProto handshake
|
|
||||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||||
&mtproto_handshake,
|
&mtproto_handshake,
|
||||||
tls_reader,
|
tls_reader,
|
||||||
@@ -193,12 +177,16 @@ impl RunningClientHandler {
|
|||||||
&config,
|
&config,
|
||||||
&replay_checker,
|
&replay_checker,
|
||||||
true,
|
true,
|
||||||
).await {
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
HandshakeResult::Success(result) => result,
|
HandshakeResult::Success(result) => result,
|
||||||
HandshakeResult::BadClient { reader, writer } => {
|
HandshakeResult::BadClient {
|
||||||
|
reader: _,
|
||||||
|
writer: _,
|
||||||
|
} => {
|
||||||
stats.increment_connects_bad();
|
stats.increment_connects_bad();
|
||||||
// Valid TLS but invalid MTProto - drop
|
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
|
||||||
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake - dropping");
|
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
@@ -211,42 +199,37 @@ impl RunningClientHandler {
|
|||||||
self.upstream_manager,
|
self.upstream_manager,
|
||||||
self.stats,
|
self.stats,
|
||||||
self.config,
|
self.config,
|
||||||
buffer_pool
|
buffer_pool,
|
||||||
).await
|
self.rng,
|
||||||
|
self.me_pool,
|
||||||
|
local_addr,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle direct (non-TLS) client
|
async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
|
||||||
async fn handle_direct_client(
|
|
||||||
mut self,
|
|
||||||
first_bytes: [u8; 5],
|
|
||||||
) -> Result<()> {
|
|
||||||
let peer = self.peer;
|
let peer = self.peer;
|
||||||
|
|
||||||
// Check if non-TLS modes are enabled
|
|
||||||
if !self.config.general.modes.classic && !self.config.general.modes.secure {
|
if !self.config.general.modes.classic && !self.config.general.modes.secure {
|
||||||
debug!(peer = %peer, "Non-TLS modes disabled");
|
debug!(peer = %peer, "Non-TLS modes disabled");
|
||||||
self.stats.increment_connects_bad();
|
self.stats.increment_connects_bad();
|
||||||
// FIX: Split stream into reader/writer for handle_bad_client
|
|
||||||
let (reader, writer) = self.stream.into_split();
|
let (reader, writer) = self.stream.into_split();
|
||||||
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
|
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read rest of handshake
|
|
||||||
let mut handshake = [0u8; HANDSHAKE_LEN];
|
let mut handshake = [0u8; HANDSHAKE_LEN];
|
||||||
handshake[..5].copy_from_slice(&first_bytes);
|
handshake[..5].copy_from_slice(&first_bytes);
|
||||||
self.stream.read_exact(&mut handshake[5..]).await?;
|
self.stream.read_exact(&mut handshake[5..]).await?;
|
||||||
|
|
||||||
// Extract fields
|
|
||||||
let config = self.config.clone();
|
let config = self.config.clone();
|
||||||
let replay_checker = self.replay_checker.clone();
|
let replay_checker = self.replay_checker.clone();
|
||||||
let stats = self.stats.clone();
|
let stats = self.stats.clone();
|
||||||
let buffer_pool = self.buffer_pool.clone();
|
let buffer_pool = self.buffer_pool.clone();
|
||||||
|
|
||||||
// Split stream
|
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
|
||||||
let (read_half, write_half) = self.stream.into_split();
|
let (read_half, write_half) = self.stream.into_split();
|
||||||
|
|
||||||
// Handle MTProto handshake
|
|
||||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||||
&handshake,
|
&handshake,
|
||||||
read_half,
|
read_half,
|
||||||
@@ -255,7 +238,9 @@ impl RunningClientHandler {
|
|||||||
&config,
|
&config,
|
||||||
&replay_checker,
|
&replay_checker,
|
||||||
false,
|
false,
|
||||||
).await {
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
HandshakeResult::Success(result) => result,
|
HandshakeResult::Success(result) => result,
|
||||||
HandshakeResult::BadClient { reader, writer } => {
|
HandshakeResult::BadClient { reader, writer } => {
|
||||||
stats.increment_connects_bad();
|
stats.increment_connects_bad();
|
||||||
@@ -272,11 +257,18 @@ impl RunningClientHandler {
|
|||||||
self.upstream_manager,
|
self.upstream_manager,
|
||||||
self.stats,
|
self.stats,
|
||||||
self.config,
|
self.config,
|
||||||
buffer_pool
|
buffer_pool,
|
||||||
).await
|
self.rng,
|
||||||
|
self.me_pool,
|
||||||
|
local_addr,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Static version of handle_authenticated_inner
|
/// Main dispatch after successful handshake.
|
||||||
|
/// Two modes:
|
||||||
|
/// - Direct: TCP relay to TG DC (existing behavior)
|
||||||
|
/// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs)
|
||||||
async fn handle_authenticated_static<R, W>(
|
async fn handle_authenticated_static<R, W>(
|
||||||
client_reader: CryptoReader<R>,
|
client_reader: CryptoReader<R>,
|
||||||
client_writer: CryptoWriter<W>,
|
client_writer: CryptoWriter<W>,
|
||||||
@@ -285,6 +277,9 @@ impl RunningClientHandler {
|
|||||||
stats: Arc<Stats>,
|
stats: Arc<Stats>,
|
||||||
config: Arc<ProxyConfig>,
|
config: Arc<ProxyConfig>,
|
||||||
buffer_pool: Arc<BufferPool>,
|
buffer_pool: Arc<BufferPool>,
|
||||||
|
rng: Arc<SecureRandom>,
|
||||||
|
me_pool: Option<Arc<MePool>>,
|
||||||
|
local_addr: SocketAddr,
|
||||||
) -> Result<()>
|
) -> Result<()>
|
||||||
where
|
where
|
||||||
R: AsyncRead + Unpin + Send + 'static,
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
@@ -292,149 +287,68 @@ impl RunningClientHandler {
|
|||||||
{
|
{
|
||||||
let user = &success.user;
|
let user = &success.user;
|
||||||
|
|
||||||
// Check user limits
|
|
||||||
if let Err(e) = Self::check_user_limits_static(user, &config, &stats) {
|
if let Err(e) = Self::check_user_limits_static(user, &config, &stats) {
|
||||||
warn!(user = %user, error = %e, "User limit exceeded");
|
warn!(user = %user, error = %e, "User limit exceeded");
|
||||||
return Err(e);
|
return Err(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get datacenter address
|
// Decide: middle proxy or direct
|
||||||
let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?;
|
if config.general.use_middle_proxy {
|
||||||
|
if let Some(ref pool) = me_pool {
|
||||||
info!(
|
return handle_via_middle_proxy(
|
||||||
user = %user,
|
|
||||||
peer = %success.peer,
|
|
||||||
dc = success.dc_idx,
|
|
||||||
dc_addr = %dc_addr,
|
|
||||||
proto = ?success.proto_tag,
|
|
||||||
fast_mode = config.general.fast_mode,
|
|
||||||
"Connecting to Telegram"
|
|
||||||
);
|
|
||||||
|
|
||||||
// Connect to Telegram via UpstreamManager
|
|
||||||
let tg_stream = upstream_manager.connect(dc_addr).await?;
|
|
||||||
|
|
||||||
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected to Telegram, performing handshake");
|
|
||||||
|
|
||||||
// Perform Telegram handshake and get crypto streams
|
|
||||||
let (tg_reader, tg_writer) = Self::do_tg_handshake_static(
|
|
||||||
tg_stream,
|
|
||||||
&success,
|
|
||||||
&config,
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
debug!(peer = %success.peer, "Telegram handshake complete, starting relay");
|
|
||||||
|
|
||||||
// Update stats
|
|
||||||
stats.increment_user_connects(user);
|
|
||||||
stats.increment_user_curr_connects(user);
|
|
||||||
|
|
||||||
// Relay traffic using buffer pool
|
|
||||||
let relay_result = relay_bidirectional(
|
|
||||||
client_reader,
|
client_reader,
|
||||||
client_writer,
|
client_writer,
|
||||||
tg_reader,
|
success,
|
||||||
tg_writer,
|
pool.clone(),
|
||||||
user,
|
stats,
|
||||||
Arc::clone(&stats),
|
config,
|
||||||
buffer_pool,
|
buffer_pool,
|
||||||
).await;
|
local_addr,
|
||||||
|
)
|
||||||
// Update stats
|
.await;
|
||||||
stats.decrement_user_curr_connects(user);
|
}
|
||||||
|
warn!("use_middle_proxy=true but MePool not initialized, falling back to direct");
|
||||||
match &relay_result {
|
|
||||||
Ok(()) => debug!(user = %user, peer = %success.peer, "Relay completed normally"),
|
|
||||||
Err(e) => debug!(user = %user, peer = %success.peer, error = %e, "Relay ended with error"),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
relay_result
|
// Direct mode (original behavior)
|
||||||
|
handle_via_direct(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
success,
|
||||||
|
upstream_manager,
|
||||||
|
stats,
|
||||||
|
config,
|
||||||
|
buffer_pool,
|
||||||
|
rng,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check user limits (static version)
|
|
||||||
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
|
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
|
||||||
// Check expiration
|
|
||||||
if let Some(expiration) = config.access.user_expirations.get(user) {
|
if let Some(expiration) = config.access.user_expirations.get(user) {
|
||||||
if chrono::Utc::now() > *expiration {
|
if chrono::Utc::now() > *expiration {
|
||||||
return Err(ProxyError::UserExpired { user: user.to_string() });
|
return Err(ProxyError::UserExpired {
|
||||||
|
user: user.to_string(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check connection limit
|
|
||||||
if let Some(limit) = config.access.user_max_tcp_conns.get(user) {
|
if let Some(limit) = config.access.user_max_tcp_conns.get(user) {
|
||||||
let current = stats.get_user_curr_connects(user);
|
if stats.get_user_curr_connects(user) >= *limit as u64 {
|
||||||
if current >= *limit as u64 {
|
return Err(ProxyError::ConnectionLimitExceeded {
|
||||||
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
|
user: user.to_string(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check data quota
|
|
||||||
if let Some(quota) = config.access.user_data_quota.get(user) {
|
if let Some(quota) = config.access.user_data_quota.get(user) {
|
||||||
let used = stats.get_user_total_octets(user);
|
if stats.get_user_total_octets(user) >= *quota {
|
||||||
if used >= *quota {
|
return Err(ProxyError::DataQuotaExceeded {
|
||||||
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
|
user: user.to_string(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get datacenter address by index (static version)
|
|
||||||
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
|
||||||
let idx = (dc_idx.abs() - 1) as usize;
|
|
||||||
|
|
||||||
let datacenters = if config.general.prefer_ipv6 {
|
|
||||||
&*TG_DATACENTERS_V6
|
|
||||||
} else {
|
|
||||||
&*TG_DATACENTERS_V4
|
|
||||||
};
|
|
||||||
|
|
||||||
datacenters.get(idx)
|
|
||||||
.map(|ip| SocketAddr::new(*ip, TG_DATACENTER_PORT))
|
|
||||||
.ok_or_else(|| ProxyError::InvalidHandshake(
|
|
||||||
format!("Invalid DC index: {}", dc_idx)
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Perform handshake with Telegram server (static version)
|
|
||||||
async fn do_tg_handshake_static(
|
|
||||||
mut stream: TcpStream,
|
|
||||||
success: &HandshakeSuccess,
|
|
||||||
config: &ProxyConfig,
|
|
||||||
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
|
|
||||||
// Generate nonce with keys for TG
|
|
||||||
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
|
|
||||||
success.proto_tag,
|
|
||||||
&success.dec_key, // Client's dec key
|
|
||||||
success.dec_iv,
|
|
||||||
config.general.fast_mode,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Encrypt nonce
|
|
||||||
let encrypted_nonce = encrypt_tg_nonce(&nonce);
|
|
||||||
|
|
||||||
debug!(
|
|
||||||
peer = %success.peer,
|
|
||||||
nonce_head = %hex::encode(&nonce[..16]),
|
|
||||||
encrypted_head = %hex::encode(&encrypted_nonce[..16]),
|
|
||||||
"Sending nonce to Telegram"
|
|
||||||
);
|
|
||||||
|
|
||||||
// Send to Telegram
|
|
||||||
stream.write_all(&encrypted_nonce).await?;
|
|
||||||
stream.flush().await?;
|
|
||||||
|
|
||||||
debug!(peer = %success.peer, "Nonce sent to Telegram");
|
|
||||||
|
|
||||||
// Split stream and wrap with crypto
|
|
||||||
let (read_half, write_half) = stream.into_split();
|
|
||||||
|
|
||||||
let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv);
|
|
||||||
let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv);
|
|
||||||
|
|
||||||
let tg_reader = CryptoReader::new(read_half, decryptor);
|
|
||||||
let tg_writer = CryptoWriter::new(write_half, encryptor);
|
|
||||||
|
|
||||||
Ok((tg_reader, tg_writer))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
163
src/proxy/direct_relay.rs
Normal file
163
src/proxy/direct_relay.rs
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
|
use crate::config::ProxyConfig;
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
|
use crate::error::Result;
|
||||||
|
use crate::protocol::constants::*;
|
||||||
|
use crate::proxy::handshake::{HandshakeSuccess, encrypt_tg_nonce_with_ciphers, generate_tg_nonce};
|
||||||
|
use crate::proxy::relay::relay_bidirectional;
|
||||||
|
use crate::stats::Stats;
|
||||||
|
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
||||||
|
use crate::transport::UpstreamManager;
|
||||||
|
|
||||||
|
pub(crate) async fn handle_via_direct<R, W>(
|
||||||
|
client_reader: CryptoReader<R>,
|
||||||
|
client_writer: CryptoWriter<W>,
|
||||||
|
success: HandshakeSuccess,
|
||||||
|
upstream_manager: Arc<UpstreamManager>,
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
config: Arc<ProxyConfig>,
|
||||||
|
buffer_pool: Arc<BufferPool>,
|
||||||
|
rng: Arc<SecureRandom>,
|
||||||
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
let user = &success.user;
|
||||||
|
let dc_addr = get_dc_addr_static(success.dc_idx, &config)?;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
user = %user,
|
||||||
|
peer = %success.peer,
|
||||||
|
dc = success.dc_idx,
|
||||||
|
dc_addr = %dc_addr,
|
||||||
|
proto = ?success.proto_tag,
|
||||||
|
mode = "direct",
|
||||||
|
"Connecting to Telegram DC"
|
||||||
|
);
|
||||||
|
|
||||||
|
let tg_stream = upstream_manager
|
||||||
|
.connect(dc_addr, Some(success.dc_idx))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
|
||||||
|
|
||||||
|
let (tg_reader, tg_writer) =
|
||||||
|
do_tg_handshake_static(tg_stream, &success, &config, rng.as_ref()).await?;
|
||||||
|
|
||||||
|
debug!(peer = %success.peer, "TG handshake complete, starting relay");
|
||||||
|
|
||||||
|
stats.increment_user_connects(user);
|
||||||
|
stats.increment_user_curr_connects(user);
|
||||||
|
|
||||||
|
let relay_result = relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
tg_reader,
|
||||||
|
tg_writer,
|
||||||
|
user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
buffer_pool,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
stats.decrement_user_curr_connects(user);
|
||||||
|
|
||||||
|
match &relay_result {
|
||||||
|
Ok(()) => debug!(user = %user, "Direct relay completed"),
|
||||||
|
Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"),
|
||||||
|
}
|
||||||
|
|
||||||
|
relay_result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
||||||
|
let datacenters = if config.general.prefer_ipv6 {
|
||||||
|
&*TG_DATACENTERS_V6
|
||||||
|
} else {
|
||||||
|
&*TG_DATACENTERS_V4
|
||||||
|
};
|
||||||
|
|
||||||
|
let num_dcs = datacenters.len();
|
||||||
|
|
||||||
|
let dc_key = dc_idx.to_string();
|
||||||
|
if let Some(addr_str) = config.dc_overrides.get(&dc_key) {
|
||||||
|
match addr_str.parse::<SocketAddr>() {
|
||||||
|
Ok(addr) => {
|
||||||
|
debug!(dc_idx = dc_idx, addr = %addr, "Using DC override from config");
|
||||||
|
return Ok(addr);
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
warn!(dc_idx = dc_idx, addr_str = %addr_str,
|
||||||
|
"Invalid DC override address in config, ignoring");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let abs_dc = dc_idx.unsigned_abs() as usize;
|
||||||
|
if abs_dc >= 1 && abs_dc <= num_dcs {
|
||||||
|
return Ok(SocketAddr::new(datacenters[abs_dc - 1], TG_DATACENTER_PORT));
|
||||||
|
}
|
||||||
|
|
||||||
|
let default_dc = config.default_dc.unwrap_or(2) as usize;
|
||||||
|
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
|
||||||
|
default_dc - 1
|
||||||
|
} else {
|
||||||
|
1
|
||||||
|
};
|
||||||
|
|
||||||
|
info!(
|
||||||
|
original_dc = dc_idx,
|
||||||
|
fallback_dc = (fallback_idx + 1) as u16,
|
||||||
|
fallback_addr = %datacenters[fallback_idx],
|
||||||
|
"Special DC ---> default_cluster"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(SocketAddr::new(
|
||||||
|
datacenters[fallback_idx],
|
||||||
|
TG_DATACENTER_PORT,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn do_tg_handshake_static(
|
||||||
|
mut stream: TcpStream,
|
||||||
|
success: &HandshakeSuccess,
|
||||||
|
config: &ProxyConfig,
|
||||||
|
rng: &SecureRandom,
|
||||||
|
) -> Result<(
|
||||||
|
CryptoReader<tokio::net::tcp::OwnedReadHalf>,
|
||||||
|
CryptoWriter<tokio::net::tcp::OwnedWriteHalf>,
|
||||||
|
)> {
|
||||||
|
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce(
|
||||||
|
success.proto_tag,
|
||||||
|
success.dc_idx,
|
||||||
|
&success.dec_key,
|
||||||
|
success.dec_iv,
|
||||||
|
rng,
|
||||||
|
config.general.fast_mode,
|
||||||
|
);
|
||||||
|
|
||||||
|
let (encrypted_nonce, tg_encryptor, tg_decryptor) = encrypt_tg_nonce_with_ciphers(&nonce);
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
peer = %success.peer,
|
||||||
|
nonce_head = %hex::encode(&nonce[..16]),
|
||||||
|
"Sending nonce to Telegram"
|
||||||
|
);
|
||||||
|
|
||||||
|
stream.write_all(&encrypted_nonce).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
|
||||||
|
let (read_half, write_half) = stream.into_split();
|
||||||
|
|
||||||
|
Ok((
|
||||||
|
CryptoReader::new(read_half, tg_decryptor),
|
||||||
|
CryptoWriter::new(write_half, tg_encryptor),
|
||||||
|
))
|
||||||
|
}
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
//! MTProto Handshake Magics
|
//! MTProto Handshake
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||||
use tracing::{debug, warn, trace, info};
|
use tracing::{debug, warn, trace, info};
|
||||||
|
use zeroize::Zeroize;
|
||||||
|
|
||||||
use crate::crypto::{sha256, AesCtr};
|
use crate::crypto::{sha256, AesCtr, SecureRandom};
|
||||||
use crate::crypto::random::SECURE_RANDOM;
|
|
||||||
use crate::protocol::constants::*;
|
use crate::protocol::constants::*;
|
||||||
use crate::protocol::tls;
|
use crate::protocol::tls;
|
||||||
use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
|
use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
|
||||||
@@ -14,6 +14,9 @@ use crate::stats::ReplayChecker;
|
|||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
|
|
||||||
/// Result of successful handshake
|
/// Result of successful handshake
|
||||||
|
///
|
||||||
|
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
|
||||||
|
/// zeroized on drop.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct HandshakeSuccess {
|
pub struct HandshakeSuccess {
|
||||||
/// Authenticated user name
|
/// Authenticated user name
|
||||||
@@ -34,6 +37,15 @@ pub struct HandshakeSuccess {
|
|||||||
pub is_tls: bool,
|
pub is_tls: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Drop for HandshakeSuccess {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.dec_key.zeroize();
|
||||||
|
self.dec_iv.zeroize();
|
||||||
|
self.enc_key.zeroize();
|
||||||
|
self.enc_iv.zeroize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Handle fake TLS handshake
|
/// Handle fake TLS handshake
|
||||||
pub async fn handle_tls_handshake<R, W>(
|
pub async fn handle_tls_handshake<R, W>(
|
||||||
handshake: &[u8],
|
handshake: &[u8],
|
||||||
@@ -42,6 +54,7 @@ pub async fn handle_tls_handshake<R, W>(
|
|||||||
peer: SocketAddr,
|
peer: SocketAddr,
|
||||||
config: &ProxyConfig,
|
config: &ProxyConfig,
|
||||||
replay_checker: &ReplayChecker,
|
replay_checker: &ReplayChecker,
|
||||||
|
rng: &SecureRandom,
|
||||||
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
|
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
|
||||||
where
|
where
|
||||||
R: AsyncRead + Unpin,
|
R: AsyncRead + Unpin,
|
||||||
@@ -49,30 +62,25 @@ where
|
|||||||
{
|
{
|
||||||
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
|
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
|
||||||
|
|
||||||
// Check minimum length
|
|
||||||
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
|
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
|
||||||
debug!(peer = %peer, "TLS handshake too short");
|
debug!(peer = %peer, "TLS handshake too short");
|
||||||
return HandshakeResult::BadClient { reader, writer };
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract digest for replay check
|
|
||||||
let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN];
|
let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN];
|
||||||
let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN];
|
let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN];
|
||||||
|
|
||||||
// Check for replay
|
|
||||||
if replay_checker.check_tls_digest(digest_half) {
|
if replay_checker.check_tls_digest(digest_half) {
|
||||||
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
|
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
|
||||||
return HandshakeResult::BadClient { reader, writer };
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build secrets list
|
|
||||||
let secrets: Vec<(String, Vec<u8>)> = config.access.users.iter()
|
let secrets: Vec<(String, Vec<u8>)> = config.access.users.iter()
|
||||||
.filter_map(|(name, hex)| {
|
.filter_map(|(name, hex)| {
|
||||||
hex::decode(hex).ok().map(|bytes| (name.clone(), bytes))
|
hex::decode(hex).ok().map(|bytes| (name.clone(), bytes))
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Validate handshake
|
|
||||||
let validation = match tls::validate_tls_handshake(
|
let validation = match tls::validate_tls_handshake(
|
||||||
handshake,
|
handshake,
|
||||||
&secrets,
|
&secrets,
|
||||||
@@ -89,18 +97,17 @@ where
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get secret for response
|
|
||||||
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
|
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
|
||||||
Some((_, s)) => s,
|
Some((_, s)) => s,
|
||||||
None => return HandshakeResult::BadClient { reader, writer },
|
None => return HandshakeResult::BadClient { reader, writer },
|
||||||
};
|
};
|
||||||
|
|
||||||
// Build and send response
|
|
||||||
let response = tls::build_server_hello(
|
let response = tls::build_server_hello(
|
||||||
secret,
|
secret,
|
||||||
&validation.digest,
|
&validation.digest,
|
||||||
&validation.session_id,
|
&validation.session_id,
|
||||||
config.censorship.fake_cert_len,
|
config.censorship.fake_cert_len,
|
||||||
|
rng,
|
||||||
);
|
);
|
||||||
|
|
||||||
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
|
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
|
||||||
@@ -115,7 +122,6 @@ where
|
|||||||
return HandshakeResult::Error(ProxyError::Io(e));
|
return HandshakeResult::Error(ProxyError::Io(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Record for replay protection only after successful handshake
|
|
||||||
replay_checker.add_tls_digest(digest_half);
|
replay_checker.add_tls_digest(digest_half);
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
@@ -147,26 +153,21 @@ where
|
|||||||
{
|
{
|
||||||
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
|
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
|
||||||
|
|
||||||
// Extract prekey and IV
|
|
||||||
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
|
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
|
||||||
|
|
||||||
// Check for replay
|
|
||||||
if replay_checker.check_handshake(dec_prekey_iv) {
|
if replay_checker.check_handshake(dec_prekey_iv) {
|
||||||
warn!(peer = %peer, "MTProto replay attack detected");
|
warn!(peer = %peer, "MTProto replay attack detected");
|
||||||
return HandshakeResult::BadClient { reader, writer };
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reversed for encryption direction
|
|
||||||
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
|
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
|
||||||
|
|
||||||
// Try each user's secret
|
|
||||||
for (user, secret_hex) in &config.access.users {
|
for (user, secret_hex) in &config.access.users {
|
||||||
let secret = match hex::decode(secret_hex) {
|
let secret = match hex::decode(secret_hex) {
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(_) => continue,
|
Err(_) => continue,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Derive decryption key
|
|
||||||
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
|
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
|
||||||
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
|
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
|
||||||
|
|
||||||
@@ -177,11 +178,9 @@ where
|
|||||||
|
|
||||||
let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
|
let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
|
||||||
|
|
||||||
// Decrypt handshake to check protocol tag
|
|
||||||
let mut decryptor = AesCtr::new(&dec_key, dec_iv);
|
let mut decryptor = AesCtr::new(&dec_key, dec_iv);
|
||||||
let decrypted = decryptor.decrypt(handshake);
|
let decrypted = decryptor.decrypt(handshake);
|
||||||
|
|
||||||
// Check protocol tag
|
|
||||||
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
|
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
|
||||||
.try_into()
|
.try_into()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -191,7 +190,6 @@ where
|
|||||||
None => continue,
|
None => continue,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check if mode is enabled
|
|
||||||
let mode_ok = match proto_tag {
|
let mode_ok = match proto_tag {
|
||||||
ProtoTag::Secure => {
|
ProtoTag::Secure => {
|
||||||
if is_tls { config.general.modes.tls } else { config.general.modes.secure }
|
if is_tls { config.general.modes.tls } else { config.general.modes.secure }
|
||||||
@@ -204,12 +202,10 @@ where
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract DC index
|
|
||||||
let dc_idx = i16::from_le_bytes(
|
let dc_idx = i16::from_le_bytes(
|
||||||
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
|
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Derive encryption key
|
|
||||||
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
|
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
|
||||||
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
|
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
|
||||||
|
|
||||||
@@ -220,11 +216,8 @@ where
|
|||||||
|
|
||||||
let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
|
let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
|
||||||
|
|
||||||
// Record for replay protection
|
|
||||||
replay_checker.add_handshake(dec_prekey_iv);
|
replay_checker.add_handshake(dec_prekey_iv);
|
||||||
|
|
||||||
// Create new cipher instances
|
|
||||||
let decryptor = AesCtr::new(&dec_key, dec_iv);
|
|
||||||
let encryptor = AesCtr::new(&enc_key, enc_iv);
|
let encryptor = AesCtr::new(&enc_key, enc_iv);
|
||||||
|
|
||||||
let success = HandshakeSuccess {
|
let success = HandshakeSuccess {
|
||||||
@@ -262,12 +255,14 @@ where
|
|||||||
/// Generate nonce for Telegram connection
|
/// Generate nonce for Telegram connection
|
||||||
pub fn generate_tg_nonce(
|
pub fn generate_tg_nonce(
|
||||||
proto_tag: ProtoTag,
|
proto_tag: ProtoTag,
|
||||||
|
dc_idx: i16,
|
||||||
client_dec_key: &[u8; 32],
|
client_dec_key: &[u8; 32],
|
||||||
client_dec_iv: u128,
|
client_dec_iv: u128,
|
||||||
|
rng: &SecureRandom,
|
||||||
fast_mode: bool,
|
fast_mode: bool,
|
||||||
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
|
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
|
||||||
loop {
|
loop {
|
||||||
let bytes = SECURE_RANDOM.bytes(HANDSHAKE_LEN);
|
let bytes = rng.bytes(HANDSHAKE_LEN);
|
||||||
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
|
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
|
||||||
|
|
||||||
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
|
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
|
||||||
@@ -279,6 +274,8 @@ pub fn generate_tg_nonce(
|
|||||||
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; }
|
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; }
|
||||||
|
|
||||||
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
|
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
|
||||||
|
// CRITICAL: write dc_idx so upstream DC knows where to route
|
||||||
|
nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
|
||||||
|
|
||||||
if fast_mode {
|
if fast_mode {
|
||||||
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key);
|
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key);
|
||||||
@@ -299,19 +296,32 @@ pub fn generate_tg_nonce(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Encrypt nonce for sending to Telegram
|
/// Encrypt nonce for sending to Telegram and return cipher objects with correct counter state
|
||||||
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
|
pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, AesCtr, AesCtr) {
|
||||||
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
||||||
let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
|
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
|
||||||
let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
|
|
||||||
|
|
||||||
let mut encryptor = AesCtr::new(&key, iv);
|
let enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
|
||||||
let encrypted_full = encryptor.encrypt(nonce);
|
let enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
|
||||||
|
|
||||||
|
let dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
|
||||||
|
let dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
|
||||||
|
|
||||||
|
let mut encryptor = AesCtr::new(&enc_key, enc_iv);
|
||||||
|
let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4
|
||||||
|
|
||||||
let mut result = nonce[..PROTO_TAG_POS].to_vec();
|
let mut result = nonce[..PROTO_TAG_POS].to_vec();
|
||||||
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
|
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
|
||||||
|
|
||||||
result
|
let decryptor = AesCtr::new(&dec_key, dec_iv);
|
||||||
|
|
||||||
|
(result, encryptor, decryptor)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encrypt nonce for sending to Telegram (legacy function for compatibility)
|
||||||
|
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
|
||||||
|
let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce);
|
||||||
|
encrypted
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -323,13 +333,12 @@ mod tests {
|
|||||||
let client_dec_key = [0x42u8; 32];
|
let client_dec_key = [0x42u8; 32];
|
||||||
let client_dec_iv = 12345u128;
|
let client_dec_iv = 12345u128;
|
||||||
|
|
||||||
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) =
|
let rng = SecureRandom::new();
|
||||||
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false);
|
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) =
|
||||||
|
generate_tg_nonce(ProtoTag::Secure, 2, &client_dec_key, client_dec_iv, &rng, false);
|
||||||
|
|
||||||
// Check length
|
|
||||||
assert_eq!(nonce.len(), HANDSHAKE_LEN);
|
assert_eq!(nonce.len(), HANDSHAKE_LEN);
|
||||||
|
|
||||||
// Check proto tag is set
|
|
||||||
let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap();
|
let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap();
|
||||||
assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure));
|
assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure));
|
||||||
}
|
}
|
||||||
@@ -339,17 +348,35 @@ mod tests {
|
|||||||
let client_dec_key = [0x42u8; 32];
|
let client_dec_key = [0x42u8; 32];
|
||||||
let client_dec_iv = 12345u128;
|
let client_dec_iv = 12345u128;
|
||||||
|
|
||||||
|
let rng = SecureRandom::new();
|
||||||
let (nonce, _, _, _, _) =
|
let (nonce, _, _, _, _) =
|
||||||
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false);
|
generate_tg_nonce(ProtoTag::Secure, 2, &client_dec_key, client_dec_iv, &rng, false);
|
||||||
|
|
||||||
let encrypted = encrypt_tg_nonce(&nonce);
|
let encrypted = encrypt_tg_nonce(&nonce);
|
||||||
|
|
||||||
assert_eq!(encrypted.len(), HANDSHAKE_LEN);
|
assert_eq!(encrypted.len(), HANDSHAKE_LEN);
|
||||||
|
|
||||||
// First PROTO_TAG_POS bytes should be unchanged
|
|
||||||
assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]);
|
assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]);
|
||||||
|
|
||||||
// Rest should be different (encrypted)
|
|
||||||
assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]);
|
assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_handshake_success_zeroize_on_drop() {
|
||||||
|
let success = HandshakeSuccess {
|
||||||
|
user: "test".to_string(),
|
||||||
|
dc_idx: 2,
|
||||||
|
proto_tag: ProtoTag::Secure,
|
||||||
|
dec_key: [0xAA; 32],
|
||||||
|
dec_iv: 0xBBBBBBBB,
|
||||||
|
enc_key: [0xCC; 32],
|
||||||
|
enc_iv: 0xDDDDDDDD,
|
||||||
|
peer: "127.0.0.1:1234".parse().unwrap(),
|
||||||
|
is_tls: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(success.dec_key, [0xAA; 32]);
|
||||||
|
assert_eq!(success.enc_key, [0xCC; 32]);
|
||||||
|
|
||||||
|
drop(success);
|
||||||
|
// Drop impl zeroizes key material without panic
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -3,12 +3,17 @@
|
|||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::str;
|
use std::str;
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
|
#[cfg(unix)]
|
||||||
|
use tokio::net::UnixStream;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
|
|
||||||
const MASK_TIMEOUT: Duration = Duration::from_secs(5);
|
const MASK_TIMEOUT: Duration = Duration::from_secs(5);
|
||||||
|
/// Maximum duration for the entire masking relay.
|
||||||
|
/// Limits resource consumption from slow-loris attacks and port scanners.
|
||||||
|
const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60);
|
||||||
const MASK_BUFFER_SIZE: usize = 8192;
|
const MASK_BUFFER_SIZE: usize = 8192;
|
||||||
|
|
||||||
/// Detect client type based on initial data
|
/// Detect client type based on initial data
|
||||||
@@ -42,8 +47,8 @@ fn detect_client_type(data: &[u8]) -> &'static str {
|
|||||||
|
|
||||||
/// Handle a bad client by forwarding to mask host
|
/// Handle a bad client by forwarding to mask host
|
||||||
pub async fn handle_bad_client<R, W>(
|
pub async fn handle_bad_client<R, W>(
|
||||||
mut reader: R,
|
reader: R,
|
||||||
mut writer: W,
|
writer: W,
|
||||||
initial_data: &[u8],
|
initial_data: &[u8],
|
||||||
config: &ProxyConfig,
|
config: &ProxyConfig,
|
||||||
)
|
)
|
||||||
@@ -59,6 +64,34 @@ where
|
|||||||
|
|
||||||
let client_type = detect_client_type(initial_data);
|
let client_type = detect_client_type(initial_data);
|
||||||
|
|
||||||
|
// Connect via Unix socket or TCP
|
||||||
|
#[cfg(unix)]
|
||||||
|
if let Some(ref sock_path) = config.censorship.mask_unix_sock {
|
||||||
|
debug!(
|
||||||
|
client_type = client_type,
|
||||||
|
sock = %sock_path,
|
||||||
|
data_len = initial_data.len(),
|
||||||
|
"Forwarding bad client to mask unix socket"
|
||||||
|
);
|
||||||
|
|
||||||
|
let connect_result = timeout(MASK_TIMEOUT, UnixStream::connect(sock_path)).await;
|
||||||
|
match connect_result {
|
||||||
|
Ok(Ok(stream)) => {
|
||||||
|
let (mask_read, mask_write) = stream.into_split();
|
||||||
|
relay_to_mask(reader, writer, mask_read, mask_write, initial_data).await;
|
||||||
|
}
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
debug!(error = %e, "Failed to connect to mask unix socket");
|
||||||
|
consume_client_data(reader).await;
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
debug!("Timeout connecting to mask unix socket");
|
||||||
|
consume_client_data(reader).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
let mask_host = config.censorship.mask_host.as_deref()
|
let mask_host = config.censorship.mask_host.as_deref()
|
||||||
.unwrap_or(&config.censorship.tls_domain);
|
.unwrap_or(&config.censorship.tls_domain);
|
||||||
let mask_port = config.censorship.mask_port;
|
let mask_port = config.censorship.mask_port;
|
||||||
@@ -73,27 +106,37 @@ where
|
|||||||
|
|
||||||
// Connect to mask host
|
// Connect to mask host
|
||||||
let mask_addr = format!("{}:{}", mask_host, mask_port);
|
let mask_addr = format!("{}:{}", mask_host, mask_port);
|
||||||
let connect_result = timeout(
|
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
|
||||||
MASK_TIMEOUT,
|
match connect_result {
|
||||||
TcpStream::connect(&mask_addr)
|
Ok(Ok(stream)) => {
|
||||||
).await;
|
let (mask_read, mask_write) = stream.into_split();
|
||||||
|
relay_to_mask(reader, writer, mask_read, mask_write, initial_data).await;
|
||||||
let mask_stream = match connect_result {
|
}
|
||||||
Ok(Ok(s)) => s,
|
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
debug!(error = %e, "Failed to connect to mask host");
|
debug!(error = %e, "Failed to connect to mask host");
|
||||||
consume_client_data(reader).await;
|
consume_client_data(reader).await;
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
debug!("Timeout connecting to mask host");
|
debug!("Timeout connecting to mask host");
|
||||||
consume_client_data(reader).await;
|
consume_client_data(reader).await;
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
}
|
||||||
let (mut mask_read, mut mask_write) = mask_stream.into_split();
|
|
||||||
|
|
||||||
|
/// Relay traffic between client and mask backend
|
||||||
|
async fn relay_to_mask<R, W, MR, MW>(
|
||||||
|
mut reader: R,
|
||||||
|
mut writer: W,
|
||||||
|
mut mask_read: MR,
|
||||||
|
mut mask_write: MW,
|
||||||
|
initial_data: &[u8],
|
||||||
|
)
|
||||||
|
where
|
||||||
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
MR: AsyncRead + Unpin + Send + 'static,
|
||||||
|
MW: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
// Send initial data to mask host
|
// Send initial data to mask host
|
||||||
if mask_write.write_all(initial_data).await.is_err() {
|
if mask_write.write_all(initial_data).await.is_err() {
|
||||||
return;
|
return;
|
||||||
|
|||||||
254
src/proxy/middle_relay.rs
Normal file
254
src/proxy/middle_relay.rs
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
use tracing::{debug, info, trace};
|
||||||
|
|
||||||
|
use crate::config::ProxyConfig;
|
||||||
|
use crate::error::{ProxyError, Result};
|
||||||
|
use crate::protocol::constants::*;
|
||||||
|
use crate::proxy::handshake::HandshakeSuccess;
|
||||||
|
use crate::stats::Stats;
|
||||||
|
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
||||||
|
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
|
||||||
|
|
||||||
|
pub(crate) async fn handle_via_middle_proxy<R, W>(
|
||||||
|
mut crypto_reader: CryptoReader<R>,
|
||||||
|
mut crypto_writer: CryptoWriter<W>,
|
||||||
|
success: HandshakeSuccess,
|
||||||
|
me_pool: Arc<MePool>,
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
_config: Arc<ProxyConfig>,
|
||||||
|
_buffer_pool: Arc<BufferPool>,
|
||||||
|
local_addr: SocketAddr,
|
||||||
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
let user = success.user.clone();
|
||||||
|
let peer = success.peer;
|
||||||
|
let proto_tag = success.proto_tag;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
user = %user,
|
||||||
|
peer = %peer,
|
||||||
|
dc = success.dc_idx,
|
||||||
|
proto = ?proto_tag,
|
||||||
|
mode = "middle_proxy",
|
||||||
|
"Routing via Middle-End"
|
||||||
|
);
|
||||||
|
|
||||||
|
let (conn_id, mut me_rx) = me_pool.registry().register().await;
|
||||||
|
|
||||||
|
stats.increment_user_connects(&user);
|
||||||
|
stats.increment_user_curr_connects(&user);
|
||||||
|
|
||||||
|
let proto_flags = proto_flags_for_tag(proto_tag, me_pool.has_proxy_tag());
|
||||||
|
debug!(
|
||||||
|
user = %user,
|
||||||
|
conn_id,
|
||||||
|
proto_flags = format_args!("0x{:08x}", proto_flags),
|
||||||
|
"ME relay started"
|
||||||
|
);
|
||||||
|
|
||||||
|
let translated_local_addr = me_pool.translate_our_addr(local_addr);
|
||||||
|
|
||||||
|
let result: Result<()> = loop {
|
||||||
|
tokio::select! {
|
||||||
|
client_frame = read_client_payload(&mut crypto_reader, proto_tag) => {
|
||||||
|
match client_frame {
|
||||||
|
Ok(Some(payload)) => {
|
||||||
|
trace!(conn_id, bytes = payload.len(), "C->ME frame");
|
||||||
|
stats.add_user_octets_from(&user, payload.len() as u64);
|
||||||
|
me_pool.send_proxy_req(
|
||||||
|
conn_id,
|
||||||
|
success.dc_idx,
|
||||||
|
peer,
|
||||||
|
translated_local_addr,
|
||||||
|
&payload,
|
||||||
|
proto_flags,
|
||||||
|
).await?;
|
||||||
|
}
|
||||||
|
Ok(None) => {
|
||||||
|
debug!(conn_id, "Client EOF");
|
||||||
|
let _ = me_pool.send_close(conn_id).await;
|
||||||
|
break Ok(());
|
||||||
|
}
|
||||||
|
Err(e) => break Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
me_msg = me_rx.recv() => {
|
||||||
|
match me_msg {
|
||||||
|
Some(MeResponse::Data { flags, data }) => {
|
||||||
|
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
|
||||||
|
stats.add_user_octets_to(&user, data.len() as u64);
|
||||||
|
write_client_payload(&mut crypto_writer, proto_tag, flags, &data).await?;
|
||||||
|
}
|
||||||
|
Some(MeResponse::Ack(confirm)) => {
|
||||||
|
trace!(conn_id, confirm, "ME->C quickack");
|
||||||
|
write_client_ack(&mut crypto_writer, proto_tag, confirm).await?;
|
||||||
|
}
|
||||||
|
Some(MeResponse::Close) => {
|
||||||
|
debug!(conn_id, "ME sent close");
|
||||||
|
break Ok(());
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
debug!(conn_id, "ME channel closed");
|
||||||
|
break Err(ProxyError::Proxy("ME connection lost".into()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!(user = %user, conn_id, "ME relay cleanup");
|
||||||
|
me_pool.registry().unregister(conn_id).await;
|
||||||
|
stats.decrement_user_curr_connects(&user);
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn read_client_payload<R>(
|
||||||
|
client_reader: &mut CryptoReader<R>,
|
||||||
|
proto_tag: ProtoTag,
|
||||||
|
) -> Result<Option<Vec<u8>>>
|
||||||
|
where
|
||||||
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
let len = match proto_tag {
|
||||||
|
ProtoTag::Abridged => {
|
||||||
|
let mut first = [0u8; 1];
|
||||||
|
match client_reader.read_exact(&mut first).await {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||||
|
Err(e) => return Err(ProxyError::Io(e)),
|
||||||
|
}
|
||||||
|
|
||||||
|
let len_words = if (first[0] & 0x7f) == 0x7f {
|
||||||
|
let mut ext = [0u8; 3];
|
||||||
|
client_reader
|
||||||
|
.read_exact(&mut ext)
|
||||||
|
.await
|
||||||
|
.map_err(ProxyError::Io)?;
|
||||||
|
u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize
|
||||||
|
} else {
|
||||||
|
(first[0] & 0x7f) as usize
|
||||||
|
};
|
||||||
|
|
||||||
|
len_words
|
||||||
|
.checked_mul(4)
|
||||||
|
.ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?
|
||||||
|
}
|
||||||
|
ProtoTag::Intermediate | ProtoTag::Secure => {
|
||||||
|
let mut len_buf = [0u8; 4];
|
||||||
|
match client_reader.read_exact(&mut len_buf).await {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||||
|
Err(e) => return Err(ProxyError::Io(e)),
|
||||||
|
}
|
||||||
|
(u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if len > 16 * 1024 * 1024 {
|
||||||
|
return Err(ProxyError::Proxy(format!("Frame too large: {len}")));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut payload = vec![0u8; len];
|
||||||
|
client_reader
|
||||||
|
.read_exact(&mut payload)
|
||||||
|
.await
|
||||||
|
.map_err(ProxyError::Io)?;
|
||||||
|
Ok(Some(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn write_client_payload<W>(
|
||||||
|
client_writer: &mut CryptoWriter<W>,
|
||||||
|
proto_tag: ProtoTag,
|
||||||
|
flags: u32,
|
||||||
|
data: &[u8],
|
||||||
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
let quickack = (flags & RPC_FLAG_QUICKACK) != 0;
|
||||||
|
|
||||||
|
match proto_tag {
|
||||||
|
ProtoTag::Abridged => {
|
||||||
|
if data.len() % 4 != 0 {
|
||||||
|
return Err(ProxyError::Proxy(format!(
|
||||||
|
"Abridged payload must be 4-byte aligned, got {}",
|
||||||
|
data.len()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let len_words = data.len() / 4;
|
||||||
|
if len_words < 0x7f {
|
||||||
|
let mut first = len_words as u8;
|
||||||
|
if quickack {
|
||||||
|
first |= 0x80;
|
||||||
|
}
|
||||||
|
client_writer
|
||||||
|
.write_all(&[first])
|
||||||
|
.await
|
||||||
|
.map_err(ProxyError::Io)?;
|
||||||
|
} else if len_words < (1 << 24) {
|
||||||
|
let mut first = 0x7fu8;
|
||||||
|
if quickack {
|
||||||
|
first |= 0x80;
|
||||||
|
}
|
||||||
|
let lw = (len_words as u32).to_le_bytes();
|
||||||
|
client_writer
|
||||||
|
.write_all(&[first, lw[0], lw[1], lw[2]])
|
||||||
|
.await
|
||||||
|
.map_err(ProxyError::Io)?;
|
||||||
|
} else {
|
||||||
|
return Err(ProxyError::Proxy(format!(
|
||||||
|
"Abridged frame too large: {}",
|
||||||
|
data.len()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
client_writer
|
||||||
|
.write_all(data)
|
||||||
|
.await
|
||||||
|
.map_err(ProxyError::Io)?;
|
||||||
|
}
|
||||||
|
ProtoTag::Intermediate | ProtoTag::Secure => {
|
||||||
|
let mut len = data.len() as u32;
|
||||||
|
if quickack {
|
||||||
|
len |= 0x8000_0000;
|
||||||
|
}
|
||||||
|
client_writer
|
||||||
|
.write_all(&len.to_le_bytes())
|
||||||
|
.await
|
||||||
|
.map_err(ProxyError::Io)?;
|
||||||
|
client_writer
|
||||||
|
.write_all(data)
|
||||||
|
.await
|
||||||
|
.map_err(ProxyError::Io)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client_writer.flush().await.map_err(ProxyError::Io)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn write_client_ack<W>(
|
||||||
|
client_writer: &mut CryptoWriter<W>,
|
||||||
|
proto_tag: ProtoTag,
|
||||||
|
confirm: u32,
|
||||||
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
let bytes = if proto_tag == ProtoTag::Abridged {
|
||||||
|
confirm.to_be_bytes()
|
||||||
|
} else {
|
||||||
|
confirm.to_le_bytes()
|
||||||
|
};
|
||||||
|
client_writer
|
||||||
|
.write_all(&bytes)
|
||||||
|
.await
|
||||||
|
.map_err(ProxyError::Io)?;
|
||||||
|
client_writer.flush().await.map_err(ProxyError::Io)
|
||||||
|
}
|
||||||
@@ -1,11 +1,13 @@
|
|||||||
//! Proxy Defs
|
//! Proxy Defs
|
||||||
|
|
||||||
pub mod handshake;
|
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod relay;
|
pub mod direct_relay;
|
||||||
|
pub mod handshake;
|
||||||
pub mod masking;
|
pub mod masking;
|
||||||
|
pub mod middle_relay;
|
||||||
|
pub mod relay;
|
||||||
|
|
||||||
pub use handshake::*;
|
|
||||||
pub use client::ClientHandler;
|
pub use client::ClientHandler;
|
||||||
pub use relay::*;
|
pub use handshake::*;
|
||||||
pub use masking::*;
|
pub use masking::*;
|
||||||
|
pub use relay::*;
|
||||||
|
|||||||
@@ -1,27 +1,320 @@
|
|||||||
//! Bidirectional Relay
|
//! Bidirectional Relay — poll-based, no head-of-line blocking
|
||||||
|
//!
|
||||||
|
//! ## What changed and why
|
||||||
|
//!
|
||||||
|
//! Previous implementation used a single-task `select! { biased; ... }` loop
|
||||||
|
//! where each branch called `write_all()`. This caused head-of-line blocking:
|
||||||
|
//! while `write_all()` waited for a slow writer (e.g. client on 3G downloading
|
||||||
|
//! media), the entire loop was blocked — the other direction couldn't make progress.
|
||||||
|
//!
|
||||||
|
//! Symptoms observed in production:
|
||||||
|
//! - Media loading at ~8 KB/s despite fast server connection
|
||||||
|
//! - Stop-and-go pattern with 50–500ms gaps between chunks
|
||||||
|
//! - `biased` select starving S→C direction
|
||||||
|
//! - Some users unable to load media at all
|
||||||
|
//!
|
||||||
|
//! ## New architecture
|
||||||
|
//!
|
||||||
|
//! Uses `tokio::io::copy_bidirectional` which polls both directions concurrently
|
||||||
|
//! in a single task via non-blocking `poll_read` / `poll_write` calls:
|
||||||
|
//!
|
||||||
|
//! Old (select! + write_all — BLOCKING):
|
||||||
|
//!
|
||||||
|
//! loop {
|
||||||
|
//! select! {
|
||||||
|
//! biased;
|
||||||
|
//! data = client.read() => { server.write_all(data).await; } ← BLOCKS here
|
||||||
|
//! data = server.read() => { client.write_all(data).await; } ← can't run
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! New (copy_bidirectional — CONCURRENT):
|
||||||
|
//!
|
||||||
|
//! poll(cx) {
|
||||||
|
//! // Both directions polled in the same poll cycle
|
||||||
|
//! C→S: poll_read(client) → poll_write(server) // non-blocking
|
||||||
|
//! S→C: poll_read(server) → poll_write(client) // non-blocking
|
||||||
|
//! // If one writer is Pending, the other direction still progresses
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! Benefits:
|
||||||
|
//! - No head-of-line blocking: slow client download doesn't block uploads
|
||||||
|
//! - No biased starvation: fair polling of both directions
|
||||||
|
//! - Proper flush: `copy_bidirectional` calls `poll_flush` when reader stalls,
|
||||||
|
//! so CryptoWriter's pending ciphertext is always drained (fixes "stuck at 95%")
|
||||||
|
//! - No deadlock risk: old write_all could deadlock when both TCP buffers filled;
|
||||||
|
//! poll-based approach lets TCP flow control work correctly
|
||||||
|
//!
|
||||||
|
//! Stats tracking:
|
||||||
|
//! - `StatsIo` wraps client side, intercepts `poll_read` / `poll_write`
|
||||||
|
//! - `poll_read` on client = C→S (client sending) → `octets_from`, `msgs_from`
|
||||||
|
//! - `poll_write` on client = S→C (to client) → `octets_to`, `msgs_to`
|
||||||
|
//! - `SharedCounters` (atomics) let the watchdog read stats without locking
|
||||||
|
|
||||||
|
use std::io;
|
||||||
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
use std::task::{Context, Poll};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{debug, trace, warn, info};
|
use tracing::{debug, trace, warn};
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
use crate::stats::Stats;
|
use crate::stats::Stats;
|
||||||
use crate::stream::BufferPool;
|
use crate::stream::BufferPool;
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
|
||||||
|
|
||||||
// Activity timeout for iOS compatibility (30 minutes)
|
// ============= Constants =============
|
||||||
const ACTIVITY_TIMEOUT_SECS: u64 = 1800;
|
|
||||||
|
|
||||||
/// Relay data bidirectionally between client and server
|
/// Activity timeout for iOS compatibility.
|
||||||
|
///
|
||||||
|
/// iOS keeps Telegram connections alive in background for up to 30 minutes.
|
||||||
|
/// Closing earlier causes unnecessary reconnects and handshake overhead.
|
||||||
|
const ACTIVITY_TIMEOUT: Duration = Duration::from_secs(1800);
|
||||||
|
|
||||||
|
/// Watchdog check interval — also used for periodic rate logging.
|
||||||
|
///
|
||||||
|
/// 10 seconds gives responsive timeout detection (±10s accuracy)
|
||||||
|
/// without measurable overhead from atomic reads.
|
||||||
|
const WATCHDOG_INTERVAL: Duration = Duration::from_secs(10);
|
||||||
|
|
||||||
|
// ============= CombinedStream =============
|
||||||
|
|
||||||
|
/// Combines separate read and write halves into a single bidirectional stream.
|
||||||
|
///
|
||||||
|
/// `copy_bidirectional` requires `AsyncRead + AsyncWrite` on each side,
|
||||||
|
/// but the handshake layer produces split reader/writer pairs
|
||||||
|
/// (e.g. `CryptoReader<FakeTlsReader<OwnedReadHalf>>` + `CryptoWriter<...>`).
|
||||||
|
///
|
||||||
|
/// This wrapper reunifies them with zero overhead — each trait method
|
||||||
|
/// delegates directly to the corresponding half. No buffering, no copies.
|
||||||
|
///
|
||||||
|
/// Safety: `poll_read` only touches `reader`, `poll_write` only touches `writer`,
|
||||||
|
/// so there's no aliasing even though both are called on the same `&mut self`.
|
||||||
|
struct CombinedStream<R, W> {
|
||||||
|
reader: R,
|
||||||
|
writer: W,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R, W> CombinedStream<R, W> {
|
||||||
|
fn new(reader: R, writer: W) -> Self {
|
||||||
|
Self { reader, writer }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: AsyncRead + Unpin, W: Unpin> AsyncRead for CombinedStream<R, W> {
|
||||||
|
#[inline]
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.get_mut().reader).poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: Unpin, W: AsyncWrite + Unpin> AsyncWrite for CombinedStream<R, W> {
|
||||||
|
#[inline]
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.get_mut().writer).poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.get_mut().writer).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.get_mut().writer).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= SharedCounters =============
|
||||||
|
|
||||||
|
/// Atomic counters shared between the relay (via StatsIo) and the watchdog task.
|
||||||
|
///
|
||||||
|
/// Using `Relaxed` ordering is sufficient because:
|
||||||
|
/// - Counters are monotonically increasing (no ABA problem)
|
||||||
|
/// - Slight staleness in watchdog reads is harmless (±10s check interval anyway)
|
||||||
|
/// - No ordering dependencies between different counters
|
||||||
|
struct SharedCounters {
|
||||||
|
/// Bytes read from client (C→S direction)
|
||||||
|
c2s_bytes: AtomicU64,
|
||||||
|
/// Bytes written to client (S→C direction)
|
||||||
|
s2c_bytes: AtomicU64,
|
||||||
|
/// Number of poll_read completions (≈ C→S chunks)
|
||||||
|
c2s_ops: AtomicU64,
|
||||||
|
/// Number of poll_write completions (≈ S→C chunks)
|
||||||
|
s2c_ops: AtomicU64,
|
||||||
|
/// Milliseconds since relay epoch of last I/O activity
|
||||||
|
last_activity_ms: AtomicU64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SharedCounters {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
c2s_bytes: AtomicU64::new(0),
|
||||||
|
s2c_bytes: AtomicU64::new(0),
|
||||||
|
c2s_ops: AtomicU64::new(0),
|
||||||
|
s2c_ops: AtomicU64::new(0),
|
||||||
|
last_activity_ms: AtomicU64::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Record activity at this instant.
|
||||||
|
#[inline]
|
||||||
|
fn touch(&self, now: Instant, epoch: Instant) {
|
||||||
|
let ms = now.duration_since(epoch).as_millis() as u64;
|
||||||
|
self.last_activity_ms.store(ms, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// How long since last recorded activity.
|
||||||
|
fn idle_duration(&self, now: Instant, epoch: Instant) -> Duration {
|
||||||
|
let last_ms = self.last_activity_ms.load(Ordering::Relaxed);
|
||||||
|
let now_ms = now.duration_since(epoch).as_millis() as u64;
|
||||||
|
Duration::from_millis(now_ms.saturating_sub(last_ms))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= StatsIo =============
|
||||||
|
|
||||||
|
/// Transparent I/O wrapper that tracks per-user statistics and activity.
|
||||||
|
///
|
||||||
|
/// Wraps the **client** side of the relay. Direction mapping:
|
||||||
|
///
|
||||||
|
/// | poll method | direction | stats updated |
|
||||||
|
/// |-------------|-----------|--------------------------------------|
|
||||||
|
/// | `poll_read` | C→S | `octets_from`, `msgs_from`, counters |
|
||||||
|
/// | `poll_write` | S→C | `octets_to`, `msgs_to`, counters |
|
||||||
|
///
|
||||||
|
/// Both update the shared activity timestamp for the watchdog.
|
||||||
|
///
|
||||||
|
/// Note on message counts: the original code counted one `read()`/`write_all()`
|
||||||
|
/// as one "message". Here we count `poll_read`/`poll_write` completions instead.
|
||||||
|
/// Byte counts are identical; op counts may differ slightly due to different
|
||||||
|
/// internal buffering in `copy_bidirectional`. This is fine for monitoring.
|
||||||
|
struct StatsIo<S> {
|
||||||
|
inner: S,
|
||||||
|
counters: Arc<SharedCounters>,
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
user: String,
|
||||||
|
epoch: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> StatsIo<S> {
|
||||||
|
fn new(
|
||||||
|
inner: S,
|
||||||
|
counters: Arc<SharedCounters>,
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
user: String,
|
||||||
|
epoch: Instant,
|
||||||
|
) -> Self {
|
||||||
|
// Mark initial activity so the watchdog doesn't fire before data flows
|
||||||
|
counters.touch(Instant::now(), epoch);
|
||||||
|
Self { inner, counters, stats, user, epoch }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
let this = self.get_mut();
|
||||||
|
let before = buf.filled().len();
|
||||||
|
|
||||||
|
match Pin::new(&mut this.inner).poll_read(cx, buf) {
|
||||||
|
Poll::Ready(Ok(())) => {
|
||||||
|
let n = buf.filled().len() - before;
|
||||||
|
if n > 0 {
|
||||||
|
// C→S: client sent data
|
||||||
|
this.counters.c2s_bytes.fetch_add(n as u64, Ordering::Relaxed);
|
||||||
|
this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed);
|
||||||
|
this.counters.touch(Instant::now(), this.epoch);
|
||||||
|
|
||||||
|
this.stats.add_user_octets_from(&this.user, n as u64);
|
||||||
|
this.stats.increment_user_msgs_from(&this.user);
|
||||||
|
|
||||||
|
trace!(user = %this.user, bytes = n, "C->S");
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
other => other,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
let this = self.get_mut();
|
||||||
|
|
||||||
|
match Pin::new(&mut this.inner).poll_write(cx, buf) {
|
||||||
|
Poll::Ready(Ok(n)) => {
|
||||||
|
if n > 0 {
|
||||||
|
// S→C: data written to client
|
||||||
|
this.counters.s2c_bytes.fetch_add(n as u64, Ordering::Relaxed);
|
||||||
|
this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed);
|
||||||
|
this.counters.touch(Instant::now(), this.epoch);
|
||||||
|
|
||||||
|
this.stats.add_user_octets_to(&this.user, n as u64);
|
||||||
|
this.stats.increment_user_msgs_to(&this.user);
|
||||||
|
|
||||||
|
trace!(user = %this.user, bytes = n, "S->C");
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(n))
|
||||||
|
}
|
||||||
|
other => other,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Relay =============
|
||||||
|
|
||||||
|
/// Relay data bidirectionally between client and server.
|
||||||
|
///
|
||||||
|
/// Uses `tokio::io::copy_bidirectional` for concurrent, non-blocking data transfer.
|
||||||
|
///
|
||||||
|
/// ## API compatibility
|
||||||
|
///
|
||||||
|
/// Signature is identical to the previous implementation. The `_buffer_pool`
|
||||||
|
/// parameter is retained for call-site compatibility — `copy_bidirectional`
|
||||||
|
/// manages its own internal buffers (8 KB per direction).
|
||||||
|
///
|
||||||
|
/// ## Guarantees preserved
|
||||||
|
///
|
||||||
|
/// - Activity timeout: 30 minutes of inactivity → clean shutdown
|
||||||
|
/// - Per-user stats: bytes and ops counted per direction
|
||||||
|
/// - Periodic rate logging: every 10 seconds when active
|
||||||
|
/// - Clean shutdown: both write sides are shut down on exit
|
||||||
|
/// - Error propagation: I/O errors are returned as `ProxyError::Io`
|
||||||
pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
||||||
mut client_reader: CR,
|
client_reader: CR,
|
||||||
mut client_writer: CW,
|
client_writer: CW,
|
||||||
mut server_reader: SR,
|
server_reader: SR,
|
||||||
mut server_writer: SW,
|
server_writer: SW,
|
||||||
user: &str,
|
user: &str,
|
||||||
stats: Arc<Stats>,
|
stats: Arc<Stats>,
|
||||||
buffer_pool: Arc<BufferPool>,
|
_buffer_pool: Arc<BufferPool>,
|
||||||
) -> Result<()>
|
) -> Result<()>
|
||||||
where
|
where
|
||||||
CR: AsyncRead + Unpin + Send + 'static,
|
CR: AsyncRead + Unpin + Send + 'static,
|
||||||
@@ -29,234 +322,145 @@ where
|
|||||||
SR: AsyncRead + Unpin + Send + 'static,
|
SR: AsyncRead + Unpin + Send + 'static,
|
||||||
SW: AsyncWrite + Unpin + Send + 'static,
|
SW: AsyncWrite + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
let user_c2s = user.to_string();
|
let epoch = Instant::now();
|
||||||
let user_s2c = user.to_string();
|
let counters = Arc::new(SharedCounters::new());
|
||||||
|
let user_owned = user.to_string();
|
||||||
|
|
||||||
let stats_c2s = Arc::clone(&stats);
|
// ── Combine split halves into bidirectional streams ──────────────
|
||||||
let stats_s2c = Arc::clone(&stats);
|
let client_combined = CombinedStream::new(client_reader, client_writer);
|
||||||
|
let mut server = CombinedStream::new(server_reader, server_writer);
|
||||||
|
|
||||||
let c2s_bytes = Arc::new(AtomicU64::new(0));
|
// Wrap client with stats/activity tracking
|
||||||
let s2c_bytes = Arc::new(AtomicU64::new(0));
|
let mut client = StatsIo::new(
|
||||||
let c2s_bytes_clone = Arc::clone(&c2s_bytes);
|
client_combined,
|
||||||
let s2c_bytes_clone = Arc::clone(&s2c_bytes);
|
Arc::clone(&counters),
|
||||||
|
Arc::clone(&stats),
|
||||||
|
user_owned.clone(),
|
||||||
|
epoch,
|
||||||
|
);
|
||||||
|
|
||||||
let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS);
|
// ── Watchdog: activity timeout + periodic rate logging ──────────
|
||||||
|
let wd_counters = Arc::clone(&counters);
|
||||||
|
let wd_user = user_owned.clone();
|
||||||
|
|
||||||
let pool_c2s = buffer_pool.clone();
|
let watchdog = async {
|
||||||
let pool_s2c = buffer_pool.clone();
|
let mut prev_c2s: u64 = 0;
|
||||||
|
let mut prev_s2c: u64 = 0;
|
||||||
// Client -> Server task
|
|
||||||
let c2s = tokio::spawn(async move {
|
|
||||||
// Get buffer from pool
|
|
||||||
let mut pooled_buf = pool_c2s.get();
|
|
||||||
// CRITICAL FIX: BytesMut from pool has len 0. We must resize it to be usable as &mut [u8].
|
|
||||||
// We use the full capacity.
|
|
||||||
let cap = pooled_buf.capacity();
|
|
||||||
pooled_buf.resize(cap, 0);
|
|
||||||
|
|
||||||
let mut total_bytes = 0u64;
|
|
||||||
let mut prev_total_bytes = 0u64;
|
|
||||||
let mut msg_count = 0u64;
|
|
||||||
let mut last_activity = Instant::now();
|
|
||||||
let mut last_log = Instant::now();
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
// Read with timeout
|
tokio::time::sleep(WATCHDOG_INTERVAL).await;
|
||||||
let read_result = tokio::time::timeout(
|
|
||||||
activity_timeout,
|
|
||||||
client_reader.read(&mut pooled_buf)
|
|
||||||
).await;
|
|
||||||
|
|
||||||
match read_result {
|
let now = Instant::now();
|
||||||
Err(_) => {
|
let idle = wd_counters.idle_duration(now, epoch);
|
||||||
|
|
||||||
|
// ── Activity timeout ────────────────────────────────────
|
||||||
|
if idle >= ACTIVITY_TIMEOUT {
|
||||||
|
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
|
||||||
|
let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed);
|
||||||
warn!(
|
warn!(
|
||||||
user = %user_c2s,
|
user = %wd_user,
|
||||||
total_bytes = total_bytes,
|
c2s_bytes = c2s,
|
||||||
msgs = msg_count,
|
s2c_bytes = s2c,
|
||||||
idle_secs = last_activity.elapsed().as_secs(),
|
idle_secs = idle.as_secs(),
|
||||||
"Activity timeout (C->S) - no data received"
|
"Activity timeout"
|
||||||
);
|
);
|
||||||
let _ = server_writer.shutdown().await;
|
return; // Causes select! to cancel copy_bidirectional
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Ok(0)) => {
|
// ── Periodic rate logging ───────────────────────────────
|
||||||
|
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
|
||||||
|
let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed);
|
||||||
|
let c2s_delta = c2s - prev_c2s;
|
||||||
|
let s2c_delta = s2c - prev_s2c;
|
||||||
|
|
||||||
|
if c2s_delta > 0 || s2c_delta > 0 {
|
||||||
|
let secs = WATCHDOG_INTERVAL.as_secs_f64();
|
||||||
debug!(
|
debug!(
|
||||||
user = %user_c2s,
|
user = %wd_user,
|
||||||
total_bytes = total_bytes,
|
c2s_kbps = (c2s_delta as f64 / secs / 1024.0) as u64,
|
||||||
msgs = msg_count,
|
s2c_kbps = (s2c_delta as f64 / secs / 1024.0) as u64,
|
||||||
"Client closed connection (C->S)"
|
c2s_total = c2s,
|
||||||
|
s2c_total = s2c,
|
||||||
|
"Relay active"
|
||||||
);
|
);
|
||||||
let _ = server_writer.shutdown().await;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Ok(n)) => {
|
prev_c2s = c2s;
|
||||||
total_bytes += n as u64;
|
prev_s2c = s2c;
|
||||||
msg_count += 1;
|
}
|
||||||
last_activity = Instant::now();
|
};
|
||||||
c2s_bytes_clone.store(total_bytes, Ordering::Relaxed);
|
|
||||||
|
|
||||||
stats_c2s.add_user_octets_from(&user_c2s, n as u64);
|
// ── Run bidirectional copy + watchdog concurrently ───────────────
|
||||||
stats_c2s.increment_user_msgs_from(&user_c2s);
|
//
|
||||||
|
// copy_bidirectional polls both directions in the same poll() call:
|
||||||
|
// C→S: poll_read(client/StatsIo) → poll_write(server)
|
||||||
|
// S→C: poll_read(server) → poll_write(client/StatsIo)
|
||||||
|
//
|
||||||
|
// When one direction's writer returns Pending, the other direction
|
||||||
|
// continues — no head-of-line blocking.
|
||||||
|
//
|
||||||
|
// When the watchdog fires, select! drops the copy future,
|
||||||
|
// releasing the &mut borrows on client and server.
|
||||||
|
let copy_result = tokio::select! {
|
||||||
|
result = copy_bidirectional(&mut client, &mut server) => Some(result),
|
||||||
|
_ = watchdog => None, // Activity timeout — cancel relay
|
||||||
|
};
|
||||||
|
|
||||||
trace!(
|
// ── Clean shutdown ──────────────────────────────────────────────
|
||||||
user = %user_c2s,
|
// After select!, the losing future is dropped, borrows released.
|
||||||
bytes = n,
|
// Shut down both write sides for clean TCP FIN.
|
||||||
total = total_bytes,
|
let _ = client.shutdown().await;
|
||||||
"C->S data"
|
let _ = server.shutdown().await;
|
||||||
);
|
|
||||||
|
|
||||||
// Log activity every 10 seconds with correct rate
|
// ── Final logging ───────────────────────────────────────────────
|
||||||
let elapsed = last_log.elapsed();
|
let c2s_ops = counters.c2s_ops.load(Ordering::Relaxed);
|
||||||
if elapsed > Duration::from_secs(10) {
|
let s2c_ops = counters.s2c_ops.load(Ordering::Relaxed);
|
||||||
let delta = total_bytes - prev_total_bytes;
|
let duration = epoch.elapsed();
|
||||||
let rate = delta as f64 / elapsed.as_secs_f64();
|
|
||||||
|
|
||||||
|
match copy_result {
|
||||||
|
Some(Ok((c2s, s2c))) => {
|
||||||
|
// Normal completion — one side closed the connection
|
||||||
debug!(
|
debug!(
|
||||||
user = %user_c2s,
|
user = %user_owned,
|
||||||
total_bytes = total_bytes,
|
c2s_bytes = c2s,
|
||||||
msgs = msg_count,
|
s2c_bytes = s2c,
|
||||||
rate_kbps = (rate / 1024.0) as u64,
|
c2s_msgs = c2s_ops,
|
||||||
"C->S transfer in progress"
|
s2c_msgs = s2c_ops,
|
||||||
);
|
duration_secs = duration.as_secs(),
|
||||||
|
|
||||||
last_log = Instant::now();
|
|
||||||
prev_total_bytes = total_bytes;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Err(e) = server_writer.write_all(&pooled_buf[..n]).await {
|
|
||||||
debug!(user = %user_c2s, error = %e, "Failed to write to server");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Err(e) = server_writer.flush().await {
|
|
||||||
debug!(user = %user_c2s, error = %e, "Failed to flush to server");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Err(e)) => {
|
|
||||||
debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Server -> Client task
|
|
||||||
let s2c = tokio::spawn(async move {
|
|
||||||
// Get buffer from pool
|
|
||||||
let mut pooled_buf = pool_s2c.get();
|
|
||||||
// CRITICAL FIX: Resize buffer
|
|
||||||
let cap = pooled_buf.capacity();
|
|
||||||
pooled_buf.resize(cap, 0);
|
|
||||||
|
|
||||||
let mut total_bytes = 0u64;
|
|
||||||
let mut prev_total_bytes = 0u64;
|
|
||||||
let mut msg_count = 0u64;
|
|
||||||
let mut last_activity = Instant::now();
|
|
||||||
let mut last_log = Instant::now();
|
|
||||||
|
|
||||||
loop {
|
|
||||||
let read_result = tokio::time::timeout(
|
|
||||||
activity_timeout,
|
|
||||||
server_reader.read(&mut pooled_buf)
|
|
||||||
).await;
|
|
||||||
|
|
||||||
match read_result {
|
|
||||||
Err(_) => {
|
|
||||||
warn!(
|
|
||||||
user = %user_s2c,
|
|
||||||
total_bytes = total_bytes,
|
|
||||||
msgs = msg_count,
|
|
||||||
idle_secs = last_activity.elapsed().as_secs(),
|
|
||||||
"Activity timeout (S->C) - no data received"
|
|
||||||
);
|
|
||||||
let _ = client_writer.shutdown().await;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Ok(0)) => {
|
|
||||||
debug!(
|
|
||||||
user = %user_s2c,
|
|
||||||
total_bytes = total_bytes,
|
|
||||||
msgs = msg_count,
|
|
||||||
"Server closed connection (S->C)"
|
|
||||||
);
|
|
||||||
let _ = client_writer.shutdown().await;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Ok(n)) => {
|
|
||||||
total_bytes += n as u64;
|
|
||||||
msg_count += 1;
|
|
||||||
last_activity = Instant::now();
|
|
||||||
s2c_bytes_clone.store(total_bytes, Ordering::Relaxed);
|
|
||||||
|
|
||||||
stats_s2c.add_user_octets_to(&user_s2c, n as u64);
|
|
||||||
stats_s2c.increment_user_msgs_to(&user_s2c);
|
|
||||||
|
|
||||||
trace!(
|
|
||||||
user = %user_s2c,
|
|
||||||
bytes = n,
|
|
||||||
total = total_bytes,
|
|
||||||
"S->C data"
|
|
||||||
);
|
|
||||||
|
|
||||||
let elapsed = last_log.elapsed();
|
|
||||||
if elapsed > Duration::from_secs(10) {
|
|
||||||
let delta = total_bytes - prev_total_bytes;
|
|
||||||
let rate = delta as f64 / elapsed.as_secs_f64();
|
|
||||||
|
|
||||||
debug!(
|
|
||||||
user = %user_s2c,
|
|
||||||
total_bytes = total_bytes,
|
|
||||||
msgs = msg_count,
|
|
||||||
rate_kbps = (rate / 1024.0) as u64,
|
|
||||||
"S->C transfer in progress"
|
|
||||||
);
|
|
||||||
|
|
||||||
last_log = Instant::now();
|
|
||||||
prev_total_bytes = total_bytes;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Err(e) = client_writer.write_all(&pooled_buf[..n]).await {
|
|
||||||
debug!(user = %user_s2c, error = %e, "Failed to write to client");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Err(e) = client_writer.flush().await {
|
|
||||||
debug!(user = %user_s2c, error = %e, "Failed to flush to client");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Err(e)) => {
|
|
||||||
debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Wait for either direction to complete
|
|
||||||
tokio::select! {
|
|
||||||
result = c2s => {
|
|
||||||
if let Err(e) = result {
|
|
||||||
warn!(error = %e, "C->S task panicked");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result = s2c => {
|
|
||||||
if let Err(e) = result {
|
|
||||||
warn!(error = %e, "S->C task panicked");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
debug!(
|
|
||||||
c2s_bytes = c2s_bytes.load(Ordering::Relaxed),
|
|
||||||
s2c_bytes = s2c_bytes.load(Ordering::Relaxed),
|
|
||||||
"Relay finished"
|
"Relay finished"
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
}
|
||||||
|
Some(Err(e)) => {
|
||||||
|
// I/O error in one of the directions
|
||||||
|
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||||
|
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||||
|
debug!(
|
||||||
|
user = %user_owned,
|
||||||
|
c2s_bytes = c2s,
|
||||||
|
s2c_bytes = s2c,
|
||||||
|
c2s_msgs = c2s_ops,
|
||||||
|
s2c_msgs = s2c_ops,
|
||||||
|
duration_secs = duration.as_secs(),
|
||||||
|
error = %e,
|
||||||
|
"Relay error"
|
||||||
|
);
|
||||||
|
Err(e.into())
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
// Activity timeout (watchdog fired)
|
||||||
|
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||||
|
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||||
|
debug!(
|
||||||
|
user = %user_owned,
|
||||||
|
c2s_bytes = c2s,
|
||||||
|
s2c_bytes = s2c,
|
||||||
|
c2s_msgs = c2s_ops,
|
||||||
|
s2c_msgs = s2c_ops,
|
||||||
|
duration_secs = duration.as_secs(),
|
||||||
|
"Relay finished (activity timeout)"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
353
src/stats/mod.rs
353
src/stats/mod.rs
@@ -1,31 +1,28 @@
|
|||||||
//! Statistics
|
//! Statistics and replay protection
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::{Instant, Duration};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use parking_lot::{RwLock, Mutex};
|
use parking_lot::Mutex;
|
||||||
use lru::LruCache;
|
use lru::LruCache;
|
||||||
use std::num::NonZeroUsize;
|
use std::num::NonZeroUsize;
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
use std::collections::hash_map::DefaultHasher;
|
use std::collections::hash_map::DefaultHasher;
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
// ============= Stats =============
|
||||||
|
|
||||||
/// Thread-safe statistics
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct Stats {
|
pub struct Stats {
|
||||||
// Global counters
|
|
||||||
connects_all: AtomicU64,
|
connects_all: AtomicU64,
|
||||||
connects_bad: AtomicU64,
|
connects_bad: AtomicU64,
|
||||||
handshake_timeouts: AtomicU64,
|
handshake_timeouts: AtomicU64,
|
||||||
|
|
||||||
// Per-user stats
|
|
||||||
user_stats: DashMap<String, UserStats>,
|
user_stats: DashMap<String, UserStats>,
|
||||||
|
start_time: parking_lot::RwLock<Option<Instant>>,
|
||||||
// Start time
|
|
||||||
start_time: RwLock<Option<Instant>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Per-user statistics
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct UserStats {
|
pub struct UserStats {
|
||||||
pub connects: AtomicU64,
|
pub connects: AtomicU64,
|
||||||
@@ -43,42 +40,20 @@ impl Stats {
|
|||||||
stats
|
stats
|
||||||
}
|
}
|
||||||
|
|
||||||
// Global stats
|
pub fn increment_connects_all(&self) { self.connects_all.fetch_add(1, Ordering::Relaxed); }
|
||||||
pub fn increment_connects_all(&self) {
|
pub fn increment_connects_bad(&self) { self.connects_bad.fetch_add(1, Ordering::Relaxed); }
|
||||||
self.connects_all.fetch_add(1, Ordering::Relaxed);
|
pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); }
|
||||||
}
|
pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) }
|
||||||
|
pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) }
|
||||||
|
|
||||||
pub fn increment_connects_bad(&self) {
|
|
||||||
self.connects_bad.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn increment_handshake_timeouts(&self) {
|
|
||||||
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_connects_all(&self) -> u64 {
|
|
||||||
self.connects_all.load(Ordering::Relaxed)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_connects_bad(&self) -> u64 {
|
|
||||||
self.connects_bad.load(Ordering::Relaxed)
|
|
||||||
}
|
|
||||||
|
|
||||||
// User stats
|
|
||||||
pub fn increment_user_connects(&self, user: &str) {
|
pub fn increment_user_connects(&self, user: &str) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.connects.fetch_add(1, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.connects
|
|
||||||
.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn increment_user_curr_connects(&self, user: &str) {
|
pub fn increment_user_curr_connects(&self, user: &str) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.curr_connects.fetch_add(1, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.curr_connects
|
|
||||||
.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decrement_user_curr_connects(&self, user: &str) {
|
pub fn decrement_user_curr_connects(&self, user: &str) {
|
||||||
@@ -88,47 +63,33 @@ impl Stats {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_user_curr_connects(&self, user: &str) -> u64 {
|
pub fn get_user_curr_connects(&self, user: &str) -> u64 {
|
||||||
self.user_stats
|
self.user_stats.get(user)
|
||||||
.get(user)
|
|
||||||
.map(|s| s.curr_connects.load(Ordering::Relaxed))
|
.map(|s| s.curr_connects.load(Ordering::Relaxed))
|
||||||
.unwrap_or(0)
|
.unwrap_or(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_user_octets_from(&self, user: &str, bytes: u64) {
|
pub fn add_user_octets_from(&self, user: &str, bytes: u64) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.octets_from_client
|
|
||||||
.fetch_add(bytes, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
|
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.octets_to_client
|
|
||||||
.fetch_add(bytes, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn increment_user_msgs_from(&self, user: &str) {
|
pub fn increment_user_msgs_from(&self, user: &str) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.msgs_from_client.fetch_add(1, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.msgs_from_client
|
|
||||||
.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn increment_user_msgs_to(&self, user: &str) {
|
pub fn increment_user_msgs_to(&self, user: &str) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.msgs_to_client.fetch_add(1, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.msgs_to_client
|
|
||||||
.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_user_total_octets(&self, user: &str) -> u64 {
|
pub fn get_user_total_octets(&self, user: &str) -> u64 {
|
||||||
self.user_stats
|
self.user_stats.get(user)
|
||||||
.get(user)
|
|
||||||
.map(|s| {
|
.map(|s| {
|
||||||
s.octets_from_client.load(Ordering::Relaxed) +
|
s.octets_from_client.load(Ordering::Relaxed) +
|
||||||
s.octets_to_client.load(Ordering::Relaxed)
|
s.octets_to_client.load(Ordering::Relaxed)
|
||||||
@@ -143,57 +104,209 @@ impl Stats {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sharded Replay attack checker using LRU cache
|
// ============= Replay Checker =============
|
||||||
/// Uses multiple independent LRU caches to reduce lock contention
|
|
||||||
pub struct ReplayChecker {
|
pub struct ReplayChecker {
|
||||||
shards: Vec<Mutex<LruCache<Vec<u8>, ()>>>,
|
shards: Vec<Mutex<ReplayShard>>,
|
||||||
shard_mask: usize,
|
shard_mask: usize,
|
||||||
|
window: Duration,
|
||||||
|
checks: AtomicU64,
|
||||||
|
hits: AtomicU64,
|
||||||
|
additions: AtomicU64,
|
||||||
|
cleanups: AtomicU64,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ReplayEntry {
|
||||||
|
seen_at: Instant,
|
||||||
|
seq: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ReplayShard {
|
||||||
|
cache: LruCache<Box<[u8]>, ReplayEntry>,
|
||||||
|
queue: VecDeque<(Instant, Box<[u8]>, u64)>,
|
||||||
|
seq_counter: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReplayShard {
|
||||||
|
fn new(cap: NonZeroUsize) -> Self {
|
||||||
|
Self {
|
||||||
|
cache: LruCache::new(cap),
|
||||||
|
queue: VecDeque::with_capacity(cap.get()),
|
||||||
|
seq_counter: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn next_seq(&mut self) -> u64 {
|
||||||
|
self.seq_counter += 1;
|
||||||
|
self.seq_counter
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cleanup(&mut self, now: Instant, window: Duration) {
|
||||||
|
if window.is_zero() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let cutoff = now.checked_sub(window).unwrap_or(now);
|
||||||
|
|
||||||
|
while let Some((ts, _, _)) = self.queue.front() {
|
||||||
|
if *ts >= cutoff {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let (_, key, queue_seq) = self.queue.pop_front().unwrap();
|
||||||
|
|
||||||
|
// Use key.as_ref() to get &[u8] — avoids Borrow<Q> ambiguity
|
||||||
|
// between Borrow<[u8]> and Borrow<Box<[u8]>>
|
||||||
|
if let Some(entry) = self.cache.peek(key.as_ref()) {
|
||||||
|
if entry.seq == queue_seq {
|
||||||
|
self.cache.pop(key.as_ref());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool {
|
||||||
|
self.cleanup(now, window);
|
||||||
|
// key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]>
|
||||||
|
self.cache.get(key).is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add(&mut self, key: &[u8], now: Instant, window: Duration) {
|
||||||
|
self.cleanup(now, window);
|
||||||
|
|
||||||
|
let seq = self.next_seq();
|
||||||
|
let boxed_key: Box<[u8]> = key.into();
|
||||||
|
|
||||||
|
self.cache.put(boxed_key.clone(), ReplayEntry { seen_at: now, seq });
|
||||||
|
self.queue.push_back((now, boxed_key, seq));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn len(&self) -> usize {
|
||||||
|
self.cache.len()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ReplayChecker {
|
impl ReplayChecker {
|
||||||
/// Create new replay checker with specified capacity per shard
|
pub fn new(total_capacity: usize, window: Duration) -> Self {
|
||||||
/// Total capacity = capacity * num_shards
|
|
||||||
pub fn new(total_capacity: usize) -> Self {
|
|
||||||
// Use 64 shards for good concurrency
|
|
||||||
let num_shards = 64;
|
let num_shards = 64;
|
||||||
let shard_capacity = (total_capacity / num_shards).max(1);
|
let shard_capacity = (total_capacity / num_shards).max(1);
|
||||||
let cap = NonZeroUsize::new(shard_capacity).unwrap();
|
let cap = NonZeroUsize::new(shard_capacity).unwrap();
|
||||||
|
|
||||||
let mut shards = Vec::with_capacity(num_shards);
|
let mut shards = Vec::with_capacity(num_shards);
|
||||||
for _ in 0..num_shards {
|
for _ in 0..num_shards {
|
||||||
shards.push(Mutex::new(LruCache::new(cap)));
|
shards.push(Mutex::new(ReplayShard::new(cap)));
|
||||||
}
|
}
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
shards,
|
shards,
|
||||||
shard_mask: num_shards - 1,
|
shard_mask: num_shards - 1,
|
||||||
|
window,
|
||||||
|
checks: AtomicU64::new(0),
|
||||||
|
hits: AtomicU64::new(0),
|
||||||
|
additions: AtomicU64::new(0),
|
||||||
|
cleanups: AtomicU64::new(0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_shard(&self, key: &[u8]) -> usize {
|
fn get_shard_idx(&self, key: &[u8]) -> usize {
|
||||||
let mut hasher = DefaultHasher::new();
|
let mut hasher = DefaultHasher::new();
|
||||||
key.hash(&mut hasher);
|
key.hash(&mut hasher);
|
||||||
(hasher.finish() as usize) & self.shard_mask
|
(hasher.finish() as usize) & self.shard_mask
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check_handshake(&self, data: &[u8]) -> bool {
|
fn check(&self, data: &[u8]) -> bool {
|
||||||
let shard_idx = self.get_shard(data);
|
self.checks.fetch_add(1, Ordering::Relaxed);
|
||||||
self.shards[shard_idx].lock().contains(&data.to_vec())
|
let idx = self.get_shard_idx(data);
|
||||||
|
let mut shard = self.shards[idx].lock();
|
||||||
|
let found = shard.check(data, Instant::now(), self.window);
|
||||||
|
if found {
|
||||||
|
self.hits.fetch_add(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
found
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_handshake(&self, data: &[u8]) {
|
fn add(&self, data: &[u8]) {
|
||||||
let shard_idx = self.get_shard(data);
|
self.additions.fetch_add(1, Ordering::Relaxed);
|
||||||
self.shards[shard_idx].lock().put(data.to_vec(), ());
|
let idx = self.get_shard_idx(data);
|
||||||
|
let mut shard = self.shards[idx].lock();
|
||||||
|
shard.add(data, Instant::now(), self.window);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
|
pub fn check_handshake(&self, data: &[u8]) -> bool { self.check(data) }
|
||||||
let shard_idx = self.get_shard(data);
|
pub fn add_handshake(&self, data: &[u8]) { self.add(data) }
|
||||||
self.shards[shard_idx].lock().contains(&data.to_vec())
|
pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check(data) }
|
||||||
|
pub fn add_tls_digest(&self, data: &[u8]) { self.add(data) }
|
||||||
|
|
||||||
|
pub fn stats(&self) -> ReplayStats {
|
||||||
|
let mut total_entries = 0;
|
||||||
|
let mut total_queue_len = 0;
|
||||||
|
for shard in &self.shards {
|
||||||
|
let s = shard.lock();
|
||||||
|
total_entries += s.cache.len();
|
||||||
|
total_queue_len += s.queue.len();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_tls_digest(&self, data: &[u8]) {
|
ReplayStats {
|
||||||
let shard_idx = self.get_shard(data);
|
total_entries,
|
||||||
self.shards[shard_idx].lock().put(data.to_vec(), ());
|
total_queue_len,
|
||||||
|
total_checks: self.checks.load(Ordering::Relaxed),
|
||||||
|
total_hits: self.hits.load(Ordering::Relaxed),
|
||||||
|
total_additions: self.additions.load(Ordering::Relaxed),
|
||||||
|
total_cleanups: self.cleanups.load(Ordering::Relaxed),
|
||||||
|
num_shards: self.shards.len(),
|
||||||
|
window_secs: self.window.as_secs(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_periodic_cleanup(&self) {
|
||||||
|
let interval = if self.window.as_secs() > 60 {
|
||||||
|
Duration::from_secs(30)
|
||||||
|
} else {
|
||||||
|
Duration::from_secs(self.window.as_secs().max(1) / 2)
|
||||||
|
};
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(interval).await;
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
let mut cleaned = 0usize;
|
||||||
|
|
||||||
|
for shard_mutex in &self.shards {
|
||||||
|
let mut shard = shard_mutex.lock();
|
||||||
|
let before = shard.len();
|
||||||
|
shard.cleanup(now, self.window);
|
||||||
|
let after = shard.len();
|
||||||
|
cleaned += before.saturating_sub(after);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.cleanups.fetch_add(1, Ordering::Relaxed);
|
||||||
|
|
||||||
|
if cleaned > 0 {
|
||||||
|
debug!(cleaned = cleaned, "Replay checker: periodic cleanup");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ReplayStats {
|
||||||
|
pub total_entries: usize,
|
||||||
|
pub total_queue_len: usize,
|
||||||
|
pub total_checks: u64,
|
||||||
|
pub total_hits: u64,
|
||||||
|
pub total_additions: u64,
|
||||||
|
pub total_cleanups: u64,
|
||||||
|
pub num_shards: usize,
|
||||||
|
pub window_secs: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReplayStats {
|
||||||
|
pub fn hit_rate(&self) -> f64 {
|
||||||
|
if self.total_checks == 0 { 0.0 }
|
||||||
|
else { (self.total_hits as f64 / self.total_checks as f64) * 100.0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ghost_ratio(&self) -> f64 {
|
||||||
|
if self.total_entries == 0 { 0.0 }
|
||||||
|
else { self.total_queue_len as f64 / self.total_entries as f64 }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -204,28 +317,60 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_stats_shared_counters() {
|
fn test_stats_shared_counters() {
|
||||||
let stats = Arc::new(Stats::new());
|
let stats = Arc::new(Stats::new());
|
||||||
|
stats.increment_connects_all();
|
||||||
let stats1 = Arc::clone(&stats);
|
stats.increment_connects_all();
|
||||||
let stats2 = Arc::clone(&stats);
|
stats.increment_connects_all();
|
||||||
|
|
||||||
stats1.increment_connects_all();
|
|
||||||
stats2.increment_connects_all();
|
|
||||||
stats1.increment_connects_all();
|
|
||||||
|
|
||||||
assert_eq!(stats.get_connects_all(), 3);
|
assert_eq!(stats.get_connects_all(), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_replay_checker_sharding() {
|
fn test_replay_checker_basic() {
|
||||||
let checker = ReplayChecker::new(100);
|
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||||
let data1 = b"test1";
|
assert!(!checker.check_handshake(b"test1"));
|
||||||
let data2 = b"test2";
|
checker.add_handshake(b"test1");
|
||||||
|
assert!(checker.check_handshake(b"test1"));
|
||||||
|
assert!(!checker.check_handshake(b"test2"));
|
||||||
|
}
|
||||||
|
|
||||||
checker.add_handshake(data1);
|
#[test]
|
||||||
assert!(checker.check_handshake(data1));
|
fn test_replay_checker_duplicate_add() {
|
||||||
assert!(!checker.check_handshake(data2));
|
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||||
|
checker.add_handshake(b"dup");
|
||||||
|
checker.add_handshake(b"dup");
|
||||||
|
assert!(checker.check_handshake(b"dup"));
|
||||||
|
}
|
||||||
|
|
||||||
checker.add_handshake(data2);
|
#[test]
|
||||||
assert!(checker.check_handshake(data2));
|
fn test_replay_checker_expiration() {
|
||||||
|
let checker = ReplayChecker::new(100, Duration::from_millis(50));
|
||||||
|
checker.add_handshake(b"expire");
|
||||||
|
assert!(checker.check_handshake(b"expire"));
|
||||||
|
std::thread::sleep(Duration::from_millis(100));
|
||||||
|
assert!(!checker.check_handshake(b"expire"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_replay_checker_stats() {
|
||||||
|
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||||
|
checker.add_handshake(b"k1");
|
||||||
|
checker.add_handshake(b"k2");
|
||||||
|
checker.check_handshake(b"k1");
|
||||||
|
checker.check_handshake(b"k3");
|
||||||
|
let stats = checker.stats();
|
||||||
|
assert_eq!(stats.total_additions, 2);
|
||||||
|
assert_eq!(stats.total_checks, 2);
|
||||||
|
assert_eq!(stats.total_hits, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_replay_checker_many_keys() {
|
||||||
|
let checker = ReplayChecker::new(1000, Duration::from_secs(60));
|
||||||
|
for i in 0..500u32 {
|
||||||
|
checker.add(&i.to_le_bytes());
|
||||||
|
}
|
||||||
|
for i in 0..500u32 {
|
||||||
|
assert!(checker.check(&i.to_le_bytes()));
|
||||||
|
}
|
||||||
|
assert_eq!(checker.stats().total_entries, 500);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -5,8 +5,10 @@
|
|||||||
|
|
||||||
use bytes::{Bytes, BytesMut};
|
use bytes::{Bytes, BytesMut};
|
||||||
use std::io::Result;
|
use std::io::Result;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::protocol::constants::ProtoTag;
|
use crate::protocol::constants::ProtoTag;
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
|
|
||||||
// ============= Frame Types =============
|
// ============= Frame Types =============
|
||||||
|
|
||||||
@@ -147,11 +149,11 @@ pub trait FrameCodec: Send + Sync {
|
|||||||
// ============= Codec Factory =============
|
// ============= Codec Factory =============
|
||||||
|
|
||||||
/// Create a frame codec for the given protocol tag
|
/// Create a frame codec for the given protocol tag
|
||||||
pub fn create_codec(proto_tag: ProtoTag) -> Box<dyn FrameCodec> {
|
pub fn create_codec(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Box<dyn FrameCodec> {
|
||||||
match proto_tag {
|
match proto_tag {
|
||||||
ProtoTag::Abridged => Box::new(crate::stream::frame_codec::AbridgedCodec::new()),
|
ProtoTag::Abridged => Box::new(crate::stream::frame_codec::AbridgedCodec::new()),
|
||||||
ProtoTag::Intermediate => Box::new(crate::stream::frame_codec::IntermediateCodec::new()),
|
ProtoTag::Intermediate => Box::new(crate::stream::frame_codec::IntermediateCodec::new()),
|
||||||
ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new()),
|
ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new(rng)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,9 +5,11 @@
|
|||||||
|
|
||||||
use bytes::{Bytes, BytesMut, BufMut};
|
use bytes::{Bytes, BytesMut, BufMut};
|
||||||
use std::io::{self, Error, ErrorKind};
|
use std::io::{self, Error, ErrorKind};
|
||||||
|
use std::sync::Arc;
|
||||||
use tokio_util::codec::{Decoder, Encoder};
|
use tokio_util::codec::{Decoder, Encoder};
|
||||||
|
|
||||||
use crate::protocol::constants::ProtoTag;
|
use crate::protocol::constants::ProtoTag;
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
|
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
|
||||||
|
|
||||||
// ============= Unified Codec =============
|
// ============= Unified Codec =============
|
||||||
@@ -21,14 +23,17 @@ pub struct FrameCodec {
|
|||||||
proto_tag: ProtoTag,
|
proto_tag: ProtoTag,
|
||||||
/// Maximum allowed frame size
|
/// Maximum allowed frame size
|
||||||
max_frame_size: usize,
|
max_frame_size: usize,
|
||||||
|
/// RNG for secure padding
|
||||||
|
rng: Arc<SecureRandom>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FrameCodec {
|
impl FrameCodec {
|
||||||
/// Create a new codec for the given protocol
|
/// Create a new codec for the given protocol
|
||||||
pub fn new(proto_tag: ProtoTag) -> Self {
|
pub fn new(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
proto_tag,
|
proto_tag,
|
||||||
max_frame_size: 16 * 1024 * 1024, // 16MB default
|
max_frame_size: 16 * 1024 * 1024, // 16MB default
|
||||||
|
rng,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,7 +69,7 @@ impl Encoder<Frame> for FrameCodec {
|
|||||||
match self.proto_tag {
|
match self.proto_tag {
|
||||||
ProtoTag::Abridged => encode_abridged(&frame, dst),
|
ProtoTag::Abridged => encode_abridged(&frame, dst),
|
||||||
ProtoTag::Intermediate => encode_intermediate(&frame, dst),
|
ProtoTag::Intermediate => encode_intermediate(&frame, dst),
|
||||||
ProtoTag::Secure => encode_secure(&frame, dst),
|
ProtoTag::Secure => encode_secure(&frame, dst, &self.rng),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -288,9 +293,7 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
|
|||||||
Ok(Some(Frame::with_meta(data, meta)))
|
Ok(Some(Frame::with_meta(data, meta)))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::Result<()> {
|
||||||
use crate::crypto::random::SECURE_RANDOM;
|
|
||||||
|
|
||||||
let data = &frame.data;
|
let data = &frame.data;
|
||||||
|
|
||||||
// Simple ACK: just send data
|
// Simple ACK: just send data
|
||||||
@@ -303,10 +306,10 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
|||||||
// Generate padding to make length not divisible by 4
|
// Generate padding to make length not divisible by 4
|
||||||
let padding_len = if data.len() % 4 == 0 {
|
let padding_len = if data.len() % 4 == 0 {
|
||||||
// Add 1-3 bytes to make it non-aligned
|
// Add 1-3 bytes to make it non-aligned
|
||||||
(SECURE_RANDOM.range(3) + 1) as usize
|
(rng.range(3) + 1) as usize
|
||||||
} else {
|
} else {
|
||||||
// Already non-aligned, can add 0-3
|
// Already non-aligned, can add 0-3
|
||||||
SECURE_RANDOM.range(4) as usize
|
rng.range(4) as usize
|
||||||
};
|
};
|
||||||
|
|
||||||
let total_len = data.len() + padding_len;
|
let total_len = data.len() + padding_len;
|
||||||
@@ -321,7 +324,7 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
|||||||
dst.extend_from_slice(data);
|
dst.extend_from_slice(data);
|
||||||
|
|
||||||
if padding_len > 0 {
|
if padding_len > 0 {
|
||||||
let padding = SECURE_RANDOM.bytes(padding_len);
|
let padding = rng.bytes(padding_len);
|
||||||
dst.extend_from_slice(&padding);
|
dst.extend_from_slice(&padding);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -445,19 +448,21 @@ impl FrameCodecTrait for IntermediateCodec {
|
|||||||
/// Secure Intermediate protocol codec
|
/// Secure Intermediate protocol codec
|
||||||
pub struct SecureCodec {
|
pub struct SecureCodec {
|
||||||
max_frame_size: usize,
|
max_frame_size: usize,
|
||||||
|
rng: Arc<SecureRandom>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SecureCodec {
|
impl SecureCodec {
|
||||||
pub fn new() -> Self {
|
pub fn new(rng: Arc<SecureRandom>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
max_frame_size: 16 * 1024 * 1024,
|
max_frame_size: 16 * 1024 * 1024,
|
||||||
|
rng,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for SecureCodec {
|
impl Default for SecureCodec {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self::new()
|
Self::new(Arc::new(SecureRandom::new()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -474,7 +479,7 @@ impl Encoder<Frame> for SecureCodec {
|
|||||||
type Error = io::Error;
|
type Error = io::Error;
|
||||||
|
|
||||||
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||||
encode_secure(&frame, dst)
|
encode_secure(&frame, dst, &self.rng)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -485,7 +490,7 @@ impl FrameCodecTrait for SecureCodec {
|
|||||||
|
|
||||||
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
|
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
|
||||||
let before = dst.len();
|
let before = dst.len();
|
||||||
encode_secure(frame, dst)?;
|
encode_secure(frame, dst, &self.rng)?;
|
||||||
Ok(dst.len() - before)
|
Ok(dst.len() - before)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -506,6 +511,8 @@ mod tests {
|
|||||||
use tokio_util::codec::{FramedRead, FramedWrite};
|
use tokio_util::codec::{FramedRead, FramedWrite};
|
||||||
use tokio::io::duplex;
|
use tokio::io::duplex;
|
||||||
use futures::{SinkExt, StreamExt};
|
use futures::{SinkExt, StreamExt};
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_framed_abridged() {
|
async fn test_framed_abridged() {
|
||||||
@@ -541,8 +548,8 @@ mod tests {
|
|||||||
async fn test_framed_secure() {
|
async fn test_framed_secure() {
|
||||||
let (client, server) = duplex(4096);
|
let (client, server) = duplex(4096);
|
||||||
|
|
||||||
let mut writer = FramedWrite::new(client, SecureCodec::new());
|
let mut writer = FramedWrite::new(client, SecureCodec::new(Arc::new(SecureRandom::new())));
|
||||||
let mut reader = FramedRead::new(server, SecureCodec::new());
|
let mut reader = FramedRead::new(server, SecureCodec::new(Arc::new(SecureRandom::new())));
|
||||||
|
|
||||||
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||||
let frame = Frame::new(original.clone());
|
let frame = Frame::new(original.clone());
|
||||||
@@ -557,8 +564,8 @@ mod tests {
|
|||||||
for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
|
for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
|
||||||
let (client, server) = duplex(4096);
|
let (client, server) = duplex(4096);
|
||||||
|
|
||||||
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag));
|
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
|
||||||
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag));
|
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
|
||||||
|
|
||||||
// Use 4-byte aligned data for abridged compatibility
|
// Use 4-byte aligned data for abridged compatibility
|
||||||
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||||
@@ -607,7 +614,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_frame_too_large() {
|
fn test_frame_too_large() {
|
||||||
let mut codec = FrameCodec::new(ProtoTag::Intermediate)
|
let mut codec = FrameCodec::new(ProtoTag::Intermediate, Arc::new(SecureRandom::new()))
|
||||||
.with_max_frame_size(100);
|
.with_max_frame_size(100);
|
||||||
|
|
||||||
// Create a "frame" that claims to be very large
|
// Create a "frame" that claims to be very large
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ use bytes::{Bytes, BytesMut};
|
|||||||
use std::io::{Error, ErrorKind, Result};
|
use std::io::{Error, ErrorKind, Result};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||||
use crate::protocol::constants::*;
|
use crate::protocol::constants::*;
|
||||||
use crate::crypto::crc32;
|
use crate::crypto::{crc32, SecureRandom};
|
||||||
use crate::crypto::random::SECURE_RANDOM;
|
use std::sync::Arc;
|
||||||
use super::traits::{FrameMeta, LayeredStream};
|
use super::traits::{FrameMeta, LayeredStream};
|
||||||
|
|
||||||
// ============= Abridged (Compact) Frame =============
|
// ============= Abridged (Compact) Frame =============
|
||||||
@@ -251,11 +251,12 @@ impl<R> LayeredStream<R> for SecureIntermediateFrameReader<R> {
|
|||||||
/// Writer for secure intermediate MTProto framing
|
/// Writer for secure intermediate MTProto framing
|
||||||
pub struct SecureIntermediateFrameWriter<W> {
|
pub struct SecureIntermediateFrameWriter<W> {
|
||||||
upstream: W,
|
upstream: W,
|
||||||
|
rng: Arc<SecureRandom>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<W> SecureIntermediateFrameWriter<W> {
|
impl<W> SecureIntermediateFrameWriter<W> {
|
||||||
pub fn new(upstream: W) -> Self {
|
pub fn new(upstream: W, rng: Arc<SecureRandom>) -> Self {
|
||||||
Self { upstream }
|
Self { upstream, rng }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -267,8 +268,8 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add random padding (0-3 bytes)
|
// Add random padding (0-3 bytes)
|
||||||
let padding_len = SECURE_RANDOM.range(4);
|
let padding_len = self.rng.range(4);
|
||||||
let padding = SECURE_RANDOM.bytes(padding_len);
|
let padding = self.rng.bytes(padding_len);
|
||||||
|
|
||||||
let total_len = data.len() + padding_len;
|
let total_len = data.len() + padding_len;
|
||||||
let len_bytes = (total_len as u32).to_le_bytes();
|
let len_bytes = (total_len as u32).to_le_bytes();
|
||||||
@@ -454,11 +455,11 @@ pub enum FrameWriterKind<W> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
|
impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
|
||||||
pub fn new(upstream: W, proto_tag: ProtoTag) -> Self {
|
pub fn new(upstream: W, proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
|
||||||
match proto_tag {
|
match proto_tag {
|
||||||
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
|
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
|
||||||
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
|
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
|
||||||
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream)),
|
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream, rng)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -483,6 +484,8 @@ impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use tokio::io::duplex;
|
use tokio::io::duplex;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_abridged_roundtrip() {
|
async fn test_abridged_roundtrip() {
|
||||||
@@ -539,7 +542,7 @@ mod tests {
|
|||||||
async fn test_secure_intermediate_padding() {
|
async fn test_secure_intermediate_padding() {
|
||||||
let (client, server) = duplex(1024);
|
let (client, server) = duplex(1024);
|
||||||
|
|
||||||
let mut writer = SecureIntermediateFrameWriter::new(client);
|
let mut writer = SecureIntermediateFrameWriter::new(client, Arc::new(SecureRandom::new()));
|
||||||
let mut reader = SecureIntermediateFrameReader::new(server);
|
let mut reader = SecureIntermediateFrameReader::new(server);
|
||||||
|
|
||||||
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
|
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
|
||||||
@@ -572,7 +575,7 @@ mod tests {
|
|||||||
async fn test_frame_reader_kind() {
|
async fn test_frame_reader_kind() {
|
||||||
let (client, server) = duplex(1024);
|
let (client, server) = duplex(1024);
|
||||||
|
|
||||||
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate);
|
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate, Arc::new(SecureRandom::new()));
|
||||||
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
|
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
|
||||||
|
|
||||||
let data = vec![1u8, 2, 3, 4];
|
let data = vec![1u8, 2, 3, 4];
|
||||||
|
|||||||
925
src/transport/middle_proxy.rs
Normal file
925
src/transport/middle_proxy.rs
Normal file
@@ -0,0 +1,925 @@
|
|||||||
|
//! Middle Proxy RPC Transport
|
||||||
|
//!
|
||||||
|
//! Implements Telegram Middle-End RPC protocol for routing to ALL DCs (including CDN).
|
||||||
|
//!
|
||||||
|
//! ## Phase 3 fixes:
|
||||||
|
//! - ROOT CAUSE: Use Telegram proxy-secret (binary file) not user secret
|
||||||
|
//! - Streaming handshake response (no fixed-size read deadlock)
|
||||||
|
//! - Health monitoring + reconnection
|
||||||
|
//! - Hex diagnostics for debugging
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
use std::time::Duration;
|
||||||
|
use bytes::{Bytes, BytesMut};
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||||
|
use tokio::time::{timeout, Instant};
|
||||||
|
use tracing::{debug, info, trace, warn, error};
|
||||||
|
|
||||||
|
use crate::crypto::{crc32, derive_middleproxy_keys, AesCbc, SecureRandom};
|
||||||
|
use crate::error::{ProxyError, Result};
|
||||||
|
use crate::protocol::constants::*;
|
||||||
|
|
||||||
|
// ========== Proxy Secret Fetching ==========
|
||||||
|
|
||||||
|
/// Fetch the Telegram proxy-secret binary file.
|
||||||
|
///
|
||||||
|
/// This is NOT the user secret (-S flag, 16 bytes hex for clients).
|
||||||
|
/// This is the infrastructure secret (--aes-pwd in C MTProxy),
|
||||||
|
/// a binary file of 32-512 bytes used for ME RPC key derivation.
|
||||||
|
///
|
||||||
|
/// Strategy: try local cache, then download from Telegram.
|
||||||
|
pub async fn fetch_proxy_secret(cache_path: Option<&str>) -> Result<Vec<u8>> {
|
||||||
|
let cache = cache_path.unwrap_or("proxy-secret");
|
||||||
|
|
||||||
|
// 1. Try local cache (< 24h old)
|
||||||
|
if let Ok(metadata) = tokio::fs::metadata(cache).await {
|
||||||
|
if let Ok(modified) = metadata.modified() {
|
||||||
|
let age = std::time::SystemTime::now()
|
||||||
|
.duration_since(modified)
|
||||||
|
.unwrap_or(Duration::from_secs(u64::MAX));
|
||||||
|
if age < Duration::from_secs(86400) {
|
||||||
|
if let Ok(data) = tokio::fs::read(cache).await {
|
||||||
|
if data.len() >= 32 {
|
||||||
|
info!(
|
||||||
|
path = cache,
|
||||||
|
len = data.len(),
|
||||||
|
age_hours = age.as_secs() / 3600,
|
||||||
|
"Loaded proxy-secret from cache"
|
||||||
|
);
|
||||||
|
return Ok(data);
|
||||||
|
}
|
||||||
|
warn!(path = cache, len = data.len(), "Cached proxy-secret too short");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Download from Telegram
|
||||||
|
info!("Downloading proxy-secret from core.telegram.org...");
|
||||||
|
let data = download_proxy_secret().await?;
|
||||||
|
|
||||||
|
// 3. Cache locally (best-effort)
|
||||||
|
if let Err(e) = tokio::fs::write(cache, &data).await {
|
||||||
|
warn!(error = %e, "Failed to cache proxy-secret (non-fatal)");
|
||||||
|
} else {
|
||||||
|
debug!(path = cache, len = data.len(), "Cached proxy-secret");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn download_proxy_secret() -> Result<Vec<u8>> {
|
||||||
|
let url = "https://core.telegram.org/getProxySecret";
|
||||||
|
let resp = reqwest::get(url)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ProxyError::Proxy(format!("Failed to download proxy-secret: {}", e)))?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
return Err(ProxyError::Proxy(format!(
|
||||||
|
"proxy-secret download HTTP {}", resp.status()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = resp.bytes().await
|
||||||
|
.map_err(|e| ProxyError::Proxy(format!("Read proxy-secret body: {}", e)))?
|
||||||
|
.to_vec();
|
||||||
|
|
||||||
|
if data.len() < 32 {
|
||||||
|
return Err(ProxyError::Proxy(format!(
|
||||||
|
"proxy-secret too short: {} bytes (need >= 32)", data.len()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(len = data.len(), "Downloaded proxy-secret OK");
|
||||||
|
Ok(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== RPC Frame helpers ==========
|
||||||
|
|
||||||
|
/// Build an RPC frame: [len(4) | seq_no(4) | payload | crc32(4)]
|
||||||
|
fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec<u8> {
|
||||||
|
let total_len = (4 + 4 + payload.len() + 4) as u32;
|
||||||
|
let mut f = Vec::with_capacity(total_len as usize);
|
||||||
|
f.extend_from_slice(&total_len.to_le_bytes());
|
||||||
|
f.extend_from_slice(&seq_no.to_le_bytes());
|
||||||
|
f.extend_from_slice(payload);
|
||||||
|
let c = crc32(&f);
|
||||||
|
f.extend_from_slice(&c.to_le_bytes());
|
||||||
|
f
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read one plaintext RPC frame. Returns (seq_no, payload).
|
||||||
|
async fn read_rpc_frame_plaintext(
|
||||||
|
rd: &mut (impl AsyncReadExt + Unpin),
|
||||||
|
) -> Result<(i32, Vec<u8>)> {
|
||||||
|
let mut len_buf = [0u8; 4];
|
||||||
|
rd.read_exact(&mut len_buf).await.map_err(ProxyError::Io)?;
|
||||||
|
let total_len = u32::from_le_bytes(len_buf) as usize;
|
||||||
|
|
||||||
|
if total_len < 12 || total_len > (1 << 24) {
|
||||||
|
return Err(ProxyError::InvalidHandshake(
|
||||||
|
format!("Bad RPC frame length: {}", total_len),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut rest = vec![0u8; total_len - 4];
|
||||||
|
rd.read_exact(&mut rest).await.map_err(ProxyError::Io)?;
|
||||||
|
|
||||||
|
let mut full = Vec::with_capacity(total_len);
|
||||||
|
full.extend_from_slice(&len_buf);
|
||||||
|
full.extend_from_slice(&rest);
|
||||||
|
|
||||||
|
let crc_offset = total_len - 4;
|
||||||
|
let expected_crc = u32::from_le_bytes([
|
||||||
|
full[crc_offset], full[crc_offset + 1],
|
||||||
|
full[crc_offset + 2], full[crc_offset + 3],
|
||||||
|
]);
|
||||||
|
let actual_crc = crc32(&full[..crc_offset]);
|
||||||
|
if expected_crc != actual_crc {
|
||||||
|
return Err(ProxyError::InvalidHandshake(
|
||||||
|
format!("CRC mismatch: 0x{:08x} vs 0x{:08x}", expected_crc, actual_crc),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let seq_no = i32::from_le_bytes([full[4], full[5], full[6], full[7]]);
|
||||||
|
let payload = full[8..crc_offset].to_vec();
|
||||||
|
Ok((seq_no, payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== RPC Nonce (32 bytes payload) ==========
|
||||||
|
|
||||||
|
fn build_nonce_payload(key_selector: u32, crypto_ts: u32, nonce: &[u8; 16]) -> [u8; 32] {
|
||||||
|
let mut p = [0u8; 32];
|
||||||
|
p[0..4].copy_from_slice(&RPC_NONCE_U32.to_le_bytes());
|
||||||
|
p[4..8].copy_from_slice(&key_selector.to_le_bytes());
|
||||||
|
p[8..12].copy_from_slice(&RPC_CRYPTO_AES_U32.to_le_bytes());
|
||||||
|
p[12..16].copy_from_slice(&crypto_ts.to_le_bytes());
|
||||||
|
p[16..32].copy_from_slice(nonce);
|
||||||
|
p
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_nonce_payload(d: &[u8]) -> Result<(u32, u32, [u8; 16])> {
|
||||||
|
if d.len() < 32 {
|
||||||
|
return Err(ProxyError::InvalidHandshake(
|
||||||
|
format!("Nonce payload too short: {} bytes", d.len()),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let t = u32::from_le_bytes([d[0], d[1], d[2], d[3]]);
|
||||||
|
if t != RPC_NONCE_U32 {
|
||||||
|
return Err(ProxyError::InvalidHandshake(
|
||||||
|
format!("Expected RPC_NONCE 0x{:08x}, got 0x{:08x}", RPC_NONCE_U32, t),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let schema = u32::from_le_bytes([d[8], d[9], d[10], d[11]]);
|
||||||
|
let ts = u32::from_le_bytes([d[12], d[13], d[14], d[15]]);
|
||||||
|
let mut nonce = [0u8; 16];
|
||||||
|
nonce.copy_from_slice(&d[16..32]);
|
||||||
|
Ok((schema, ts, nonce))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== RPC Handshake (32 bytes payload) ==========
|
||||||
|
|
||||||
|
fn build_handshake_payload(our_ip: u32, our_port: u16, peer_ip: u32, peer_port: u16) -> [u8; 32] {
|
||||||
|
let mut p = [0u8; 32];
|
||||||
|
p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes());
|
||||||
|
// flags = 0 at offset 4..8
|
||||||
|
|
||||||
|
// sender_pid: {ip(4), port(2), pid(2), utime(4)} at offset 8..20
|
||||||
|
p[8..12].copy_from_slice(&our_ip.to_le_bytes());
|
||||||
|
p[12..14].copy_from_slice(&our_port.to_le_bytes());
|
||||||
|
let pid = (std::process::id() & 0xFFFF) as u16;
|
||||||
|
p[14..16].copy_from_slice(&pid.to_le_bytes());
|
||||||
|
let utime = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs() as u32;
|
||||||
|
p[16..20].copy_from_slice(&utime.to_le_bytes());
|
||||||
|
|
||||||
|
// peer_pid: {ip(4), port(2), pid(2), utime(4)} at offset 20..32
|
||||||
|
p[20..24].copy_from_slice(&peer_ip.to_le_bytes());
|
||||||
|
p[24..26].copy_from_slice(&peer_port.to_le_bytes());
|
||||||
|
p
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== CBC helpers ==========
|
||||||
|
|
||||||
|
fn cbc_encrypt_padded(key: &[u8; 32], iv: &[u8; 16], plaintext: &[u8]) -> Result<(Vec<u8>, [u8; 16])> {
|
||||||
|
let pad = (16 - (plaintext.len() % 16)) % 16;
|
||||||
|
let mut buf = plaintext.to_vec();
|
||||||
|
let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00];
|
||||||
|
for i in 0..pad {
|
||||||
|
buf.push(pad_pattern[i % 4]);
|
||||||
|
}
|
||||||
|
let cipher = AesCbc::new(*key, *iv);
|
||||||
|
cipher.encrypt_in_place(&mut buf)
|
||||||
|
.map_err(|e| ProxyError::Crypto(format!("CBC encrypt: {}", e)))?;
|
||||||
|
let mut new_iv = [0u8; 16];
|
||||||
|
if buf.len() >= 16 {
|
||||||
|
new_iv.copy_from_slice(&buf[buf.len() - 16..]);
|
||||||
|
}
|
||||||
|
Ok((buf, new_iv))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cbc_decrypt_inplace(key: &[u8; 32], iv: &[u8; 16], data: &mut [u8]) -> Result<[u8; 16]> {
|
||||||
|
let mut new_iv = [0u8; 16];
|
||||||
|
if data.len() >= 16 {
|
||||||
|
new_iv.copy_from_slice(&data[data.len() - 16..]);
|
||||||
|
}
|
||||||
|
AesCbc::new(*key, *iv)
|
||||||
|
.decrypt_in_place(data)
|
||||||
|
.map_err(|e| ProxyError::Crypto(format!("CBC decrypt: {}", e)))?;
|
||||||
|
Ok(new_iv)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== IPv4 helpers ==========
|
||||||
|
|
||||||
|
fn ipv4_to_mapped_v6(ip: Ipv4Addr) -> [u8; 16] {
|
||||||
|
let mut buf = [0u8; 16];
|
||||||
|
buf[10] = 0xFF;
|
||||||
|
buf[11] = 0xFF;
|
||||||
|
let o = ip.octets();
|
||||||
|
buf[12] = o[0]; buf[13] = o[1]; buf[14] = o[2]; buf[15] = o[3];
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
|
||||||
|
fn addr_to_ip_u32(addr: &SocketAddr) -> u32 {
|
||||||
|
match addr.ip() {
|
||||||
|
IpAddr::V4(v4) => u32::from_be_bytes(v4.octets()),
|
||||||
|
IpAddr::V6(v6) => {
|
||||||
|
if let Some(v4) = v6.to_ipv4_mapped() {
|
||||||
|
u32::from_be_bytes(v4.octets())
|
||||||
|
} else { 0 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== ME Response ==========
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum MeResponse {
|
||||||
|
Data(Bytes),
|
||||||
|
Ack(u32),
|
||||||
|
Close,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== Connection Registry ==========
|
||||||
|
|
||||||
|
pub struct ConnRegistry {
|
||||||
|
map: RwLock<HashMap<u64, mpsc::Sender<MeResponse>>>,
|
||||||
|
next_id: AtomicU64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnRegistry {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
map: RwLock::new(HashMap::new()),
|
||||||
|
next_id: AtomicU64::new(1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub async fn register(&self) -> (u64, mpsc::Receiver<MeResponse>) {
|
||||||
|
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
|
||||||
|
let (tx, rx) = mpsc::channel(256);
|
||||||
|
self.map.write().await.insert(id, tx);
|
||||||
|
(id, rx)
|
||||||
|
}
|
||||||
|
pub async fn unregister(&self, id: u64) {
|
||||||
|
self.map.write().await.remove(&id);
|
||||||
|
}
|
||||||
|
pub async fn route(&self, id: u64, resp: MeResponse) -> bool {
|
||||||
|
let m = self.map.read().await;
|
||||||
|
if let Some(tx) = m.get(&id) {
|
||||||
|
tx.send(resp).await.is_ok()
|
||||||
|
} else { false }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== RPC Writer (streaming CBC) ==========
|
||||||
|
|
||||||
|
struct RpcWriter {
|
||||||
|
writer: tokio::io::WriteHalf<TcpStream>,
|
||||||
|
key: [u8; 32],
|
||||||
|
iv: [u8; 16],
|
||||||
|
seq_no: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RpcWriter {
|
||||||
|
async fn send(&mut self, payload: &[u8]) -> Result<()> {
|
||||||
|
let frame = build_rpc_frame(self.seq_no, payload);
|
||||||
|
self.seq_no += 1;
|
||||||
|
|
||||||
|
let pad = (16 - (frame.len() % 16)) % 16;
|
||||||
|
let mut buf = frame;
|
||||||
|
let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00];
|
||||||
|
for i in 0..pad {
|
||||||
|
buf.push(pad_pattern[i % 4]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let cipher = AesCbc::new(self.key, self.iv);
|
||||||
|
cipher.encrypt_in_place(&mut buf)
|
||||||
|
.map_err(|e| ProxyError::Crypto(format!("{}", e)))?;
|
||||||
|
|
||||||
|
if buf.len() >= 16 {
|
||||||
|
self.iv.copy_from_slice(&buf[buf.len() - 16..]);
|
||||||
|
}
|
||||||
|
self.writer.write_all(&buf).await.map_err(ProxyError::Io)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== RPC_PROXY_REQ ==========
|
||||||
|
|
||||||
|
|
||||||
|
fn build_proxy_req_payload(
|
||||||
|
conn_id: u64,
|
||||||
|
client_addr: SocketAddr,
|
||||||
|
our_addr: SocketAddr,
|
||||||
|
data: &[u8],
|
||||||
|
proxy_tag: Option<&[u8]>,
|
||||||
|
proto_flags: u32,
|
||||||
|
) -> Vec<u8> {
|
||||||
|
// flags are pre-calculated by proto_flags_for_tag
|
||||||
|
// We just need to ensure FLAG_HAS_AD_TAG is set if we have a tag (it is set by default in our new function, but let's be safe)
|
||||||
|
let mut flags = proto_flags;
|
||||||
|
|
||||||
|
// The C code logic:
|
||||||
|
// flags = (transport_flags) | 0x1000 | 0x20000 | 0x8 (if tag)
|
||||||
|
// Our proto_flags_for_tag returns: 0x8 | 0x1000 | 0x20000 | transport_flags
|
||||||
|
// So we are good.
|
||||||
|
|
||||||
|
let b_cap = 128 + data.len();
|
||||||
|
let mut b = Vec::with_capacity(b_cap);
|
||||||
|
|
||||||
|
b.extend_from_slice(&RPC_PROXY_REQ_U32.to_le_bytes());
|
||||||
|
b.extend_from_slice(&flags.to_le_bytes());
|
||||||
|
b.extend_from_slice(&conn_id.to_le_bytes());
|
||||||
|
|
||||||
|
// Client IP (16 bytes IPv4-mapped-v6) + port (4 bytes)
|
||||||
|
match client_addr.ip() {
|
||||||
|
IpAddr::V4(v4) => b.extend_from_slice(&ipv4_to_mapped_v6(v4)),
|
||||||
|
IpAddr::V6(v6) => b.extend_from_slice(&v6.octets()),
|
||||||
|
}
|
||||||
|
b.extend_from_slice(&(client_addr.port() as u32).to_le_bytes());
|
||||||
|
|
||||||
|
// Our IP (16 bytes) + port (4 bytes)
|
||||||
|
match our_addr.ip() {
|
||||||
|
IpAddr::V4(v4) => b.extend_from_slice(&ipv4_to_mapped_v6(v4)),
|
||||||
|
IpAddr::V6(v6) => b.extend_from_slice(&v6.octets()),
|
||||||
|
}
|
||||||
|
b.extend_from_slice(&(our_addr.port() as u32).to_le_bytes());
|
||||||
|
|
||||||
|
// Extra section (proxy_tag)
|
||||||
|
if flags & 12 != 0 {
|
||||||
|
let extra_start = b.len();
|
||||||
|
b.extend_from_slice(&0u32.to_le_bytes()); // placeholder
|
||||||
|
|
||||||
|
if let Some(tag) = proxy_tag {
|
||||||
|
b.extend_from_slice(&TL_PROXY_TAG_U32.to_le_bytes());
|
||||||
|
// TL string encoding
|
||||||
|
if tag.len() < 254 {
|
||||||
|
b.push(tag.len() as u8);
|
||||||
|
b.extend_from_slice(tag);
|
||||||
|
let pad = (4 - ((1 + tag.len()) % 4)) % 4;
|
||||||
|
b.extend(std::iter::repeat(0u8).take(pad));
|
||||||
|
} else {
|
||||||
|
b.push(0xfe);
|
||||||
|
let len_bytes = (tag.len() as u32).to_le_bytes();
|
||||||
|
b.extend_from_slice(&len_bytes[..3]);
|
||||||
|
b.extend_from_slice(tag);
|
||||||
|
let pad = (4 - (tag.len() % 4)) % 4;
|
||||||
|
b.extend(std::iter::repeat(0u8).take(pad));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let extra_bytes = (b.len() - extra_start - 4) as u32;
|
||||||
|
let eb = extra_bytes.to_le_bytes();
|
||||||
|
b[extra_start..extra_start + 4].copy_from_slice(&eb);
|
||||||
|
}
|
||||||
|
|
||||||
|
b.extend_from_slice(data);
|
||||||
|
b
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== ME Pool ==========
|
||||||
|
|
||||||
|
pub struct MePool {
|
||||||
|
registry: Arc<ConnRegistry>,
|
||||||
|
writers: Arc<RwLock<Vec<Arc<Mutex<RpcWriter>>>>>,
|
||||||
|
rr: AtomicU64,
|
||||||
|
proxy_tag: Option<Vec<u8>>,
|
||||||
|
/// Telegram proxy-secret (binary, 32-512 bytes)
|
||||||
|
proxy_secret: Vec<u8>,
|
||||||
|
pool_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MePool {
|
||||||
|
pub fn new(proxy_tag: Option<Vec<u8>>, proxy_secret: Vec<u8>) -> Arc<Self> {
|
||||||
|
Arc::new(Self {
|
||||||
|
registry: Arc::new(ConnRegistry::new()),
|
||||||
|
writers: Arc::new(RwLock::new(Vec::new())),
|
||||||
|
rr: AtomicU64::new(0),
|
||||||
|
proxy_tag,
|
||||||
|
proxy_secret,
|
||||||
|
pool_size: 2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn registry(&self) -> &Arc<ConnRegistry> {
|
||||||
|
&self.registry
|
||||||
|
}
|
||||||
|
|
||||||
|
fn writers_arc(&self) -> Arc<RwLock<Vec<Arc<Mutex<RpcWriter>>>>> {
|
||||||
|
self.writers.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// key_selector = first 4 bytes of proxy-secret as LE u32
|
||||||
|
/// C: main_secret.key_signature via union { char secret[]; int key_signature; }
|
||||||
|
fn key_selector(&self) -> u32 {
|
||||||
|
if self.proxy_secret.len() >= 4 {
|
||||||
|
u32::from_le_bytes([
|
||||||
|
self.proxy_secret[0], self.proxy_secret[1],
|
||||||
|
self.proxy_secret[2], self.proxy_secret[3],
|
||||||
|
])
|
||||||
|
} else { 0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn init(
|
||||||
|
self: &Arc<Self>,
|
||||||
|
pool_size: usize,
|
||||||
|
rng: &SecureRandom,
|
||||||
|
) -> Result<()> {
|
||||||
|
let addrs = &*TG_MIDDLE_PROXIES_FLAT_V4;
|
||||||
|
let ks = self.key_selector();
|
||||||
|
info!(
|
||||||
|
me_servers = addrs.len(),
|
||||||
|
pool_size,
|
||||||
|
key_selector = format_args!("0x{:08x}", ks),
|
||||||
|
secret_len = self.proxy_secret.len(),
|
||||||
|
"Initializing ME pool"
|
||||||
|
);
|
||||||
|
|
||||||
|
for &(ip, port) in addrs.iter() {
|
||||||
|
for i in 0..pool_size {
|
||||||
|
let addr = SocketAddr::new(ip, port);
|
||||||
|
match self.connect_one(addr, rng).await {
|
||||||
|
Ok(()) => info!(%addr, idx = i, "ME connected"),
|
||||||
|
Err(e) => warn!(%addr, idx = i, error = %e, "ME connect failed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if self.writers.read().await.len() >= pool_size {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.writers.read().await.is_empty() {
|
||||||
|
return Err(ProxyError::Proxy("No ME connections".into()));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_one(
|
||||||
|
self: &Arc<Self>,
|
||||||
|
addr: SocketAddr,
|
||||||
|
rng: &SecureRandom,
|
||||||
|
) -> Result<()> {
|
||||||
|
let secret = &self.proxy_secret;
|
||||||
|
if secret.len() < 32 {
|
||||||
|
return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== TCP connect =====
|
||||||
|
let stream = timeout(
|
||||||
|
Duration::from_secs(ME_CONNECT_TIMEOUT_SECS),
|
||||||
|
TcpStream::connect(addr),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|_| ProxyError::ConnectionTimeout { addr: addr.to_string() })?
|
||||||
|
.map_err(ProxyError::Io)?;
|
||||||
|
stream.set_nodelay(true).ok();
|
||||||
|
|
||||||
|
let local_addr = stream.local_addr().map_err(ProxyError::Io)?;
|
||||||
|
let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?;
|
||||||
|
let (mut rd, mut wr) = tokio::io::split(stream);
|
||||||
|
|
||||||
|
// ===== 1. Send RPC nonce (plaintext, seq=-2) =====
|
||||||
|
let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap();
|
||||||
|
let crypto_ts = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs() as u32;
|
||||||
|
let ks = self.key_selector();
|
||||||
|
|
||||||
|
let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce);
|
||||||
|
let nonce_frame = build_rpc_frame(-2, &nonce_payload);
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
%addr,
|
||||||
|
frame_len = nonce_frame.len(),
|
||||||
|
key_sel = format_args!("0x{:08x}", ks),
|
||||||
|
crypto_ts,
|
||||||
|
"Sending nonce"
|
||||||
|
);
|
||||||
|
|
||||||
|
wr.write_all(&nonce_frame).await.map_err(ProxyError::Io)?;
|
||||||
|
wr.flush().await.map_err(ProxyError::Io)?;
|
||||||
|
|
||||||
|
// ===== 2. Read server nonce (plaintext, seq=-2) =====
|
||||||
|
let (srv_seq, srv_nonce_payload) = timeout(
|
||||||
|
Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS),
|
||||||
|
read_rpc_frame_plaintext(&mut rd),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|_| ProxyError::TgHandshakeTimeout)??;
|
||||||
|
|
||||||
|
if srv_seq != -2 {
|
||||||
|
return Err(ProxyError::InvalidHandshake(
|
||||||
|
format!("Expected seq=-2, got {}", srv_seq),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let (schema, _srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?;
|
||||||
|
if schema != RPC_CRYPTO_AES_U32 {
|
||||||
|
return Err(ProxyError::InvalidHandshake(
|
||||||
|
format!("Unsupported crypto schema: 0x{:x}", schema),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!(%addr, "Nonce exchange OK, deriving keys");
|
||||||
|
|
||||||
|
// ===== 3. Derive AES-256-CBC keys =====
|
||||||
|
// C buffer layout:
|
||||||
|
// [0..16] nonce_server (srv_nonce)
|
||||||
|
// [16..32] nonce_client (my_nonce)
|
||||||
|
// [32..36] client_timestamp
|
||||||
|
// [36..40] server_ip
|
||||||
|
// [40..42] client_port
|
||||||
|
// [42..48] "CLIENT" or "SERVER"
|
||||||
|
// [48..52] client_ip
|
||||||
|
// [52..54] server_port
|
||||||
|
// [54..54+N] secret (proxy-secret binary)
|
||||||
|
// [54+N..70+N] nonce_server
|
||||||
|
// nonce_client(16)
|
||||||
|
|
||||||
|
let ts_bytes = crypto_ts.to_le_bytes();
|
||||||
|
let server_ip = addr_to_ip_u32(&peer_addr);
|
||||||
|
let client_ip = addr_to_ip_u32(&local_addr);
|
||||||
|
let server_ip_bytes = server_ip.to_le_bytes();
|
||||||
|
let client_ip_bytes = client_ip.to_le_bytes();
|
||||||
|
let server_port_bytes = peer_addr.port().to_le_bytes();
|
||||||
|
let client_port_bytes = local_addr.port().to_le_bytes();
|
||||||
|
|
||||||
|
let (wk, wi) = derive_middleproxy_keys(
|
||||||
|
&srv_nonce, &my_nonce, &ts_bytes,
|
||||||
|
Some(&server_ip_bytes), &client_port_bytes,
|
||||||
|
b"CLIENT",
|
||||||
|
Some(&client_ip_bytes), &server_port_bytes,
|
||||||
|
secret, None, None,
|
||||||
|
);
|
||||||
|
let (rk, ri) = derive_middleproxy_keys(
|
||||||
|
&srv_nonce, &my_nonce, &ts_bytes,
|
||||||
|
Some(&server_ip_bytes), &client_port_bytes,
|
||||||
|
b"SERVER",
|
||||||
|
Some(&client_ip_bytes), &server_port_bytes,
|
||||||
|
secret, None, None,
|
||||||
|
);
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
%addr,
|
||||||
|
write_key = %hex::encode(&wk[..8]),
|
||||||
|
read_key = %hex::encode(&rk[..8]),
|
||||||
|
"Keys derived"
|
||||||
|
);
|
||||||
|
|
||||||
|
// ===== 4. Send encrypted handshake (seq=-1) =====
|
||||||
|
let hs_payload = build_handshake_payload(
|
||||||
|
client_ip, local_addr.port(),
|
||||||
|
server_ip, peer_addr.port(),
|
||||||
|
);
|
||||||
|
let hs_frame = build_rpc_frame(-1, &hs_payload);
|
||||||
|
let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?;
|
||||||
|
wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?;
|
||||||
|
wr.flush().await.map_err(ProxyError::Io)?;
|
||||||
|
|
||||||
|
debug!(%addr, enc_len = encrypted_hs.len(), "Sent encrypted handshake");
|
||||||
|
|
||||||
|
// ===== 5. Read encrypted handshake response (STREAMING) =====
|
||||||
|
// Server sends encrypted handshake. C crypto layer may send partial
|
||||||
|
// blocks (only complete 16-byte blocks get encrypted at a time).
|
||||||
|
// We read incrementally and decrypt block-by-block.
|
||||||
|
let deadline = Instant::now() + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS);
|
||||||
|
let mut enc_buf = BytesMut::with_capacity(256);
|
||||||
|
let mut dec_buf = BytesMut::with_capacity(256);
|
||||||
|
let mut read_iv = ri;
|
||||||
|
let mut handshake_ok = false;
|
||||||
|
|
||||||
|
while Instant::now() < deadline && !handshake_ok {
|
||||||
|
let remaining = deadline - Instant::now();
|
||||||
|
let mut tmp = [0u8; 256];
|
||||||
|
let n = match timeout(remaining, rd.read(&mut tmp)).await {
|
||||||
|
Ok(Ok(0)) => return Err(ProxyError::Io(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::UnexpectedEof, "ME closed during handshake",
|
||||||
|
))),
|
||||||
|
Ok(Ok(n)) => n,
|
||||||
|
Ok(Err(e)) => return Err(ProxyError::Io(e)),
|
||||||
|
Err(_) => return Err(ProxyError::TgHandshakeTimeout),
|
||||||
|
};
|
||||||
|
enc_buf.extend_from_slice(&tmp[..n]);
|
||||||
|
|
||||||
|
// Decrypt complete 16-byte blocks
|
||||||
|
let blocks = enc_buf.len() / 16 * 16;
|
||||||
|
if blocks > 0 {
|
||||||
|
let mut chunk = vec![0u8; blocks];
|
||||||
|
chunk.copy_from_slice(&enc_buf[..blocks]);
|
||||||
|
let new_iv = cbc_decrypt_inplace(&rk, &read_iv, &mut chunk)?;
|
||||||
|
read_iv = new_iv;
|
||||||
|
dec_buf.extend_from_slice(&chunk);
|
||||||
|
let _ = enc_buf.split_to(blocks);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse RPC frame from decrypted data
|
||||||
|
while dec_buf.len() >= 4 {
|
||||||
|
let fl = u32::from_le_bytes([
|
||||||
|
dec_buf[0], dec_buf[1], dec_buf[2], dec_buf[3],
|
||||||
|
]) as usize;
|
||||||
|
|
||||||
|
// Skip noop padding
|
||||||
|
if fl == 4 {
|
||||||
|
let _ = dec_buf.split_to(4);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if fl < 12 || fl > (1 << 24) {
|
||||||
|
return Err(ProxyError::InvalidHandshake(
|
||||||
|
format!("Bad HS response frame len: {}", fl),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if dec_buf.len() < fl {
|
||||||
|
break; // need more data
|
||||||
|
}
|
||||||
|
|
||||||
|
let frame = dec_buf.split_to(fl);
|
||||||
|
|
||||||
|
// CRC32 check
|
||||||
|
let pe = fl - 4;
|
||||||
|
let ec = u32::from_le_bytes([
|
||||||
|
frame[pe], frame[pe + 1], frame[pe + 2], frame[pe + 3],
|
||||||
|
]);
|
||||||
|
let ac = crc32(&frame[..pe]);
|
||||||
|
if ec != ac {
|
||||||
|
return Err(ProxyError::InvalidHandshake(
|
||||||
|
format!("HS CRC mismatch: 0x{:08x} vs 0x{:08x}", ec, ac),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check type
|
||||||
|
let hs_type = u32::from_le_bytes([
|
||||||
|
frame[8], frame[9], frame[10], frame[11],
|
||||||
|
]);
|
||||||
|
if hs_type == RPC_HANDSHAKE_ERROR_U32 {
|
||||||
|
let err_code = if frame.len() >= 16 {
|
||||||
|
i32::from_le_bytes([frame[12], frame[13], frame[14], frame[15]])
|
||||||
|
} else { -1 };
|
||||||
|
return Err(ProxyError::InvalidHandshake(
|
||||||
|
format!("ME rejected handshake (error={})", err_code),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if hs_type != RPC_HANDSHAKE_U32 {
|
||||||
|
return Err(ProxyError::InvalidHandshake(
|
||||||
|
format!("Expected HANDSHAKE 0x{:08x}, got 0x{:08x}", RPC_HANDSHAKE_U32, hs_type),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
handshake_ok = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !handshake_ok {
|
||||||
|
return Err(ProxyError::TgHandshakeTimeout);
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(%addr, "RPC handshake OK");
|
||||||
|
|
||||||
|
// ===== 6. Setup writer + reader =====
|
||||||
|
let rpc_w = Arc::new(Mutex::new(RpcWriter {
|
||||||
|
writer: wr,
|
||||||
|
key: wk,
|
||||||
|
iv: write_iv,
|
||||||
|
seq_no: 0,
|
||||||
|
}));
|
||||||
|
self.writers.write().await.push(rpc_w.clone());
|
||||||
|
|
||||||
|
let reg = self.registry.clone();
|
||||||
|
let w_pong = rpc_w.clone();
|
||||||
|
let w_pool = self.writers_arc();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = reader_loop(rd, rk, read_iv, reg, enc_buf, dec_buf, w_pong.clone()).await {
|
||||||
|
warn!(error = %e, "ME reader ended");
|
||||||
|
}
|
||||||
|
// Remove dead writer from pool
|
||||||
|
let mut ws = w_pool.write().await;
|
||||||
|
ws.retain(|w| !Arc::ptr_eq(w, &w_pong));
|
||||||
|
info!(remaining = ws.len(), "Dead ME writer removed from pool");
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send_proxy_req(
|
||||||
|
&self,
|
||||||
|
conn_id: u64,
|
||||||
|
client_addr: SocketAddr,
|
||||||
|
our_addr: SocketAddr,
|
||||||
|
data: &[u8],
|
||||||
|
proto_flags: u32,
|
||||||
|
) -> Result<()> {
|
||||||
|
let payload = build_proxy_req_payload(
|
||||||
|
conn_id, client_addr, our_addr, data,
|
||||||
|
self.proxy_tag.as_deref(), proto_flags,
|
||||||
|
);
|
||||||
|
loop {
|
||||||
|
let ws = self.writers.read().await;
|
||||||
|
if ws.is_empty() {
|
||||||
|
return Err(ProxyError::Proxy("All ME connections dead".into()));
|
||||||
|
}
|
||||||
|
let idx = self.rr.fetch_add(1, Ordering::Relaxed) as usize % ws.len();
|
||||||
|
let w = ws[idx].clone();
|
||||||
|
drop(ws);
|
||||||
|
match w.lock().await.send(&payload).await {
|
||||||
|
Ok(()) => return Ok(()),
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "ME write failed, removing dead conn");
|
||||||
|
let mut ws = self.writers.write().await;
|
||||||
|
ws.retain(|o| !Arc::ptr_eq(o, &w));
|
||||||
|
if ws.is_empty() {
|
||||||
|
return Err(ProxyError::Proxy("All ME connections dead".into()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send_close(&self, conn_id: u64) -> Result<()> {
|
||||||
|
let ws = self.writers.read().await;
|
||||||
|
if !ws.is_empty() {
|
||||||
|
let w = ws[0].clone();
|
||||||
|
drop(ws);
|
||||||
|
let mut p = Vec::with_capacity(12);
|
||||||
|
p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes());
|
||||||
|
p.extend_from_slice(&conn_id.to_le_bytes());
|
||||||
|
if let Err(e) = w.lock().await.send(&p).await {
|
||||||
|
debug!(error = %e, "ME close write failed");
|
||||||
|
let mut ws = self.writers.write().await;
|
||||||
|
ws.retain(|o| !Arc::ptr_eq(o, &w));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.registry.unregister(conn_id).await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn connection_count(&self) -> usize {
|
||||||
|
self.writers.try_read().map(|w| w.len()).unwrap_or(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== Reader Loop ==========
|
||||||
|
|
||||||
|
async fn reader_loop(
|
||||||
|
mut rd: tokio::io::ReadHalf<TcpStream>,
|
||||||
|
dk: [u8; 32],
|
||||||
|
mut div: [u8; 16],
|
||||||
|
reg: Arc<ConnRegistry>,
|
||||||
|
mut enc_leftover: BytesMut,
|
||||||
|
mut dec: BytesMut,
|
||||||
|
writer: Arc<Mutex<RpcWriter>>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut raw = enc_leftover;
|
||||||
|
loop {
|
||||||
|
let mut tmp = [0u8; 16384];
|
||||||
|
let n = rd.read(&mut tmp).await.map_err(ProxyError::Io)?;
|
||||||
|
if n == 0 { return Ok(()); }
|
||||||
|
raw.extend_from_slice(&tmp[..n]);
|
||||||
|
|
||||||
|
// Decrypt complete 16-byte blocks
|
||||||
|
let blocks = raw.len() / 16 * 16;
|
||||||
|
if blocks > 0 {
|
||||||
|
let mut new_iv = [0u8; 16];
|
||||||
|
new_iv.copy_from_slice(&raw[blocks - 16..blocks]);
|
||||||
|
let mut chunk = vec![0u8; blocks];
|
||||||
|
chunk.copy_from_slice(&raw[..blocks]);
|
||||||
|
AesCbc::new(dk, div)
|
||||||
|
.decrypt_in_place(&mut chunk)
|
||||||
|
.map_err(|e| ProxyError::Crypto(format!("{}", e)))?;
|
||||||
|
div = new_iv;
|
||||||
|
dec.extend_from_slice(&chunk);
|
||||||
|
let _ = raw.split_to(blocks);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse RPC frames
|
||||||
|
while dec.len() >= 12 {
|
||||||
|
let fl = u32::from_le_bytes([dec[0], dec[1], dec[2], dec[3]]) as usize;
|
||||||
|
if fl == 4 { let _ = dec.split_to(4); continue; }
|
||||||
|
if fl < 12 || fl > (1 << 24) {
|
||||||
|
warn!(frame_len = fl, "Invalid RPC frame len");
|
||||||
|
dec.clear();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if dec.len() < fl { break; }
|
||||||
|
|
||||||
|
let frame = dec.split_to(fl);
|
||||||
|
let pe = fl - 4;
|
||||||
|
let ec = u32::from_le_bytes([frame[pe], frame[pe+1], frame[pe+2], frame[pe+3]]);
|
||||||
|
if crc32(&frame[..pe]) != ec {
|
||||||
|
warn!("CRC mismatch in data frame");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let payload = &frame[8..pe];
|
||||||
|
if payload.len() < 4 { continue; }
|
||||||
|
let pt = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
|
||||||
|
let body = &payload[4..];
|
||||||
|
|
||||||
|
if pt == RPC_PROXY_ANS_U32 && body.len() >= 12 {
|
||||||
|
let flags = u32::from_le_bytes(body[0..4].try_into().unwrap());
|
||||||
|
let cid = u64::from_le_bytes(body[4..12].try_into().unwrap());
|
||||||
|
let data = Bytes::copy_from_slice(&body[12..]);
|
||||||
|
trace!(cid, len = data.len(), flags, "ANS");
|
||||||
|
reg.route(cid, MeResponse::Data(data)).await;
|
||||||
|
} else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 {
|
||||||
|
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
|
||||||
|
let cfm = u32::from_le_bytes(body[8..12].try_into().unwrap());
|
||||||
|
trace!(cid, cfm, "ACK");
|
||||||
|
reg.route(cid, MeResponse::Ack(cfm)).await;
|
||||||
|
} else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 {
|
||||||
|
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
|
||||||
|
debug!(cid, "CLOSE_EXT from ME");
|
||||||
|
reg.route(cid, MeResponse::Close).await;
|
||||||
|
reg.unregister(cid).await;
|
||||||
|
} else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 {
|
||||||
|
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
|
||||||
|
debug!(cid, "CLOSE_CONN from ME");
|
||||||
|
reg.route(cid, MeResponse::Close).await;
|
||||||
|
reg.unregister(cid).await;
|
||||||
|
} else if pt == RPC_PING_U32 && body.len() >= 8 {
|
||||||
|
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
|
||||||
|
trace!(ping_id, "RPC_PING -> PONG");
|
||||||
|
let mut pong = Vec::with_capacity(12);
|
||||||
|
pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes());
|
||||||
|
pong.extend_from_slice(&ping_id.to_le_bytes());
|
||||||
|
if let Err(e) = writer.lock().await.send(&pong).await {
|
||||||
|
warn!(error = %e, "PONG send failed");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
debug!(rpc_type = format_args!("0x{:08x}", pt), len = body.len(), "Unknown RPC");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== Proto flags ==========
|
||||||
|
|
||||||
|
/// Map ProtoTag to C-compatible RPC_PROXY_REQ transport flags.
|
||||||
|
/// C: RPC_F_COMPACT(0x40000000)=abridged, RPC_F_MEDIUM(0x20000000)=intermediate/secure
|
||||||
|
/// The 0x1000(magic) and 0x8(proxy_tag) are added inside build_proxy_req_payload.
|
||||||
|
|
||||||
|
pub fn proto_flags_for_tag(tag: crate::protocol::constants::ProtoTag) -> u32 {
|
||||||
|
use crate::protocol::constants::*;
|
||||||
|
let mut flags = RPC_FLAG_HAS_AD_TAG | RPC_FLAG_MAGIC | RPC_FLAG_EXTMODE2;
|
||||||
|
match tag {
|
||||||
|
ProtoTag::Abridged => flags | RPC_FLAG_ABRIDGED,
|
||||||
|
ProtoTag::Intermediate => flags | RPC_FLAG_INTERMEDIATE,
|
||||||
|
ProtoTag::Secure => flags | RPC_FLAG_PAD | RPC_FLAG_INTERMEDIATE,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ========== Health Monitor (Phase 4) ==========
|
||||||
|
|
||||||
|
pub async fn me_health_monitor(
|
||||||
|
pool: Arc<MePool>,
|
||||||
|
rng: Arc<SecureRandom>,
|
||||||
|
min_connections: usize,
|
||||||
|
) {
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||||
|
let current = pool.writers.read().await.len();
|
||||||
|
if current < min_connections {
|
||||||
|
warn!(current, min = min_connections, "ME pool below minimum, reconnecting...");
|
||||||
|
let addrs = TG_MIDDLE_PROXIES_FLAT_V4.clone();
|
||||||
|
for &(ip, port) in addrs.iter() {
|
||||||
|
let needed = min_connections.saturating_sub(pool.writers.read().await.len());
|
||||||
|
if needed == 0 { break; }
|
||||||
|
for _ in 0..needed {
|
||||||
|
let addr = SocketAddr::new(ip, port);
|
||||||
|
match pool.connect_one(addr, &rng).await {
|
||||||
|
Ok(()) => info!(%addr, "ME reconnected"),
|
||||||
|
Err(e) => debug!(%addr, error = %e, "ME reconnect failed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
179
src/transport/middle_proxy/codec.rs
Normal file
179
src/transport/middle_proxy/codec.rs
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
|
||||||
|
use crate::crypto::{AesCbc, crc32};
|
||||||
|
use crate::error::{ProxyError, Result};
|
||||||
|
use crate::protocol::constants::*;
|
||||||
|
|
||||||
|
pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec<u8> {
|
||||||
|
let total_len = (4 + 4 + payload.len() + 4) as u32;
|
||||||
|
let mut frame = Vec::with_capacity(total_len as usize);
|
||||||
|
frame.extend_from_slice(&total_len.to_le_bytes());
|
||||||
|
frame.extend_from_slice(&seq_no.to_le_bytes());
|
||||||
|
frame.extend_from_slice(payload);
|
||||||
|
let c = crc32(&frame);
|
||||||
|
frame.extend_from_slice(&c.to_le_bytes());
|
||||||
|
frame
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn read_rpc_frame_plaintext(
|
||||||
|
rd: &mut (impl AsyncReadExt + Unpin),
|
||||||
|
) -> Result<(i32, Vec<u8>)> {
|
||||||
|
let mut len_buf = [0u8; 4];
|
||||||
|
rd.read_exact(&mut len_buf).await.map_err(ProxyError::Io)?;
|
||||||
|
let total_len = u32::from_le_bytes(len_buf) as usize;
|
||||||
|
|
||||||
|
if !(12..=(1 << 24)).contains(&total_len) {
|
||||||
|
return Err(ProxyError::InvalidHandshake(format!(
|
||||||
|
"Bad RPC frame length: {total_len}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut rest = vec![0u8; total_len - 4];
|
||||||
|
rd.read_exact(&mut rest).await.map_err(ProxyError::Io)?;
|
||||||
|
|
||||||
|
let mut full = Vec::with_capacity(total_len);
|
||||||
|
full.extend_from_slice(&len_buf);
|
||||||
|
full.extend_from_slice(&rest);
|
||||||
|
|
||||||
|
let crc_offset = total_len - 4;
|
||||||
|
let expected_crc = u32::from_le_bytes(full[crc_offset..crc_offset + 4].try_into().unwrap());
|
||||||
|
let actual_crc = crc32(&full[..crc_offset]);
|
||||||
|
if expected_crc != actual_crc {
|
||||||
|
return Err(ProxyError::InvalidHandshake(format!(
|
||||||
|
"CRC mismatch: 0x{expected_crc:08x} vs 0x{actual_crc:08x}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let seq_no = i32::from_le_bytes(full[4..8].try_into().unwrap());
|
||||||
|
let payload = full[8..crc_offset].to_vec();
|
||||||
|
Ok((seq_no, payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn build_nonce_payload(key_selector: u32, crypto_ts: u32, nonce: &[u8; 16]) -> [u8; 32] {
|
||||||
|
let mut p = [0u8; 32];
|
||||||
|
p[0..4].copy_from_slice(&RPC_NONCE_U32.to_le_bytes());
|
||||||
|
p[4..8].copy_from_slice(&key_selector.to_le_bytes());
|
||||||
|
p[8..12].copy_from_slice(&RPC_CRYPTO_AES_U32.to_le_bytes());
|
||||||
|
p[12..16].copy_from_slice(&crypto_ts.to_le_bytes());
|
||||||
|
p[16..32].copy_from_slice(nonce);
|
||||||
|
p
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn parse_nonce_payload(d: &[u8]) -> Result<(u32, u32, u32, [u8; 16])> {
|
||||||
|
if d.len() < 32 {
|
||||||
|
return Err(ProxyError::InvalidHandshake(format!(
|
||||||
|
"Nonce payload too short: {} bytes",
|
||||||
|
d.len()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let t = u32::from_le_bytes(d[0..4].try_into().unwrap());
|
||||||
|
if t != RPC_NONCE_U32 {
|
||||||
|
return Err(ProxyError::InvalidHandshake(format!(
|
||||||
|
"Expected RPC_NONCE 0x{RPC_NONCE_U32:08x}, got 0x{t:08x}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let key_select = u32::from_le_bytes(d[4..8].try_into().unwrap());
|
||||||
|
let schema = u32::from_le_bytes(d[8..12].try_into().unwrap());
|
||||||
|
let ts = u32::from_le_bytes(d[12..16].try_into().unwrap());
|
||||||
|
let mut nonce = [0u8; 16];
|
||||||
|
nonce.copy_from_slice(&d[16..32]);
|
||||||
|
Ok((key_select, schema, ts, nonce))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn build_handshake_payload(
|
||||||
|
our_ip: [u8; 4],
|
||||||
|
our_port: u16,
|
||||||
|
peer_ip: [u8; 4],
|
||||||
|
peer_port: u16,
|
||||||
|
) -> [u8; 32] {
|
||||||
|
let mut p = [0u8; 32];
|
||||||
|
p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes());
|
||||||
|
|
||||||
|
// Keep C memory layout compatibility for PID IPv4 bytes.
|
||||||
|
p[8..12].copy_from_slice(&our_ip);
|
||||||
|
p[12..14].copy_from_slice(&our_port.to_le_bytes());
|
||||||
|
let pid = (std::process::id() & 0xffff) as u16;
|
||||||
|
p[14..16].copy_from_slice(&pid.to_le_bytes());
|
||||||
|
let utime = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs() as u32;
|
||||||
|
p[16..20].copy_from_slice(&utime.to_le_bytes());
|
||||||
|
|
||||||
|
p[20..24].copy_from_slice(&peer_ip);
|
||||||
|
p[24..26].copy_from_slice(&peer_port.to_le_bytes());
|
||||||
|
p
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn cbc_encrypt_padded(
|
||||||
|
key: &[u8; 32],
|
||||||
|
iv: &[u8; 16],
|
||||||
|
plaintext: &[u8],
|
||||||
|
) -> Result<(Vec<u8>, [u8; 16])> {
|
||||||
|
let pad = (16 - (plaintext.len() % 16)) % 16;
|
||||||
|
let mut buf = plaintext.to_vec();
|
||||||
|
let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00];
|
||||||
|
for i in 0..pad {
|
||||||
|
buf.push(pad_pattern[i % 4]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let cipher = AesCbc::new(*key, *iv);
|
||||||
|
cipher
|
||||||
|
.encrypt_in_place(&mut buf)
|
||||||
|
.map_err(|e| ProxyError::Crypto(format!("CBC encrypt: {e}")))?;
|
||||||
|
|
||||||
|
let mut new_iv = [0u8; 16];
|
||||||
|
if buf.len() >= 16 {
|
||||||
|
new_iv.copy_from_slice(&buf[buf.len() - 16..]);
|
||||||
|
}
|
||||||
|
Ok((buf, new_iv))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn cbc_decrypt_inplace(
|
||||||
|
key: &[u8; 32],
|
||||||
|
iv: &[u8; 16],
|
||||||
|
data: &mut [u8],
|
||||||
|
) -> Result<[u8; 16]> {
|
||||||
|
let mut new_iv = [0u8; 16];
|
||||||
|
if data.len() >= 16 {
|
||||||
|
new_iv.copy_from_slice(&data[data.len() - 16..]);
|
||||||
|
}
|
||||||
|
|
||||||
|
AesCbc::new(*key, *iv)
|
||||||
|
.decrypt_in_place(data)
|
||||||
|
.map_err(|e| ProxyError::Crypto(format!("CBC decrypt: {e}")))?;
|
||||||
|
Ok(new_iv)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct RpcWriter {
|
||||||
|
pub(crate) writer: tokio::io::WriteHalf<tokio::net::TcpStream>,
|
||||||
|
pub(crate) key: [u8; 32],
|
||||||
|
pub(crate) iv: [u8; 16],
|
||||||
|
pub(crate) seq_no: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RpcWriter {
|
||||||
|
pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> {
|
||||||
|
let frame = build_rpc_frame(self.seq_no, payload);
|
||||||
|
self.seq_no += 1;
|
||||||
|
|
||||||
|
let pad = (16 - (frame.len() % 16)) % 16;
|
||||||
|
let mut buf = frame;
|
||||||
|
let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00];
|
||||||
|
for i in 0..pad {
|
||||||
|
buf.push(pad_pattern[i % 4]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let cipher = AesCbc::new(self.key, self.iv);
|
||||||
|
cipher
|
||||||
|
.encrypt_in_place(&mut buf)
|
||||||
|
.map_err(|e| ProxyError::Crypto(format!("{e}")))?;
|
||||||
|
|
||||||
|
if buf.len() >= 16 {
|
||||||
|
self.iv.copy_from_slice(&buf[buf.len() - 16..]);
|
||||||
|
}
|
||||||
|
self.writer.write_all(&buf).await.map_err(ProxyError::Io)
|
||||||
|
}
|
||||||
|
}
|
||||||
38
src/transport/middle_proxy/health.rs
Normal file
38
src/transport/middle_proxy/health.rs
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
|
use crate::protocol::constants::TG_MIDDLE_PROXIES_FLAT_V4;
|
||||||
|
|
||||||
|
use super::MePool;
|
||||||
|
|
||||||
|
pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, min_connections: usize) {
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||||
|
let current = pool.connection_count();
|
||||||
|
if current < min_connections {
|
||||||
|
warn!(
|
||||||
|
current,
|
||||||
|
min = min_connections,
|
||||||
|
"ME pool below minimum, reconnecting..."
|
||||||
|
);
|
||||||
|
let addrs = TG_MIDDLE_PROXIES_FLAT_V4.clone();
|
||||||
|
for &(ip, port) in addrs.iter() {
|
||||||
|
let needed = min_connections.saturating_sub(pool.connection_count());
|
||||||
|
if needed == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
for _ in 0..needed {
|
||||||
|
let addr = SocketAddr::new(ip, port);
|
||||||
|
match pool.connect_one(addr, &rng).await {
|
||||||
|
Ok(()) => info!(%addr, "ME reconnected"),
|
||||||
|
Err(e) => debug!(%addr, error = %e, "ME reconnect failed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
26
src/transport/middle_proxy/mod.rs
Normal file
26
src/transport/middle_proxy/mod.rs
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
//! Middle Proxy RPC transport.
|
||||||
|
|
||||||
|
mod codec;
|
||||||
|
mod health;
|
||||||
|
mod pool;
|
||||||
|
mod pool_nat;
|
||||||
|
mod reader;
|
||||||
|
mod registry;
|
||||||
|
mod send;
|
||||||
|
mod secret;
|
||||||
|
mod wire;
|
||||||
|
|
||||||
|
use bytes::Bytes;
|
||||||
|
|
||||||
|
pub use health::me_health_monitor;
|
||||||
|
pub use pool::MePool;
|
||||||
|
pub use registry::ConnRegistry;
|
||||||
|
pub use secret::fetch_proxy_secret;
|
||||||
|
pub use wire::proto_flags_for_tag;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum MeResponse {
|
||||||
|
Data { flags: u32, data: Bytes },
|
||||||
|
Ack(u32),
|
||||||
|
Close,
|
||||||
|
}
|
||||||
499
src/transport/middle_proxy/pool.rs
Normal file
499
src/transport/middle_proxy/pool.rs
Normal file
@@ -0,0 +1,499 @@
|
|||||||
|
use std::net::{IpAddr, SocketAddr};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::OnceLock;
|
||||||
|
use std::sync::atomic::AtomicU64;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use bytes::BytesMut;
|
||||||
|
use rand::Rng;
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::sync::{Mutex, RwLock};
|
||||||
|
use tokio::time::{Instant, timeout};
|
||||||
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
|
use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256};
|
||||||
|
use crate::error::{ProxyError, Result};
|
||||||
|
use crate::protocol::constants::*;
|
||||||
|
|
||||||
|
use super::ConnRegistry;
|
||||||
|
use super::codec::{
|
||||||
|
RpcWriter, build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace,
|
||||||
|
cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext,
|
||||||
|
};
|
||||||
|
use super::reader::reader_loop;
|
||||||
|
use super::wire::{IpMaterial, extract_ip_material};
|
||||||
|
|
||||||
|
const ME_ACTIVE_PING_SECS: u64 = 25;
|
||||||
|
const ME_ACTIVE_PING_JITTER_SECS: i64 = 5;
|
||||||
|
|
||||||
|
pub struct MePool {
|
||||||
|
pub(super) registry: Arc<ConnRegistry>,
|
||||||
|
pub(super) writers: Arc<RwLock<Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)>>> ,
|
||||||
|
pub(super) rr: AtomicU64,
|
||||||
|
pub(super) proxy_tag: Option<Vec<u8>>,
|
||||||
|
proxy_secret: Vec<u8>,
|
||||||
|
pub(super) nat_ip_cfg: Option<IpAddr>,
|
||||||
|
pub(super) nat_ip_detected: OnceLock<IpAddr>,
|
||||||
|
pub(super) nat_probe: bool,
|
||||||
|
pub(super) nat_stun: Option<String>,
|
||||||
|
pool_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MePool {
|
||||||
|
pub fn new(
|
||||||
|
proxy_tag: Option<Vec<u8>>,
|
||||||
|
proxy_secret: Vec<u8>,
|
||||||
|
nat_ip: Option<IpAddr>,
|
||||||
|
nat_probe: bool,
|
||||||
|
nat_stun: Option<String>,
|
||||||
|
) -> Arc<Self> {
|
||||||
|
Arc::new(Self {
|
||||||
|
registry: Arc::new(ConnRegistry::new()),
|
||||||
|
writers: Arc::new(RwLock::new(Vec::new())),
|
||||||
|
rr: AtomicU64::new(0),
|
||||||
|
proxy_tag,
|
||||||
|
proxy_secret,
|
||||||
|
nat_ip_cfg: nat_ip,
|
||||||
|
nat_ip_detected: OnceLock::new(),
|
||||||
|
nat_probe,
|
||||||
|
nat_stun,
|
||||||
|
pool_size: 2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn has_proxy_tag(&self) -> bool {
|
||||||
|
self.proxy_tag.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn translate_our_addr(&self, addr: SocketAddr) -> SocketAddr {
|
||||||
|
let ip = self.translate_ip_for_nat(addr.ip());
|
||||||
|
SocketAddr::new(ip, addr.port())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn registry(&self) -> &Arc<ConnRegistry> {
|
||||||
|
&self.registry
|
||||||
|
}
|
||||||
|
|
||||||
|
fn writers_arc(&self) -> Arc<RwLock<Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)>>>
|
||||||
|
{
|
||||||
|
self.writers.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn key_selector(&self) -> u32 {
|
||||||
|
if self.proxy_secret.len() >= 4 {
|
||||||
|
u32::from_le_bytes([
|
||||||
|
self.proxy_secret[0],
|
||||||
|
self.proxy_secret[1],
|
||||||
|
self.proxy_secret[2],
|
||||||
|
self.proxy_secret[3],
|
||||||
|
])
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &SecureRandom) -> Result<()> {
|
||||||
|
let addrs = &*TG_MIDDLE_PROXIES_FLAT_V4;
|
||||||
|
let ks = self.key_selector();
|
||||||
|
info!(
|
||||||
|
me_servers = addrs.len(),
|
||||||
|
pool_size,
|
||||||
|
key_selector = format_args!("0x{ks:08x}"),
|
||||||
|
secret_len = self.proxy_secret.len(),
|
||||||
|
"Initializing ME pool"
|
||||||
|
);
|
||||||
|
|
||||||
|
for &(ip, port) in addrs.iter() {
|
||||||
|
for i in 0..pool_size {
|
||||||
|
let addr = SocketAddr::new(ip, port);
|
||||||
|
match self.connect_one(addr, rng).await {
|
||||||
|
Ok(()) => info!(%addr, idx = i, "ME connected"),
|
||||||
|
Err(e) => warn!(%addr, idx = i, error = %e, "ME connect failed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if self.writers.read().await.len() >= pool_size {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.writers.read().await.is_empty() {
|
||||||
|
return Err(ProxyError::Proxy("No ME connections".into()));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn connect_one(
|
||||||
|
self: &Arc<Self>,
|
||||||
|
addr: SocketAddr,
|
||||||
|
rng: &SecureRandom,
|
||||||
|
) -> Result<()> {
|
||||||
|
let secret = &self.proxy_secret;
|
||||||
|
if secret.len() < 32 {
|
||||||
|
return Err(ProxyError::Proxy(
|
||||||
|
"proxy-secret too short for ME auth".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let stream = timeout(
|
||||||
|
Duration::from_secs(ME_CONNECT_TIMEOUT_SECS),
|
||||||
|
TcpStream::connect(addr),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|_| ProxyError::ConnectionTimeout {
|
||||||
|
addr: addr.to_string(),
|
||||||
|
})?
|
||||||
|
.map_err(ProxyError::Io)?;
|
||||||
|
stream.set_nodelay(true).ok();
|
||||||
|
|
||||||
|
let local_addr = stream.local_addr().map_err(ProxyError::Io)?;
|
||||||
|
let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?;
|
||||||
|
let _ = self.maybe_detect_nat_ip(local_addr.ip()).await;
|
||||||
|
let reflected = if self.nat_probe {
|
||||||
|
self.maybe_reflect_public_addr().await
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected);
|
||||||
|
let peer_addr_nat =
|
||||||
|
SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port());
|
||||||
|
let (mut rd, mut wr) = tokio::io::split(stream);
|
||||||
|
|
||||||
|
let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap();
|
||||||
|
let crypto_ts = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs() as u32;
|
||||||
|
|
||||||
|
let ks = self.key_selector();
|
||||||
|
let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce);
|
||||||
|
let nonce_frame = build_rpc_frame(-2, &nonce_payload);
|
||||||
|
let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]);
|
||||||
|
info!(
|
||||||
|
key_selector = format_args!("0x{ks:08x}"),
|
||||||
|
crypto_ts,
|
||||||
|
frame_len = nonce_frame.len(),
|
||||||
|
nonce_frame_hex = %dump,
|
||||||
|
"Sending ME nonce frame"
|
||||||
|
);
|
||||||
|
wr.write_all(&nonce_frame).await.map_err(ProxyError::Io)?;
|
||||||
|
wr.flush().await.map_err(ProxyError::Io)?;
|
||||||
|
|
||||||
|
let (srv_seq, srv_nonce_payload) = timeout(
|
||||||
|
Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS),
|
||||||
|
read_rpc_frame_plaintext(&mut rd),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|_| ProxyError::TgHandshakeTimeout)??;
|
||||||
|
|
||||||
|
if srv_seq != -2 {
|
||||||
|
return Err(ProxyError::InvalidHandshake(format!(
|
||||||
|
"Expected seq=-2, got {srv_seq}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let (srv_key_select, schema, srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?;
|
||||||
|
if schema != RPC_CRYPTO_AES_U32 {
|
||||||
|
warn!(schema = format_args!("0x{schema:08x}"), "Unsupported ME crypto schema");
|
||||||
|
return Err(ProxyError::InvalidHandshake(format!(
|
||||||
|
"Unsupported crypto schema: 0x{schema:x}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if srv_key_select != ks {
|
||||||
|
return Err(ProxyError::InvalidHandshake(format!(
|
||||||
|
"Server key_select 0x{srv_key_select:08x} != client 0x{ks:08x}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let skew = crypto_ts.abs_diff(srv_ts);
|
||||||
|
if skew > 30 {
|
||||||
|
return Err(ProxyError::InvalidHandshake(format!(
|
||||||
|
"nonce crypto_ts skew too large: client={crypto_ts}, server={srv_ts}, skew={skew}s"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(
|
||||||
|
%local_addr,
|
||||||
|
%local_addr_nat,
|
||||||
|
reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string),
|
||||||
|
%peer_addr,
|
||||||
|
%peer_addr_nat,
|
||||||
|
key_selector = format_args!("0x{ks:08x}"),
|
||||||
|
crypto_schema = format_args!("0x{schema:08x}"),
|
||||||
|
skew_secs = skew,
|
||||||
|
"ME key derivation parameters"
|
||||||
|
);
|
||||||
|
|
||||||
|
let ts_bytes = crypto_ts.to_le_bytes();
|
||||||
|
let server_port_bytes = peer_addr_nat.port().to_le_bytes();
|
||||||
|
let client_port_bytes = local_addr_nat.port().to_le_bytes();
|
||||||
|
|
||||||
|
let server_ip = extract_ip_material(peer_addr_nat);
|
||||||
|
let client_ip = extract_ip_material(local_addr_nat);
|
||||||
|
|
||||||
|
let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) =
|
||||||
|
match (server_ip, client_ip) {
|
||||||
|
(IpMaterial::V4(srv), IpMaterial::V4(clt)) => {
|
||||||
|
(Some(srv), Some(clt), None, None, clt, srv)
|
||||||
|
}
|
||||||
|
(IpMaterial::V6(srv), IpMaterial::V6(clt)) => {
|
||||||
|
let zero = [0u8; 4];
|
||||||
|
(None, None, Some(clt), Some(srv), zero, zero)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(ProxyError::InvalidHandshake(
|
||||||
|
"mixed IPv4/IPv6 endpoints are not supported for ME key derivation"
|
||||||
|
.to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let diag_level: u8 = std::env::var("ME_DIAG")
|
||||||
|
.ok()
|
||||||
|
.and_then(|v| v.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
let prekey_client = build_middleproxy_prekey(
|
||||||
|
&srv_nonce,
|
||||||
|
&my_nonce,
|
||||||
|
&ts_bytes,
|
||||||
|
srv_ip_opt.as_ref().map(|x| &x[..]),
|
||||||
|
&client_port_bytes,
|
||||||
|
b"CLIENT",
|
||||||
|
clt_ip_opt.as_ref().map(|x| &x[..]),
|
||||||
|
&server_port_bytes,
|
||||||
|
secret,
|
||||||
|
clt_v6_opt.as_ref(),
|
||||||
|
srv_v6_opt.as_ref(),
|
||||||
|
);
|
||||||
|
let prekey_server = build_middleproxy_prekey(
|
||||||
|
&srv_nonce,
|
||||||
|
&my_nonce,
|
||||||
|
&ts_bytes,
|
||||||
|
srv_ip_opt.as_ref().map(|x| &x[..]),
|
||||||
|
&client_port_bytes,
|
||||||
|
b"SERVER",
|
||||||
|
clt_ip_opt.as_ref().map(|x| &x[..]),
|
||||||
|
&server_port_bytes,
|
||||||
|
secret,
|
||||||
|
clt_v6_opt.as_ref(),
|
||||||
|
srv_v6_opt.as_ref(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let (wk, wi) = derive_middleproxy_keys(
|
||||||
|
&srv_nonce,
|
||||||
|
&my_nonce,
|
||||||
|
&ts_bytes,
|
||||||
|
srv_ip_opt.as_ref().map(|x| &x[..]),
|
||||||
|
&client_port_bytes,
|
||||||
|
b"CLIENT",
|
||||||
|
clt_ip_opt.as_ref().map(|x| &x[..]),
|
||||||
|
&server_port_bytes,
|
||||||
|
secret,
|
||||||
|
clt_v6_opt.as_ref(),
|
||||||
|
srv_v6_opt.as_ref(),
|
||||||
|
);
|
||||||
|
let (rk, ri) = derive_middleproxy_keys(
|
||||||
|
&srv_nonce,
|
||||||
|
&my_nonce,
|
||||||
|
&ts_bytes,
|
||||||
|
srv_ip_opt.as_ref().map(|x| &x[..]),
|
||||||
|
&client_port_bytes,
|
||||||
|
b"SERVER",
|
||||||
|
clt_ip_opt.as_ref().map(|x| &x[..]),
|
||||||
|
&server_port_bytes,
|
||||||
|
secret,
|
||||||
|
clt_v6_opt.as_ref(),
|
||||||
|
srv_v6_opt.as_ref(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let hs_payload =
|
||||||
|
build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port());
|
||||||
|
let hs_frame = build_rpc_frame(-1, &hs_payload);
|
||||||
|
if diag_level >= 1 {
|
||||||
|
info!(
|
||||||
|
write_key = %hex_dump(&wk),
|
||||||
|
write_iv = %hex_dump(&wi),
|
||||||
|
read_key = %hex_dump(&rk),
|
||||||
|
read_iv = %hex_dump(&ri),
|
||||||
|
srv_ip = %srv_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(),
|
||||||
|
clt_ip = %clt_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(),
|
||||||
|
srv_port = %hex_dump(&server_port_bytes),
|
||||||
|
clt_port = %hex_dump(&client_port_bytes),
|
||||||
|
crypto_ts = %hex_dump(&ts_bytes),
|
||||||
|
nonce_srv = %hex_dump(&srv_nonce),
|
||||||
|
nonce_clt = %hex_dump(&my_nonce),
|
||||||
|
prekey_sha256_client = %hex_dump(&sha256(&prekey_client)),
|
||||||
|
prekey_sha256_server = %hex_dump(&sha256(&prekey_server)),
|
||||||
|
hs_plain = %hex_dump(&hs_frame),
|
||||||
|
proxy_secret_sha256 = %hex_dump(&sha256(secret)),
|
||||||
|
"ME diag: derived keys and handshake plaintext"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if diag_level >= 2 {
|
||||||
|
info!(
|
||||||
|
prekey_client = %hex_dump(&prekey_client),
|
||||||
|
prekey_server = %hex_dump(&prekey_server),
|
||||||
|
"ME diag: full prekey buffers"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?;
|
||||||
|
if diag_level >= 1 {
|
||||||
|
info!(
|
||||||
|
hs_cipher = %hex_dump(&encrypted_hs),
|
||||||
|
"ME diag: handshake ciphertext"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?;
|
||||||
|
wr.flush().await.map_err(ProxyError::Io)?;
|
||||||
|
|
||||||
|
let deadline = Instant::now() + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS);
|
||||||
|
let mut enc_buf = BytesMut::with_capacity(256);
|
||||||
|
let mut dec_buf = BytesMut::with_capacity(256);
|
||||||
|
let mut read_iv = ri;
|
||||||
|
let mut handshake_ok = false;
|
||||||
|
|
||||||
|
while Instant::now() < deadline && !handshake_ok {
|
||||||
|
let remaining = deadline - Instant::now();
|
||||||
|
let mut tmp = [0u8; 256];
|
||||||
|
let n = match timeout(remaining, rd.read(&mut tmp)).await {
|
||||||
|
Ok(Ok(0)) => {
|
||||||
|
return Err(ProxyError::Io(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::UnexpectedEof,
|
||||||
|
"ME closed during handshake",
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Ok(Ok(n)) => n,
|
||||||
|
Ok(Err(e)) => return Err(ProxyError::Io(e)),
|
||||||
|
Err(_) => return Err(ProxyError::TgHandshakeTimeout),
|
||||||
|
};
|
||||||
|
|
||||||
|
enc_buf.extend_from_slice(&tmp[..n]);
|
||||||
|
|
||||||
|
let blocks = enc_buf.len() / 16 * 16;
|
||||||
|
if blocks > 0 {
|
||||||
|
let mut chunk = vec![0u8; blocks];
|
||||||
|
chunk.copy_from_slice(&enc_buf[..blocks]);
|
||||||
|
read_iv = cbc_decrypt_inplace(&rk, &read_iv, &mut chunk)?;
|
||||||
|
dec_buf.extend_from_slice(&chunk);
|
||||||
|
let _ = enc_buf.split_to(blocks);
|
||||||
|
}
|
||||||
|
|
||||||
|
while dec_buf.len() >= 4 {
|
||||||
|
let fl = u32::from_le_bytes(dec_buf[0..4].try_into().unwrap()) as usize;
|
||||||
|
|
||||||
|
if fl == 4 {
|
||||||
|
let _ = dec_buf.split_to(4);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if !(12..=(1 << 24)).contains(&fl) {
|
||||||
|
return Err(ProxyError::InvalidHandshake(format!(
|
||||||
|
"Bad HS response frame len: {fl}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
if dec_buf.len() < fl {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let frame = dec_buf.split_to(fl);
|
||||||
|
let pe = fl - 4;
|
||||||
|
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
|
||||||
|
let ac = crate::crypto::crc32(&frame[..pe]);
|
||||||
|
if ec != ac {
|
||||||
|
return Err(ProxyError::InvalidHandshake(format!(
|
||||||
|
"HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap());
|
||||||
|
if hs_type == RPC_HANDSHAKE_ERROR_U32 {
|
||||||
|
let err_code = if frame.len() >= 16 {
|
||||||
|
i32::from_le_bytes(frame[12..16].try_into().unwrap())
|
||||||
|
} else {
|
||||||
|
-1
|
||||||
|
};
|
||||||
|
return Err(ProxyError::InvalidHandshake(format!(
|
||||||
|
"ME rejected handshake (error={err_code})"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
if hs_type != RPC_HANDSHAKE_U32 {
|
||||||
|
return Err(ProxyError::InvalidHandshake(format!(
|
||||||
|
"Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
handshake_ok = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !handshake_ok {
|
||||||
|
return Err(ProxyError::TgHandshakeTimeout);
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(%addr, "RPC handshake OK");
|
||||||
|
|
||||||
|
let rpc_w = Arc::new(Mutex::new(RpcWriter {
|
||||||
|
writer: wr,
|
||||||
|
key: wk,
|
||||||
|
iv: write_iv,
|
||||||
|
seq_no: 0,
|
||||||
|
}));
|
||||||
|
self.writers.write().await.push((addr, rpc_w.clone()));
|
||||||
|
|
||||||
|
let reg = self.registry.clone();
|
||||||
|
let w_pong = rpc_w.clone();
|
||||||
|
let w_pool = self.writers_arc();
|
||||||
|
let w_ping = rpc_w.clone();
|
||||||
|
let w_pool_ping = self.writers_arc();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) =
|
||||||
|
reader_loop(rd, rk, read_iv, reg, enc_buf, dec_buf, w_pong.clone()).await
|
||||||
|
{
|
||||||
|
warn!(error = %e, "ME reader ended");
|
||||||
|
}
|
||||||
|
let mut ws = w_pool.write().await;
|
||||||
|
ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_pong));
|
||||||
|
info!(remaining = ws.len(), "Dead ME writer removed from pool");
|
||||||
|
});
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut ping_id: i64 = rand::random::<i64>();
|
||||||
|
loop {
|
||||||
|
let jitter = rand::rng()
|
||||||
|
.random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS);
|
||||||
|
let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64;
|
||||||
|
tokio::time::sleep(Duration::from_secs(wait)).await;
|
||||||
|
let mut p = Vec::with_capacity(12);
|
||||||
|
p.extend_from_slice(&RPC_PING_U32.to_le_bytes());
|
||||||
|
p.extend_from_slice(&ping_id.to_le_bytes());
|
||||||
|
ping_id = ping_id.wrapping_add(1);
|
||||||
|
if let Err(e) = w_ping.lock().await.send(&p).await {
|
||||||
|
debug!(error = %e, "Active ME ping failed, removing dead writer");
|
||||||
|
let mut ws = w_pool_ping.write().await;
|
||||||
|
ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_ping));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hex_dump(data: &[u8]) -> String {
|
||||||
|
const MAX: usize = 64;
|
||||||
|
let mut out = String::with_capacity(data.len() * 2 + 3);
|
||||||
|
for (i, b) in data.iter().take(MAX).enumerate() {
|
||||||
|
if i > 0 {
|
||||||
|
out.push(' ');
|
||||||
|
}
|
||||||
|
out.push_str(&format!("{b:02x}"));
|
||||||
|
}
|
||||||
|
if data.len() > MAX {
|
||||||
|
out.push_str(" …");
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
200
src/transport/middle_proxy/pool_nat.rs
Normal file
200
src/transport/middle_proxy/pool_nat.rs
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
use std::net::{IpAddr, Ipv4Addr};
|
||||||
|
|
||||||
|
use tracing::{info, warn};
|
||||||
|
|
||||||
|
use crate::error::{ProxyError, Result};
|
||||||
|
|
||||||
|
use super::MePool;
|
||||||
|
|
||||||
|
impl MePool {
|
||||||
|
pub(super) fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr {
|
||||||
|
let nat_ip = self
|
||||||
|
.nat_ip_cfg
|
||||||
|
.or_else(|| self.nat_ip_detected.get().copied());
|
||||||
|
|
||||||
|
let Some(nat_ip) = nat_ip else {
|
||||||
|
return ip;
|
||||||
|
};
|
||||||
|
|
||||||
|
match (ip, nat_ip) {
|
||||||
|
(IpAddr::V4(src), IpAddr::V4(dst))
|
||||||
|
if is_privateish(IpAddr::V4(src))
|
||||||
|
|| src.is_loopback()
|
||||||
|
|| src.is_unspecified() =>
|
||||||
|
{
|
||||||
|
IpAddr::V4(dst)
|
||||||
|
}
|
||||||
|
(IpAddr::V6(src), IpAddr::V6(dst)) if src.is_loopback() || src.is_unspecified() => {
|
||||||
|
IpAddr::V6(dst)
|
||||||
|
}
|
||||||
|
(orig, _) => orig,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn translate_our_addr_with_reflection(
|
||||||
|
&self,
|
||||||
|
addr: std::net::SocketAddr,
|
||||||
|
reflected: Option<std::net::SocketAddr>,
|
||||||
|
) -> std::net::SocketAddr {
|
||||||
|
let ip = if let Some(r) = reflected {
|
||||||
|
// Use reflected IP (not port) only when local address is non-public.
|
||||||
|
if is_privateish(addr.ip()) || addr.ip().is_loopback() || addr.ip().is_unspecified() {
|
||||||
|
r.ip()
|
||||||
|
} else {
|
||||||
|
self.translate_ip_for_nat(addr.ip())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.translate_ip_for_nat(addr.ip())
|
||||||
|
};
|
||||||
|
|
||||||
|
// Keep the kernel-assigned TCP source port; STUN port can differ.
|
||||||
|
std::net::SocketAddr::new(ip, addr.port())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn maybe_detect_nat_ip(&self, local_ip: IpAddr) -> Option<IpAddr> {
|
||||||
|
if self.nat_ip_cfg.is_some() {
|
||||||
|
return self.nat_ip_cfg;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !(is_privateish(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ip) = self.nat_ip_detected.get().copied() {
|
||||||
|
return Some(ip);
|
||||||
|
}
|
||||||
|
|
||||||
|
match fetch_public_ipv4().await {
|
||||||
|
Ok(Some(ip)) => {
|
||||||
|
let _ = self.nat_ip_detected.set(IpAddr::V4(ip));
|
||||||
|
info!(public_ip = %ip, "Auto-detected public IP for NAT translation");
|
||||||
|
Some(IpAddr::V4(ip))
|
||||||
|
}
|
||||||
|
Ok(None) => None,
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "Failed to auto-detect public IP");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn maybe_reflect_public_addr(&self) -> Option<std::net::SocketAddr> {
|
||||||
|
let stun_addr = self
|
||||||
|
.nat_stun
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| "stun.l.google.com:19302".to_string());
|
||||||
|
match fetch_stun_binding(&stun_addr).await {
|
||||||
|
Ok(sa) => {
|
||||||
|
if let Some(sa) = sa {
|
||||||
|
info!(%sa, "NAT probe: reflected address");
|
||||||
|
}
|
||||||
|
sa
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "NAT probe failed");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_public_ipv4() -> Result<Option<Ipv4Addr>> {
|
||||||
|
let res = reqwest::get("https://checkip.amazonaws.com").await.map_err(|e| {
|
||||||
|
ProxyError::Proxy(format!("public IP detection request failed: {e}"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let text = res.text().await.map_err(|e| {
|
||||||
|
ProxyError::Proxy(format!("public IP detection read failed: {e}"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let ip = text.trim().parse().ok();
|
||||||
|
Ok(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_stun_binding(stun_addr: &str) -> Result<Option<std::net::SocketAddr>> {
|
||||||
|
use rand::RngCore;
|
||||||
|
use tokio::net::UdpSocket;
|
||||||
|
|
||||||
|
let socket = UdpSocket::bind("0.0.0.0:0")
|
||||||
|
.await
|
||||||
|
.map_err(|e| ProxyError::Proxy(format!("STUN bind failed: {e}")))?;
|
||||||
|
socket
|
||||||
|
.connect(stun_addr)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ProxyError::Proxy(format!("STUN connect failed: {e}")))?;
|
||||||
|
|
||||||
|
// Build minimal Binding Request.
|
||||||
|
let mut req = vec![0u8; 20];
|
||||||
|
req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request
|
||||||
|
req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length
|
||||||
|
req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie
|
||||||
|
rand::thread_rng().fill_bytes(&mut req[8..20]);
|
||||||
|
|
||||||
|
socket
|
||||||
|
.send(&req)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ProxyError::Proxy(format!("STUN send failed: {e}")))?;
|
||||||
|
|
||||||
|
let mut buf = [0u8; 128];
|
||||||
|
let n = socket
|
||||||
|
.recv(&mut buf)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ProxyError::Proxy(format!("STUN recv failed: {e}")))?;
|
||||||
|
if n < 20 {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse attributes.
|
||||||
|
let mut idx = 20;
|
||||||
|
while idx + 4 <= n {
|
||||||
|
let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap());
|
||||||
|
let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize;
|
||||||
|
idx += 4;
|
||||||
|
if idx + alen > n {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
match atype {
|
||||||
|
0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => {
|
||||||
|
if alen < 8 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let family = buf[idx + 1];
|
||||||
|
if family != 0x01 {
|
||||||
|
// only IPv4 supported here
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let port_bytes = [buf[idx + 2], buf[idx + 3]];
|
||||||
|
let ip_bytes = [buf[idx + 4], buf[idx + 5], buf[idx + 6], buf[idx + 7]];
|
||||||
|
|
||||||
|
let (port, ip) = if atype == 0x0020 {
|
||||||
|
let magic = 0x2112A442u32.to_be_bytes();
|
||||||
|
let port = u16::from_be_bytes(port_bytes) ^ ((magic[0] as u16) << 8 | magic[1] as u16);
|
||||||
|
let ip = [
|
||||||
|
ip_bytes[0] ^ magic[0],
|
||||||
|
ip_bytes[1] ^ magic[1],
|
||||||
|
ip_bytes[2] ^ magic[2],
|
||||||
|
ip_bytes[3] ^ magic[3],
|
||||||
|
];
|
||||||
|
(port, ip)
|
||||||
|
} else {
|
||||||
|
(u16::from_be_bytes(port_bytes), ip_bytes)
|
||||||
|
};
|
||||||
|
return Ok(Some(std::net::SocketAddr::new(
|
||||||
|
IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3])),
|
||||||
|
port,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
idx += (alen + 3) & !3; // 4-byte alignment
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_privateish(ip: IpAddr) -> bool {
|
||||||
|
match ip {
|
||||||
|
IpAddr::V4(v4) => v4.is_private() || v4.is_link_local(),
|
||||||
|
IpAddr::V6(v6) => v6.is_unique_local(),
|
||||||
|
}
|
||||||
|
}
|
||||||
141
src/transport/middle_proxy/reader.rs
Normal file
141
src/transport/middle_proxy/reader.rs
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use bytes::{Bytes, BytesMut};
|
||||||
|
use tokio::io::AsyncReadExt;
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tracing::{debug, trace, warn};
|
||||||
|
|
||||||
|
use crate::crypto::{AesCbc, crc32};
|
||||||
|
use crate::error::{ProxyError, Result};
|
||||||
|
use crate::protocol::constants::*;
|
||||||
|
|
||||||
|
use super::codec::RpcWriter;
|
||||||
|
use super::{ConnRegistry, MeResponse};
|
||||||
|
|
||||||
|
pub(crate) async fn reader_loop(
|
||||||
|
mut rd: tokio::io::ReadHalf<TcpStream>,
|
||||||
|
dk: [u8; 32],
|
||||||
|
mut div: [u8; 16],
|
||||||
|
reg: Arc<ConnRegistry>,
|
||||||
|
enc_leftover: BytesMut,
|
||||||
|
mut dec: BytesMut,
|
||||||
|
writer: Arc<Mutex<RpcWriter>>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut raw = enc_leftover;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let mut tmp = [0u8; 16_384];
|
||||||
|
let n = rd.read(&mut tmp).await.map_err(ProxyError::Io)?;
|
||||||
|
if n == 0 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
raw.extend_from_slice(&tmp[..n]);
|
||||||
|
|
||||||
|
let blocks = raw.len() / 16 * 16;
|
||||||
|
if blocks > 0 {
|
||||||
|
let mut new_iv = [0u8; 16];
|
||||||
|
new_iv.copy_from_slice(&raw[blocks - 16..blocks]);
|
||||||
|
|
||||||
|
let mut chunk = vec![0u8; blocks];
|
||||||
|
chunk.copy_from_slice(&raw[..blocks]);
|
||||||
|
AesCbc::new(dk, div)
|
||||||
|
.decrypt_in_place(&mut chunk)
|
||||||
|
.map_err(|e| ProxyError::Crypto(format!("{e}")))?;
|
||||||
|
div = new_iv;
|
||||||
|
dec.extend_from_slice(&chunk);
|
||||||
|
let _ = raw.split_to(blocks);
|
||||||
|
}
|
||||||
|
|
||||||
|
while dec.len() >= 12 {
|
||||||
|
let fl = u32::from_le_bytes(dec[0..4].try_into().unwrap()) as usize;
|
||||||
|
if fl == 4 {
|
||||||
|
let _ = dec.split_to(4);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if !(12..=(1 << 24)).contains(&fl) {
|
||||||
|
warn!(frame_len = fl, "Invalid RPC frame len");
|
||||||
|
dec.clear();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if dec.len() < fl {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let frame = dec.split_to(fl);
|
||||||
|
let pe = fl - 4;
|
||||||
|
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
|
||||||
|
if crc32(&frame[..pe]) != ec {
|
||||||
|
warn!("CRC mismatch in data frame");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let payload = &frame[8..pe];
|
||||||
|
if payload.len() < 4 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let pt = u32::from_le_bytes(payload[0..4].try_into().unwrap());
|
||||||
|
let body = &payload[4..];
|
||||||
|
|
||||||
|
if pt == RPC_PROXY_ANS_U32 && body.len() >= 12 {
|
||||||
|
let flags = u32::from_le_bytes(body[0..4].try_into().unwrap());
|
||||||
|
let cid = u64::from_le_bytes(body[4..12].try_into().unwrap());
|
||||||
|
let data = Bytes::copy_from_slice(&body[12..]);
|
||||||
|
trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS");
|
||||||
|
|
||||||
|
let routed = reg.route(cid, MeResponse::Data { flags, data }).await;
|
||||||
|
if !routed {
|
||||||
|
reg.unregister(cid).await;
|
||||||
|
send_close_conn(&writer, cid).await;
|
||||||
|
}
|
||||||
|
} else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 {
|
||||||
|
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
|
||||||
|
let cfm = u32::from_le_bytes(body[8..12].try_into().unwrap());
|
||||||
|
trace!(cid, cfm, "RPC_SIMPLE_ACK");
|
||||||
|
|
||||||
|
let routed = reg.route(cid, MeResponse::Ack(cfm)).await;
|
||||||
|
if !routed {
|
||||||
|
reg.unregister(cid).await;
|
||||||
|
send_close_conn(&writer, cid).await;
|
||||||
|
}
|
||||||
|
} else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 {
|
||||||
|
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
|
||||||
|
debug!(cid, "RPC_CLOSE_EXT from ME");
|
||||||
|
reg.route(cid, MeResponse::Close).await;
|
||||||
|
reg.unregister(cid).await;
|
||||||
|
} else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 {
|
||||||
|
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
|
||||||
|
debug!(cid, "RPC_CLOSE_CONN from ME");
|
||||||
|
reg.route(cid, MeResponse::Close).await;
|
||||||
|
reg.unregister(cid).await;
|
||||||
|
} else if pt == RPC_PING_U32 && body.len() >= 8 {
|
||||||
|
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
|
||||||
|
trace!(ping_id, "RPC_PING -> RPC_PONG");
|
||||||
|
let mut pong = Vec::with_capacity(12);
|
||||||
|
pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes());
|
||||||
|
pong.extend_from_slice(&ping_id.to_le_bytes());
|
||||||
|
if let Err(e) = writer.lock().await.send(&pong).await {
|
||||||
|
warn!(error = %e, "PONG send failed");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
debug!(
|
||||||
|
rpc_type = format_args!("0x{pt:08x}"),
|
||||||
|
len = body.len(),
|
||||||
|
"Unknown RPC"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_close_conn(writer: &Arc<Mutex<RpcWriter>>, conn_id: u64) {
|
||||||
|
let mut p = Vec::with_capacity(12);
|
||||||
|
p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes());
|
||||||
|
p.extend_from_slice(&conn_id.to_le_bytes());
|
||||||
|
|
||||||
|
if let Err(e) = writer.lock().await.send(&p).await {
|
||||||
|
debug!(conn_id, error = %e, "Failed to send RPC_CLOSE_CONN");
|
||||||
|
}
|
||||||
|
}
|
||||||
42
src/transport/middle_proxy/registry.rs
Normal file
42
src/transport/middle_proxy/registry.rs
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
|
||||||
|
use tokio::sync::{RwLock, mpsc};
|
||||||
|
|
||||||
|
use super::MeResponse;
|
||||||
|
|
||||||
|
pub struct ConnRegistry {
|
||||||
|
map: RwLock<HashMap<u64, mpsc::Sender<MeResponse>>>,
|
||||||
|
next_id: AtomicU64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnRegistry {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
// Avoid fully predictable conn_id sequence from 1.
|
||||||
|
let start = rand::random::<u64>() | 1;
|
||||||
|
Self {
|
||||||
|
map: RwLock::new(HashMap::new()),
|
||||||
|
next_id: AtomicU64::new(start),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn register(&self) -> (u64, mpsc::Receiver<MeResponse>) {
|
||||||
|
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
|
||||||
|
let (tx, rx) = mpsc::channel(256);
|
||||||
|
self.map.write().await.insert(id, tx);
|
||||||
|
(id, rx)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn unregister(&self, id: u64) {
|
||||||
|
self.map.write().await.remove(&id);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn route(&self, id: u64, resp: MeResponse) -> bool {
|
||||||
|
let m = self.map.read().await;
|
||||||
|
if let Some(tx) = m.get(&id) {
|
||||||
|
tx.send(resp).await.is_ok()
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
81
src/transport/middle_proxy/secret.rs
Normal file
81
src/transport/middle_proxy/secret.rs
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
|
use crate::error::{ProxyError, Result};
|
||||||
|
|
||||||
|
/// Fetch Telegram proxy-secret binary.
|
||||||
|
pub async fn fetch_proxy_secret(cache_path: Option<&str>) -> Result<Vec<u8>> {
|
||||||
|
let cache = cache_path.unwrap_or("proxy-secret");
|
||||||
|
|
||||||
|
// 1) Try fresh download first.
|
||||||
|
match download_proxy_secret().await {
|
||||||
|
Ok(data) => {
|
||||||
|
if let Err(e) = tokio::fs::write(cache, &data).await {
|
||||||
|
warn!(error = %e, "Failed to cache proxy-secret (non-fatal)");
|
||||||
|
} else {
|
||||||
|
debug!(path = cache, len = data.len(), "Cached proxy-secret");
|
||||||
|
}
|
||||||
|
return Ok(data);
|
||||||
|
}
|
||||||
|
Err(download_err) => {
|
||||||
|
warn!(error = %download_err, "Proxy-secret download failed, trying cache/file fallback");
|
||||||
|
// Fall through to cache/file.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) Fallback to cache/file regardless of age; require len>=32.
|
||||||
|
match tokio::fs::read(cache).await {
|
||||||
|
Ok(data) if data.len() >= 32 => {
|
||||||
|
let age_hours = tokio::fs::metadata(cache)
|
||||||
|
.await
|
||||||
|
.ok()
|
||||||
|
.and_then(|m| m.modified().ok())
|
||||||
|
.and_then(|m| std::time::SystemTime::now().duration_since(m).ok())
|
||||||
|
.map(|d| d.as_secs() / 3600);
|
||||||
|
info!(
|
||||||
|
path = cache,
|
||||||
|
len = data.len(),
|
||||||
|
age_hours,
|
||||||
|
"Loaded proxy-secret from cache/file after download failure"
|
||||||
|
);
|
||||||
|
Ok(data)
|
||||||
|
}
|
||||||
|
Ok(data) => Err(ProxyError::Proxy(format!(
|
||||||
|
"Cached proxy-secret too short: {} bytes (need >= 32)",
|
||||||
|
data.len()
|
||||||
|
))),
|
||||||
|
Err(e) => Err(ProxyError::Proxy(format!(
|
||||||
|
"Failed to read proxy-secret cache after download failure: {e}"
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn download_proxy_secret() -> Result<Vec<u8>> {
|
||||||
|
let resp = reqwest::get("https://core.telegram.org/getProxySecret")
|
||||||
|
.await
|
||||||
|
.map_err(|e| ProxyError::Proxy(format!("Failed to download proxy-secret: {e}")))?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
return Err(ProxyError::Proxy(format!(
|
||||||
|
"proxy-secret download HTTP {}",
|
||||||
|
resp.status()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = resp
|
||||||
|
.bytes()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ProxyError::Proxy(format!("Read proxy-secret body: {e}")))?
|
||||||
|
.to_vec();
|
||||||
|
|
||||||
|
if data.len() < 32 {
|
||||||
|
return Err(ProxyError::Proxy(format!(
|
||||||
|
"proxy-secret too short: {} bytes (need >= 32)",
|
||||||
|
data.len()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(len = data.len(), "Downloaded proxy-secret OK");
|
||||||
|
Ok(data)
|
||||||
|
}
|
||||||
146
src/transport/middle_proxy/send.rs
Normal file
146
src/transport/middle_proxy/send.rs
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::Ordering;
|
||||||
|
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
|
use crate::error::{ProxyError, Result};
|
||||||
|
use crate::protocol::constants::{RPC_CLOSE_EXT_U32, TG_MIDDLE_PROXIES_V4};
|
||||||
|
|
||||||
|
use super::MePool;
|
||||||
|
use super::codec::RpcWriter;
|
||||||
|
use super::wire::build_proxy_req_payload;
|
||||||
|
|
||||||
|
impl MePool {
|
||||||
|
pub async fn send_proxy_req(
|
||||||
|
&self,
|
||||||
|
conn_id: u64,
|
||||||
|
target_dc: i16,
|
||||||
|
client_addr: SocketAddr,
|
||||||
|
our_addr: SocketAddr,
|
||||||
|
data: &[u8],
|
||||||
|
proto_flags: u32,
|
||||||
|
) -> Result<()> {
|
||||||
|
let payload = build_proxy_req_payload(
|
||||||
|
conn_id,
|
||||||
|
client_addr,
|
||||||
|
our_addr,
|
||||||
|
data,
|
||||||
|
self.proxy_tag.as_deref(),
|
||||||
|
proto_flags,
|
||||||
|
);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let ws = self.writers.read().await;
|
||||||
|
if ws.is_empty() {
|
||||||
|
return Err(ProxyError::Proxy("All ME connections dead".into()));
|
||||||
|
}
|
||||||
|
let writers: Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)> = ws.iter().cloned().collect();
|
||||||
|
drop(ws);
|
||||||
|
|
||||||
|
let candidate_indices = candidate_indices_for_dc(&writers, target_dc);
|
||||||
|
if candidate_indices.is_empty() {
|
||||||
|
return Err(ProxyError::Proxy("No ME writers available for target DC".into()));
|
||||||
|
}
|
||||||
|
let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len();
|
||||||
|
|
||||||
|
// Prefer immediately available writer to avoid waiting on stalled connection.
|
||||||
|
for offset in 0..candidate_indices.len() {
|
||||||
|
let cidx = (start + offset) % candidate_indices.len();
|
||||||
|
let idx = candidate_indices[cidx];
|
||||||
|
let w = writers[idx].1.clone();
|
||||||
|
if let Ok(mut guard) = w.try_lock() {
|
||||||
|
let send_res = guard.send(&payload).await;
|
||||||
|
drop(guard);
|
||||||
|
match send_res {
|
||||||
|
Ok(()) => return Ok(()),
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "ME write failed, removing dead conn");
|
||||||
|
let mut ws = self.writers.write().await;
|
||||||
|
ws.retain(|(_, o)| !Arc::ptr_eq(o, &w));
|
||||||
|
if ws.is_empty() {
|
||||||
|
return Err(ProxyError::Proxy("All ME connections dead".into()));
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All writers are currently busy, wait for the selected one.
|
||||||
|
let w = writers[candidate_indices[start]].1.clone();
|
||||||
|
match w.lock().await.send(&payload).await {
|
||||||
|
Ok(()) => return Ok(()),
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "ME write failed, removing dead conn");
|
||||||
|
let mut ws = self.writers.write().await;
|
||||||
|
ws.retain(|(_, o)| !Arc::ptr_eq(o, &w));
|
||||||
|
if ws.is_empty() {
|
||||||
|
return Err(ProxyError::Proxy("All ME connections dead".into()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send_close(&self, conn_id: u64) -> Result<()> {
|
||||||
|
let ws = self.writers.read().await;
|
||||||
|
if !ws.is_empty() {
|
||||||
|
let w = ws[0].1.clone();
|
||||||
|
drop(ws);
|
||||||
|
let mut p = Vec::with_capacity(12);
|
||||||
|
p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes());
|
||||||
|
p.extend_from_slice(&conn_id.to_le_bytes());
|
||||||
|
if let Err(e) = w.lock().await.send(&p).await {
|
||||||
|
debug!(error = %e, "ME close write failed");
|
||||||
|
let mut ws = self.writers.write().await;
|
||||||
|
ws.retain(|(_, o)| !Arc::ptr_eq(o, &w));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.registry.unregister(conn_id).await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn connection_count(&self) -> usize {
|
||||||
|
self.writers.try_read().map(|w| w.len()).unwrap_or(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn candidate_indices_for_dc(
|
||||||
|
writers: &[(SocketAddr, Arc<Mutex<RpcWriter>>)],
|
||||||
|
target_dc: i16,
|
||||||
|
) -> Vec<usize> {
|
||||||
|
let mut preferred = Vec::<SocketAddr>::new();
|
||||||
|
let key = target_dc as i32;
|
||||||
|
if let Some(v) = TG_MIDDLE_PROXIES_V4.get(&key) {
|
||||||
|
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
|
||||||
|
}
|
||||||
|
if preferred.is_empty() {
|
||||||
|
let abs = key.abs();
|
||||||
|
if let Some(v) = TG_MIDDLE_PROXIES_V4.get(&abs) {
|
||||||
|
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if preferred.is_empty() {
|
||||||
|
let abs = key.abs();
|
||||||
|
if let Some(v) = TG_MIDDLE_PROXIES_V4.get(&-abs) {
|
||||||
|
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if preferred.is_empty() {
|
||||||
|
return (0..writers.len()).collect();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut out = Vec::new();
|
||||||
|
for (idx, (addr, _)) in writers.iter().enumerate() {
|
||||||
|
if preferred.iter().any(|p| p == addr) {
|
||||||
|
out.push(idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if out.is_empty() {
|
||||||
|
return (0..writers.len()).collect();
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
106
src/transport/middle_proxy/wire.rs
Normal file
106
src/transport/middle_proxy/wire.rs
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
|
|
||||||
|
use crate::protocol::constants::*;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub(crate) enum IpMaterial {
|
||||||
|
V4([u8; 4]),
|
||||||
|
V6([u8; 16]),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn extract_ip_material(addr: SocketAddr) -> IpMaterial {
|
||||||
|
match addr.ip() {
|
||||||
|
IpAddr::V4(v4) => IpMaterial::V4(v4.octets()),
|
||||||
|
IpAddr::V6(v6) => {
|
||||||
|
if let Some(v4) = v6.to_ipv4_mapped() {
|
||||||
|
IpMaterial::V4(v4.octets())
|
||||||
|
} else {
|
||||||
|
IpMaterial::V6(v6.octets())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ipv4_to_mapped_v6_c_compat(ip: Ipv4Addr) -> [u8; 16] {
|
||||||
|
let mut buf = [0u8; 16];
|
||||||
|
|
||||||
|
// Matches tl_store_long(0) + tl_store_int(-0x10000).
|
||||||
|
buf[8..12].copy_from_slice(&(-0x10000i32).to_le_bytes());
|
||||||
|
|
||||||
|
// Matches tl_store_int(htonl(remote_ip_host_order)).
|
||||||
|
let host_order = u32::from_ne_bytes(ip.octets());
|
||||||
|
let network_order = host_order.to_be();
|
||||||
|
buf[12..16].copy_from_slice(&network_order.to_le_bytes());
|
||||||
|
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
|
||||||
|
fn append_mapped_addr_and_port(buf: &mut Vec<u8>, addr: SocketAddr) {
|
||||||
|
match addr.ip() {
|
||||||
|
IpAddr::V4(v4) => buf.extend_from_slice(&ipv4_to_mapped_v6_c_compat(v4)),
|
||||||
|
IpAddr::V6(v6) => buf.extend_from_slice(&v6.octets()),
|
||||||
|
}
|
||||||
|
buf.extend_from_slice(&(addr.port() as u32).to_le_bytes());
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn build_proxy_req_payload(
|
||||||
|
conn_id: u64,
|
||||||
|
client_addr: SocketAddr,
|
||||||
|
our_addr: SocketAddr,
|
||||||
|
data: &[u8],
|
||||||
|
proxy_tag: Option<&[u8]>,
|
||||||
|
proto_flags: u32,
|
||||||
|
) -> Vec<u8> {
|
||||||
|
let mut b = Vec::with_capacity(128 + data.len());
|
||||||
|
|
||||||
|
b.extend_from_slice(&RPC_PROXY_REQ_U32.to_le_bytes());
|
||||||
|
b.extend_from_slice(&proto_flags.to_le_bytes());
|
||||||
|
b.extend_from_slice(&conn_id.to_le_bytes());
|
||||||
|
|
||||||
|
append_mapped_addr_and_port(&mut b, client_addr);
|
||||||
|
append_mapped_addr_and_port(&mut b, our_addr);
|
||||||
|
|
||||||
|
if proto_flags & 12 != 0 {
|
||||||
|
let extra_start = b.len();
|
||||||
|
b.extend_from_slice(&0u32.to_le_bytes());
|
||||||
|
|
||||||
|
if let Some(tag) = proxy_tag {
|
||||||
|
b.extend_from_slice(&TL_PROXY_TAG_U32.to_le_bytes());
|
||||||
|
|
||||||
|
if tag.len() < 254 {
|
||||||
|
b.push(tag.len() as u8);
|
||||||
|
b.extend_from_slice(tag);
|
||||||
|
let pad = (4 - ((1 + tag.len()) % 4)) % 4;
|
||||||
|
b.extend(std::iter::repeat_n(0u8, pad));
|
||||||
|
} else {
|
||||||
|
b.push(0xfe);
|
||||||
|
let len_bytes = (tag.len() as u32).to_le_bytes();
|
||||||
|
b.extend_from_slice(&len_bytes[..3]);
|
||||||
|
b.extend_from_slice(tag);
|
||||||
|
let pad = (4 - (tag.len() % 4)) % 4;
|
||||||
|
b.extend(std::iter::repeat_n(0u8, pad));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let extra_bytes = (b.len() - extra_start - 4) as u32;
|
||||||
|
b[extra_start..extra_start + 4].copy_from_slice(&extra_bytes.to_le_bytes());
|
||||||
|
}
|
||||||
|
|
||||||
|
b.extend_from_slice(data);
|
||||||
|
b
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn proto_flags_for_tag(tag: crate::protocol::constants::ProtoTag, has_proxy_tag: bool) -> u32 {
|
||||||
|
use crate::protocol::constants::ProtoTag;
|
||||||
|
|
||||||
|
let mut flags = RPC_FLAG_MAGIC | RPC_FLAG_EXTMODE2;
|
||||||
|
if has_proxy_tag {
|
||||||
|
flags |= RPC_FLAG_HAS_AD_TAG;
|
||||||
|
}
|
||||||
|
|
||||||
|
match tag {
|
||||||
|
ProtoTag::Abridged => flags | RPC_FLAG_ABRIDGED,
|
||||||
|
ProtoTag::Intermediate => flags | RPC_FLAG_INTERMEDIATE,
|
||||||
|
ProtoTag::Secure => flags | RPC_FLAG_PAD | RPC_FLAG_INTERMEDIATE,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,4 +10,5 @@ pub use pool::ConnectionPool;
|
|||||||
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
|
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
|
||||||
pub use socket::*;
|
pub use socket::*;
|
||||||
pub use socks::*;
|
pub use socks::*;
|
||||||
pub use upstream::UpstreamManager;
|
pub use upstream::{DcPingResult, StartupPingResult, UpstreamManager};
|
||||||
|
pub mod middle_proxy;
|
||||||
|
|||||||
@@ -1,26 +1,153 @@
|
|||||||
//! Upstream Management
|
//! Upstream Management with per-DC latency-weighted selection
|
||||||
|
//!
|
||||||
|
//! IPv6/IPv4 connectivity checks with configurable preference.
|
||||||
|
|
||||||
use std::net::{SocketAddr, IpAddr};
|
use std::net::{SocketAddr, IpAddr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
use tokio::time::Instant;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use tracing::{debug, warn, error, info};
|
use tracing::{debug, warn, info, trace};
|
||||||
|
|
||||||
use crate::config::{UpstreamConfig, UpstreamType};
|
use crate::config::{UpstreamConfig, UpstreamType};
|
||||||
use crate::error::{Result, ProxyError};
|
use crate::error::{Result, ProxyError};
|
||||||
|
use crate::protocol::constants::{TG_DATACENTERS_V4, TG_DATACENTERS_V6, TG_DATACENTER_PORT};
|
||||||
use crate::transport::socket::create_outgoing_socket_bound;
|
use crate::transport::socket::create_outgoing_socket_bound;
|
||||||
use crate::transport::socks::{connect_socks4, connect_socks5};
|
use crate::transport::socks::{connect_socks4, connect_socks5};
|
||||||
|
|
||||||
|
/// Number of Telegram datacenters
|
||||||
|
const NUM_DCS: usize = 5;
|
||||||
|
|
||||||
|
/// Timeout for individual DC ping attempt
|
||||||
|
const DC_PING_TIMEOUT_SECS: u64 = 5;
|
||||||
|
|
||||||
|
// ============= RTT Tracking =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
struct LatencyEma {
|
||||||
|
value_ms: Option<f64>,
|
||||||
|
alpha: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LatencyEma {
|
||||||
|
const fn new(alpha: f64) -> Self {
|
||||||
|
Self { value_ms: None, alpha }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&mut self, sample_ms: f64) {
|
||||||
|
self.value_ms = Some(match self.value_ms {
|
||||||
|
None => sample_ms,
|
||||||
|
Some(prev) => prev * (1.0 - self.alpha) + sample_ms * self.alpha,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get(&self) -> Option<f64> {
|
||||||
|
self.value_ms
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Per-DC IP Preference Tracking =============
|
||||||
|
|
||||||
|
/// Tracks which IP version works for each DC
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum IpPreference {
|
||||||
|
/// Not yet tested
|
||||||
|
Unknown,
|
||||||
|
/// IPv6 works
|
||||||
|
PreferV6,
|
||||||
|
/// Only IPv4 works (IPv6 failed)
|
||||||
|
PreferV4,
|
||||||
|
/// Both work
|
||||||
|
BothWork,
|
||||||
|
/// Both failed
|
||||||
|
Unavailable,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for IpPreference {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::Unknown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Upstream State =============
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct UpstreamState {
|
struct UpstreamState {
|
||||||
config: UpstreamConfig,
|
config: UpstreamConfig,
|
||||||
healthy: bool,
|
healthy: bool,
|
||||||
fails: u32,
|
fails: u32,
|
||||||
last_check: std::time::Instant,
|
last_check: std::time::Instant,
|
||||||
|
/// Per-DC latency EMA (index 0 = DC1, index 4 = DC5)
|
||||||
|
dc_latency: [LatencyEma; NUM_DCS],
|
||||||
|
/// Per-DC IP version preference (learned from connectivity tests)
|
||||||
|
dc_ip_pref: [IpPreference; NUM_DCS],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl UpstreamState {
|
||||||
|
fn new(config: UpstreamConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
healthy: true,
|
||||||
|
fails: 0,
|
||||||
|
last_check: std::time::Instant::now(),
|
||||||
|
dc_latency: [LatencyEma::new(0.3); NUM_DCS],
|
||||||
|
dc_ip_pref: [IpPreference::Unknown; NUM_DCS],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map DC index to latency array slot (0..NUM_DCS).
|
||||||
|
fn dc_array_idx(dc_idx: i16) -> Option<usize> {
|
||||||
|
let abs_dc = dc_idx.unsigned_abs() as usize;
|
||||||
|
if abs_dc == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
if abs_dc >= 1 && abs_dc <= NUM_DCS {
|
||||||
|
Some(abs_dc - 1)
|
||||||
|
} else {
|
||||||
|
// Unknown DC → default cluster (DC 2, index 1)
|
||||||
|
Some(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get latency for a specific DC, falling back to average across all known DCs
|
||||||
|
fn effective_latency(&self, dc_idx: Option<i16>) -> Option<f64> {
|
||||||
|
if let Some(di) = dc_idx.and_then(Self::dc_array_idx) {
|
||||||
|
if let Some(ms) = self.dc_latency[di].get() {
|
||||||
|
return Some(ms);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (sum, count) = self.dc_latency.iter()
|
||||||
|
.filter_map(|l| l.get())
|
||||||
|
.fold((0.0, 0u32), |(s, c), v| (s + v, c + 1));
|
||||||
|
|
||||||
|
if count > 0 { Some(sum / count as f64) } else { None }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of a single DC ping
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DcPingResult {
|
||||||
|
pub dc_idx: usize,
|
||||||
|
pub dc_addr: SocketAddr,
|
||||||
|
pub rtt_ms: Option<f64>,
|
||||||
|
pub error: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of startup ping for one upstream (separate v6/v4 results)
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StartupPingResult {
|
||||||
|
pub v6_results: Vec<DcPingResult>,
|
||||||
|
pub v4_results: Vec<DcPingResult>,
|
||||||
|
pub upstream_name: String,
|
||||||
|
/// True if both IPv6 and IPv4 have at least one working DC
|
||||||
|
pub both_available: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Upstream Manager =============
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct UpstreamManager {
|
pub struct UpstreamManager {
|
||||||
upstreams: Arc<RwLock<Vec<UpstreamState>>>,
|
upstreams: Arc<RwLock<Vec<UpstreamState>>>,
|
||||||
@@ -30,12 +157,7 @@ impl UpstreamManager {
|
|||||||
pub fn new(configs: Vec<UpstreamConfig>) -> Self {
|
pub fn new(configs: Vec<UpstreamConfig>) -> Self {
|
||||||
let states = configs.into_iter()
|
let states = configs.into_iter()
|
||||||
.filter(|c| c.enabled)
|
.filter(|c| c.enabled)
|
||||||
.map(|c| UpstreamState {
|
.map(UpstreamState::new)
|
||||||
config: c,
|
|
||||||
healthy: true, // Optimistic start
|
|
||||||
fails: 0,
|
|
||||||
last_check: std::time::Instant::now(),
|
|
||||||
})
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
@@ -43,48 +165,64 @@ impl UpstreamManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Select an upstream using Weighted Round Robin (simplified)
|
/// Select upstream using latency-weighted random selection.
|
||||||
async fn select_upstream(&self) -> Option<usize> {
|
async fn select_upstream(&self, dc_idx: Option<i16>) -> Option<usize> {
|
||||||
let upstreams = self.upstreams.read().await;
|
let upstreams = self.upstreams.read().await;
|
||||||
if upstreams.is_empty() {
|
if upstreams.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
let healthy_indices: Vec<usize> = upstreams.iter()
|
let healthy: Vec<usize> = upstreams.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.filter(|(_, u)| u.healthy)
|
.filter(|(_, u)| u.healthy)
|
||||||
.map(|(i, _)| i)
|
.map(|(i, _)| i)
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if healthy_indices.is_empty() {
|
if healthy.is_empty() {
|
||||||
// If all unhealthy, try any random one
|
return Some(rand::rng().gen_range(0..upstreams.len()));
|
||||||
return Some(rand::thread_rng().gen_range(0..upstreams.len()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Weighted selection
|
if healthy.len() == 1 {
|
||||||
let total_weight: u32 = healthy_indices.iter()
|
return Some(healthy[0]);
|
||||||
.map(|&i| upstreams[i].config.weight as u32)
|
|
||||||
.sum();
|
|
||||||
|
|
||||||
if total_weight == 0 {
|
|
||||||
return Some(healthy_indices[rand::thread_rng().gen_range(0..healthy_indices.len())]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut choice = rand::thread_rng().gen_range(0..total_weight);
|
let weights: Vec<(usize, f64)> = healthy.iter().map(|&i| {
|
||||||
|
let base = upstreams[i].config.weight as f64;
|
||||||
|
let latency_factor = upstreams[i].effective_latency(dc_idx)
|
||||||
|
.map(|ms| if ms > 1.0 { 1000.0 / ms } else { 1000.0 })
|
||||||
|
.unwrap_or(1.0);
|
||||||
|
|
||||||
for &idx in &healthy_indices {
|
(i, base * latency_factor)
|
||||||
let weight = upstreams[idx].config.weight as u32;
|
}).collect();
|
||||||
|
|
||||||
|
let total: f64 = weights.iter().map(|(_, w)| w).sum();
|
||||||
|
|
||||||
|
if total <= 0.0 {
|
||||||
|
return Some(healthy[rand::rng().gen_range(0..healthy.len())]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut choice: f64 = rand::rng().gen_range(0.0..total);
|
||||||
|
|
||||||
|
for &(idx, weight) in &weights {
|
||||||
if choice < weight {
|
if choice < weight {
|
||||||
|
trace!(
|
||||||
|
upstream = idx,
|
||||||
|
dc = ?dc_idx,
|
||||||
|
weight = format!("{:.2}", weight),
|
||||||
|
total = format!("{:.2}", total),
|
||||||
|
"Upstream selected"
|
||||||
|
);
|
||||||
return Some(idx);
|
return Some(idx);
|
||||||
}
|
}
|
||||||
choice -= weight;
|
choice -= weight;
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(healthy_indices[0])
|
Some(healthy[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn connect(&self, target: SocketAddr) -> Result<TcpStream> {
|
/// Connect to target through a selected upstream.
|
||||||
let idx = self.select_upstream().await
|
pub async fn connect(&self, target: SocketAddr, dc_idx: Option<i16>) -> Result<TcpStream> {
|
||||||
|
let idx = self.select_upstream(dc_idx).await
|
||||||
.ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?;
|
.ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?;
|
||||||
|
|
||||||
let upstream = {
|
let upstream = {
|
||||||
@@ -92,28 +230,33 @@ impl UpstreamManager {
|
|||||||
guard[idx].config.clone()
|
guard[idx].config.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
match self.connect_via_upstream(&upstream, target).await {
|
match self.connect_via_upstream(&upstream, target).await {
|
||||||
Ok(stream) => {
|
Ok(stream) => {
|
||||||
// Mark success
|
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
|
||||||
let mut guard = self.upstreams.write().await;
|
let mut guard = self.upstreams.write().await;
|
||||||
if let Some(u) = guard.get_mut(idx) {
|
if let Some(u) = guard.get_mut(idx) {
|
||||||
if !u.healthy {
|
if !u.healthy {
|
||||||
debug!("Upstream recovered: {:?}", u.config);
|
debug!(rtt_ms = format!("{:.1}", rtt_ms), "Upstream recovered");
|
||||||
}
|
}
|
||||||
u.healthy = true;
|
u.healthy = true;
|
||||||
u.fails = 0;
|
u.fails = 0;
|
||||||
|
|
||||||
|
if let Some(di) = dc_idx.and_then(UpstreamState::dc_array_idx) {
|
||||||
|
u.dc_latency[di].update(rtt_ms);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(stream)
|
Ok(stream)
|
||||||
},
|
},
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Mark failure
|
|
||||||
let mut guard = self.upstreams.write().await;
|
let mut guard = self.upstreams.write().await;
|
||||||
if let Some(u) = guard.get_mut(idx) {
|
if let Some(u) = guard.get_mut(idx) {
|
||||||
u.fails += 1;
|
u.fails += 1;
|
||||||
warn!("Failed to connect via upstream {:?}: {}. Fails: {}", u.config, e, u.fails);
|
warn!(fails = u.fails, "Upstream failed: {}", e);
|
||||||
if u.fails > 3 {
|
if u.fails > 3 {
|
||||||
u.healthy = false;
|
u.healthy = false;
|
||||||
warn!("Upstream disabled due to failures: {:?}", u.config);
|
warn!("Upstream marked unhealthy");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e)
|
Err(e)
|
||||||
@@ -129,18 +272,16 @@ impl UpstreamManager {
|
|||||||
|
|
||||||
let socket = create_outgoing_socket_bound(target, bind_ip)?;
|
let socket = create_outgoing_socket_bound(target, bind_ip)?;
|
||||||
|
|
||||||
// Non-blocking connect logic
|
|
||||||
socket.set_nonblocking(true)?;
|
socket.set_nonblocking(true)?;
|
||||||
match socket.connect(&target.into()) {
|
match socket.connect(&target.into()) {
|
||||||
Ok(()) => {},
|
Ok(()) => {},
|
||||||
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {},
|
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
|
||||||
Err(err) => return Err(ProxyError::Io(err)),
|
Err(err) => return Err(ProxyError::Io(err)),
|
||||||
}
|
}
|
||||||
|
|
||||||
let std_stream: std::net::TcpStream = socket.into();
|
let std_stream: std::net::TcpStream = socket.into();
|
||||||
let stream = TcpStream::from_std(std_stream)?;
|
let stream = TcpStream::from_std(std_stream)?;
|
||||||
|
|
||||||
// Wait for connection to complete
|
|
||||||
stream.writable().await?;
|
stream.writable().await?;
|
||||||
if let Some(e) = stream.take_error()? {
|
if let Some(e) = stream.take_error()? {
|
||||||
return Err(ProxyError::Io(e));
|
return Err(ProxyError::Io(e));
|
||||||
@@ -149,8 +290,6 @@ impl UpstreamManager {
|
|||||||
Ok(stream)
|
Ok(stream)
|
||||||
},
|
},
|
||||||
UpstreamType::Socks4 { address, interface, user_id } => {
|
UpstreamType::Socks4 { address, interface, user_id } => {
|
||||||
info!("Connecting to target {} via SOCKS4 proxy {}", target, address);
|
|
||||||
|
|
||||||
let proxy_addr: SocketAddr = address.parse()
|
let proxy_addr: SocketAddr = address.parse()
|
||||||
.map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?;
|
.map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?;
|
||||||
|
|
||||||
@@ -159,18 +298,16 @@ impl UpstreamManager {
|
|||||||
|
|
||||||
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
|
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
|
||||||
|
|
||||||
// Non-blocking connect logic
|
|
||||||
socket.set_nonblocking(true)?;
|
socket.set_nonblocking(true)?;
|
||||||
match socket.connect(&proxy_addr.into()) {
|
match socket.connect(&proxy_addr.into()) {
|
||||||
Ok(()) => {},
|
Ok(()) => {},
|
||||||
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {},
|
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
|
||||||
Err(err) => return Err(ProxyError::Io(err)),
|
Err(err) => return Err(ProxyError::Io(err)),
|
||||||
}
|
}
|
||||||
|
|
||||||
let std_stream: std::net::TcpStream = socket.into();
|
let std_stream: std::net::TcpStream = socket.into();
|
||||||
let mut stream = TcpStream::from_std(std_stream)?;
|
let mut stream = TcpStream::from_std(std_stream)?;
|
||||||
|
|
||||||
// Wait for connection to complete
|
|
||||||
stream.writable().await?;
|
stream.writable().await?;
|
||||||
if let Some(e) = stream.take_error()? {
|
if let Some(e) = stream.take_error()? {
|
||||||
return Err(ProxyError::Io(e));
|
return Err(ProxyError::Io(e));
|
||||||
@@ -180,8 +317,6 @@ impl UpstreamManager {
|
|||||||
Ok(stream)
|
Ok(stream)
|
||||||
},
|
},
|
||||||
UpstreamType::Socks5 { address, interface, username, password } => {
|
UpstreamType::Socks5 { address, interface, username, password } => {
|
||||||
info!("Connecting to target {} via SOCKS5 proxy {}", target, address);
|
|
||||||
|
|
||||||
let proxy_addr: SocketAddr = address.parse()
|
let proxy_addr: SocketAddr = address.parse()
|
||||||
.map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?;
|
.map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?;
|
||||||
|
|
||||||
@@ -190,18 +325,16 @@ impl UpstreamManager {
|
|||||||
|
|
||||||
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
|
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
|
||||||
|
|
||||||
// Non-blocking connect logic
|
|
||||||
socket.set_nonblocking(true)?;
|
socket.set_nonblocking(true)?;
|
||||||
match socket.connect(&proxy_addr.into()) {
|
match socket.connect(&proxy_addr.into()) {
|
||||||
Ok(()) => {},
|
Ok(()) => {},
|
||||||
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {},
|
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
|
||||||
Err(err) => return Err(ProxyError::Io(err)),
|
Err(err) => return Err(ProxyError::Io(err)),
|
||||||
}
|
}
|
||||||
|
|
||||||
let std_stream: std::net::TcpStream = socket.into();
|
let std_stream: std::net::TcpStream = socket.into();
|
||||||
let mut stream = TcpStream::from_std(std_stream)?;
|
let mut stream = TcpStream::from_std(std_stream)?;
|
||||||
|
|
||||||
// Wait for connection to complete
|
|
||||||
stream.writable().await?;
|
stream.writable().await?;
|
||||||
if let Some(e) = stream.take_error()? {
|
if let Some(e) = stream.take_error()? {
|
||||||
return Err(ProxyError::Io(e));
|
return Err(ProxyError::Io(e));
|
||||||
@@ -213,47 +346,282 @@ impl UpstreamManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Background task to check health
|
// ============= Startup Ping (test both IPv6 and IPv4) =============
|
||||||
pub async fn run_health_checks(&self) {
|
|
||||||
// Simple TCP connect check to a known stable DC (e.g. 149.154.167.50:443 - DC2)
|
/// Ping all Telegram DCs through all upstreams.
|
||||||
let check_target: SocketAddr = "149.154.167.50:443".parse().unwrap();
|
/// Tests BOTH IPv6 and IPv4, returns separate results for each.
|
||||||
|
pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec<StartupPingResult> {
|
||||||
|
let upstreams: Vec<(usize, UpstreamConfig)> = {
|
||||||
|
let guard = self.upstreams.read().await;
|
||||||
|
guard.iter().enumerate()
|
||||||
|
.map(|(i, u)| (i, u.config.clone()))
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut all_results = Vec::new();
|
||||||
|
|
||||||
|
for (upstream_idx, upstream_config) in &upstreams {
|
||||||
|
let upstream_name = match &upstream_config.upstream_type {
|
||||||
|
UpstreamType::Direct { interface } => {
|
||||||
|
format!("direct{}", interface.as_ref().map(|i| format!(" ({})", i)).unwrap_or_default())
|
||||||
|
}
|
||||||
|
UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address),
|
||||||
|
UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut v6_results = Vec::new();
|
||||||
|
let mut v4_results = Vec::new();
|
||||||
|
|
||||||
|
// === Ping IPv6 first ===
|
||||||
|
for dc_zero_idx in 0..NUM_DCS {
|
||||||
|
let dc_v6 = TG_DATACENTERS_V6[dc_zero_idx];
|
||||||
|
let addr_v6 = SocketAddr::new(dc_v6, TG_DATACENTER_PORT);
|
||||||
|
|
||||||
|
let result = tokio::time::timeout(
|
||||||
|
Duration::from_secs(DC_PING_TIMEOUT_SECS),
|
||||||
|
self.ping_single_dc(&upstream_config, addr_v6)
|
||||||
|
).await;
|
||||||
|
|
||||||
|
let ping_result = match result {
|
||||||
|
Ok(Ok(rtt_ms)) => {
|
||||||
|
let mut guard = self.upstreams.write().await;
|
||||||
|
if let Some(u) = guard.get_mut(*upstream_idx) {
|
||||||
|
u.dc_latency[dc_zero_idx].update(rtt_ms);
|
||||||
|
}
|
||||||
|
DcPingResult {
|
||||||
|
dc_idx: dc_zero_idx + 1,
|
||||||
|
dc_addr: addr_v6,
|
||||||
|
rtt_ms: Some(rtt_ms),
|
||||||
|
error: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Err(e)) => DcPingResult {
|
||||||
|
dc_idx: dc_zero_idx + 1,
|
||||||
|
dc_addr: addr_v6,
|
||||||
|
rtt_ms: None,
|
||||||
|
error: Some(e.to_string()),
|
||||||
|
},
|
||||||
|
Err(_) => DcPingResult {
|
||||||
|
dc_idx: dc_zero_idx + 1,
|
||||||
|
dc_addr: addr_v6,
|
||||||
|
rtt_ms: None,
|
||||||
|
error: Some("timeout".to_string()),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
v6_results.push(ping_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// === Then ping IPv4 ===
|
||||||
|
for dc_zero_idx in 0..NUM_DCS {
|
||||||
|
let dc_v4 = TG_DATACENTERS_V4[dc_zero_idx];
|
||||||
|
let addr_v4 = SocketAddr::new(dc_v4, TG_DATACENTER_PORT);
|
||||||
|
|
||||||
|
let result = tokio::time::timeout(
|
||||||
|
Duration::from_secs(DC_PING_TIMEOUT_SECS),
|
||||||
|
self.ping_single_dc(&upstream_config, addr_v4)
|
||||||
|
).await;
|
||||||
|
|
||||||
|
let ping_result = match result {
|
||||||
|
Ok(Ok(rtt_ms)) => {
|
||||||
|
let mut guard = self.upstreams.write().await;
|
||||||
|
if let Some(u) = guard.get_mut(*upstream_idx) {
|
||||||
|
u.dc_latency[dc_zero_idx].update(rtt_ms);
|
||||||
|
}
|
||||||
|
DcPingResult {
|
||||||
|
dc_idx: dc_zero_idx + 1,
|
||||||
|
dc_addr: addr_v4,
|
||||||
|
rtt_ms: Some(rtt_ms),
|
||||||
|
error: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Err(e)) => DcPingResult {
|
||||||
|
dc_idx: dc_zero_idx + 1,
|
||||||
|
dc_addr: addr_v4,
|
||||||
|
rtt_ms: None,
|
||||||
|
error: Some(e.to_string()),
|
||||||
|
},
|
||||||
|
Err(_) => DcPingResult {
|
||||||
|
dc_idx: dc_zero_idx + 1,
|
||||||
|
dc_addr: addr_v4,
|
||||||
|
rtt_ms: None,
|
||||||
|
error: Some("timeout".to_string()),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
v4_results.push(ping_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if both IP versions have at least one working DC
|
||||||
|
let v6_has_working = v6_results.iter().any(|r| r.rtt_ms.is_some());
|
||||||
|
let v4_has_working = v4_results.iter().any(|r| r.rtt_ms.is_some());
|
||||||
|
let both_available = v6_has_working && v4_has_working;
|
||||||
|
|
||||||
|
// Update IP preference for each DC
|
||||||
|
{
|
||||||
|
let mut guard = self.upstreams.write().await;
|
||||||
|
if let Some(u) = guard.get_mut(*upstream_idx) {
|
||||||
|
for dc_zero_idx in 0..NUM_DCS {
|
||||||
|
let v6_ok = v6_results[dc_zero_idx].rtt_ms.is_some();
|
||||||
|
let v4_ok = v4_results[dc_zero_idx].rtt_ms.is_some();
|
||||||
|
|
||||||
|
u.dc_ip_pref[dc_zero_idx] = match (v6_ok, v4_ok) {
|
||||||
|
(true, true) => IpPreference::BothWork,
|
||||||
|
(true, false) => IpPreference::PreferV6,
|
||||||
|
(false, true) => IpPreference::PreferV4,
|
||||||
|
(false, false) => IpPreference::Unavailable,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
all_results.push(StartupPingResult {
|
||||||
|
v6_results,
|
||||||
|
v4_results,
|
||||||
|
upstream_name,
|
||||||
|
both_available,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
all_results
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<f64> {
|
||||||
|
let start = Instant::now();
|
||||||
|
let _stream = self.connect_via_upstream(config, target).await?;
|
||||||
|
Ok(start.elapsed().as_secs_f64() * 1000.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Health Checks =============
|
||||||
|
|
||||||
|
/// Background health check: rotates through DCs, 30s interval.
|
||||||
|
/// Uses preferred IP version based on config.
|
||||||
|
pub async fn run_health_checks(&self, prefer_ipv6: bool) {
|
||||||
|
let mut dc_rotation = 0usize;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||||
|
|
||||||
|
let dc_zero_idx = dc_rotation % NUM_DCS;
|
||||||
|
dc_rotation += 1;
|
||||||
|
|
||||||
|
let dc_addr = if prefer_ipv6 {
|
||||||
|
SocketAddr::new(TG_DATACENTERS_V6[dc_zero_idx], TG_DATACENTER_PORT)
|
||||||
|
} else {
|
||||||
|
SocketAddr::new(TG_DATACENTERS_V4[dc_zero_idx], TG_DATACENTER_PORT)
|
||||||
|
};
|
||||||
|
|
||||||
|
let fallback_addr = if prefer_ipv6 {
|
||||||
|
SocketAddr::new(TG_DATACENTERS_V4[dc_zero_idx], TG_DATACENTER_PORT)
|
||||||
|
} else {
|
||||||
|
SocketAddr::new(TG_DATACENTERS_V6[dc_zero_idx], TG_DATACENTER_PORT)
|
||||||
|
};
|
||||||
|
|
||||||
let count = self.upstreams.read().await.len();
|
let count = self.upstreams.read().await.len();
|
||||||
|
|
||||||
for i in 0..count {
|
for i in 0..count {
|
||||||
let config = {
|
let config = {
|
||||||
let guard = self.upstreams.read().await;
|
let guard = self.upstreams.read().await;
|
||||||
guard[i].config.clone()
|
guard[i].config.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let start = Instant::now();
|
||||||
let result = tokio::time::timeout(
|
let result = tokio::time::timeout(
|
||||||
Duration::from_secs(10),
|
Duration::from_secs(10),
|
||||||
self.connect_via_upstream(&config, check_target)
|
self.connect_via_upstream(&config, dc_addr)
|
||||||
|
).await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(Ok(_stream)) => {
|
||||||
|
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
|
||||||
|
let mut guard = self.upstreams.write().await;
|
||||||
|
let u = &mut guard[i];
|
||||||
|
u.dc_latency[dc_zero_idx].update(rtt_ms);
|
||||||
|
|
||||||
|
if !u.healthy {
|
||||||
|
info!(
|
||||||
|
rtt = format!("{:.0} ms", rtt_ms),
|
||||||
|
dc = dc_zero_idx + 1,
|
||||||
|
"Upstream recovered"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
u.healthy = true;
|
||||||
|
u.fails = 0;
|
||||||
|
u.last_check = std::time::Instant::now();
|
||||||
|
}
|
||||||
|
Ok(Err(_)) | Err(_) => {
|
||||||
|
// Try fallback
|
||||||
|
debug!(dc = dc_zero_idx + 1, "Health check failed, trying fallback");
|
||||||
|
|
||||||
|
let start2 = Instant::now();
|
||||||
|
let result2 = tokio::time::timeout(
|
||||||
|
Duration::from_secs(10),
|
||||||
|
self.connect_via_upstream(&config, fallback_addr)
|
||||||
).await;
|
).await;
|
||||||
|
|
||||||
let mut guard = self.upstreams.write().await;
|
let mut guard = self.upstreams.write().await;
|
||||||
let u = &mut guard[i];
|
let u = &mut guard[i];
|
||||||
|
|
||||||
match result {
|
match result2 {
|
||||||
Ok(Ok(_stream)) => {
|
Ok(Ok(_stream)) => {
|
||||||
|
let rtt_ms = start2.elapsed().as_secs_f64() * 1000.0;
|
||||||
|
u.dc_latency[dc_zero_idx].update(rtt_ms);
|
||||||
|
|
||||||
if !u.healthy {
|
if !u.healthy {
|
||||||
debug!("Upstream recovered: {:?}", u.config);
|
info!(
|
||||||
|
rtt = format!("{:.0} ms", rtt_ms),
|
||||||
|
dc = dc_zero_idx + 1,
|
||||||
|
"Upstream recovered (fallback)"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
u.healthy = true;
|
u.healthy = true;
|
||||||
u.fails = 0;
|
u.fails = 0;
|
||||||
}
|
}
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
debug!("Health check failed for {:?}: {}", u.config, e);
|
u.fails += 1;
|
||||||
// Don't mark unhealthy immediately in background check
|
debug!(dc = dc_zero_idx + 1, fails = u.fails,
|
||||||
|
"Health check failed (both): {}", e);
|
||||||
|
if u.fails > 3 {
|
||||||
|
u.healthy = false;
|
||||||
|
warn!("Upstream unhealthy (fails)");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
debug!("Health check timeout for {:?}", u.config);
|
u.fails += 1;
|
||||||
|
debug!(dc = dc_zero_idx + 1, fails = u.fails,
|
||||||
|
"Health check timeout (both)");
|
||||||
|
if u.fails > 3 {
|
||||||
|
u.healthy = false;
|
||||||
|
warn!("Upstream unhealthy (timeout)");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
u.last_check = std::time::Instant::now();
|
u.last_check = std::time::Instant::now();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the preferred IP for a DC (for use by other components)
|
||||||
|
pub async fn get_dc_ip_preference(&self, dc_idx: i16) -> Option<IpPreference> {
|
||||||
|
let guard = self.upstreams.read().await;
|
||||||
|
if guard.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
UpstreamState::dc_array_idx(dc_idx)
|
||||||
|
.map(|idx| guard[0].dc_ip_pref[idx])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get preferred DC address based on config preference
|
||||||
|
pub async fn get_dc_addr(&self, dc_idx: i16, prefer_ipv6: bool) -> Option<SocketAddr> {
|
||||||
|
let arr_idx = UpstreamState::dc_array_idx(dc_idx)?;
|
||||||
|
|
||||||
|
let ip = if prefer_ipv6 {
|
||||||
|
TG_DATACENTERS_V6[arr_idx]
|
||||||
|
} else {
|
||||||
|
TG_DATACENTERS_V4[arr_idx]
|
||||||
|
};
|
||||||
|
|
||||||
|
Some(SocketAddr::new(ip, TG_DATACENTER_PORT))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user