99 Commits

Author SHA1 Message Date
Alexey
a688bfe22f New Relay on Tokio Copy Bidirectional 2026-02-12 20:20:01 +03:00
Alexey
9bd12f6acb 1.2.0.2 Special DC support: merge pull request #32 from telemt/1.2.0.2
1.2.0.2 Special DC support
2026-02-12 18:46:40 +03:00
Alexey
61581203c4 Semaphore + Async Magics for Defcluster
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-12 18:38:05 +03:00
Alexey
84668e671e Default Cluster Drafts
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-12 18:25:41 +03:00
Alexey
5bde202866 Startup logging refactoring: merge pull request #26 from Katze-942/main
Startup logging refactoring
2026-02-12 11:46:22 +03:00
Жора Змейкин
9304d5256a Refactor startup logging
Move all startup output (DC pings, proxy links) from println!() to
      info!() for consistent tracing format. Add reload::Layer so startup
      messages stay visible even in silent mode.
2026-02-12 05:14:23 +03:00
Alexey
364bc6e278 Merge pull request #21 from telemt/1.2.0.0
1.2.0.0
2026-02-11 17:00:46 +03:00
Alexey
e83db704b7 Pull-up 2026-02-11 16:55:18 +03:00
Alexey
acf90043eb Merge pull request #15 from telemt/main-emergency
Update README.md
2026-02-11 00:56:12 +03:00
Alexey
0011e20653 Update README.md 2026-02-11 00:55:27 +03:00
Alexey
41fb307858 Merge pull request #14 from telemt/main-emergency
Update README.md
2026-02-11 00:41:30 +03:00
Alexey
6a78c44d2e Update README.md 2026-02-11 00:41:08 +03:00
Alexey
be9c9858ac Merge pull request #13 from telemt/main-emergency
Main emergency
2026-02-11 00:39:45 +03:00
Alexey
2fa8d85b4c Update README.md 2026-02-11 00:31:45 +03:00
Alexey
310666fd44 Update README.md 2026-02-11 00:31:02 +03:00
Alexey
6cafee153a Fire-and-Forgot™ Draft
- Added fire-and-forget ignition via `--init` CLI command:
  - New `mod cli;` module handling installation logic
  - Extended `parse_cli()` to process `--init` flag (runs synchronously before tokio runtime)
  - Expanded `--help` output with installation options

- `--init` command functionality:
  - Generates random secret if not provided via `--secret`
  - Creates `/etc/telemt/config.toml` from template with user-provided or default parameters (`--port`, `--domain`, `--user`, `--config-dir`)
  - Creates hardened systemd unit `/etc/systemd/system/telemt.service` with security features:
    - `NoNewPrivileges=true`
    - `ProtectSystem=strict`
    - `PrivateTmp=true`
  - Runs `systemctl enable --now telemt.service`
  - Outputs `tg://` proxy links for the running service

- Implementation approach:
  - `--init` handled at the very start of `main()` before any async context
  - Uses blocking operations throughout (file I/O, `std::process::Command` for systemctl)
  - IP detection for tg:// links performed via blocking HTTP request
  - Command exits after installation without entering normal proxy runtime

- New CLI parameters for installation:
  - `--port` - listening port (default: 443)
  - `--domain` - TLS domain (default: auto-detected)
  - `--secret` - custom secret (default: randomly generated)
  - `--user` - systemd service user (default: telemt)
  - `--config-dir` - configuration directory (default: /etc/telemt)

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-07 20:31:49 +03:00
Alexey
32f60f34db Fix Stats + UpstreamState + EMA Latency Tracking
- Per-DC latency tracking in UpstreamState (array of 5 EMA instances, one per DC):
  - Added `dc_latency: [LatencyEma; 5]` – per‑DC tracking instead of a single global EMA
  - `effective_latency(dc_idx)` – returns DC‑specific latency, falls back to average if unavailable
  - `select_upstream(dc_idx)` – now performs latency‑weighted selection: effective_weight = config_weight × (1000 / latency_ms)
    - Example: two upstreams with equal config weight but latencies of 50ms and 200ms → selection probabilities become 80% / 20%
  - `connect(target, dc_idx)` – extended signature, dc_idx used for upstream selection and per‑DC RTT recording
  - All ping/health‑check operations now record RTT into `dc_latency[dc_zero_index]`
  - `upstream_manager.connect(dc_addr)` changed to `upstream_manager.connect(dc_addr, Some(success.dc_idx))` – DC index now participates in upstream selection and per‑DC RTT logging
  - `client.rs` – passes dc_idx when connecting to Telegram

- Summary: Upstream selection now accounts for per‑DC latency using the formula weight × (1000/ms). With multiple upstreams (e.g., direct + socks5), traffic automatically flows to the faster route for each specific DC. With a single upstream, the data is used for monitoring without affecting routing.

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-07 20:24:12 +03:00
Alexey
158eae8d2a Antireplay Improvements + DC Ping
- Fix: LruCache::get type ambiguity in stats/mod.rs
  - Changed `self.cache.get(&key.into())` to `self.cache.get(key)` (key is already &[u8], resolved via Box<[u8]>: Borrow<[u8]>)
  - Changed `self.cache.peek(&key)` / `.pop(&key)` to `.peek(key.as_ref())` / `.pop(key.as_ref())` (explicit &[u8] instead of &Box<[u8]>)

- Startup DC ping with RTT display and improved health-check (all DCs, RTT tracking, EMA latency, 30s interval):
  - Implemented `LatencyEma` – exponential moving average (α=0.3) for RTT
  - `connect()` – measures RTT of each real connection and updates EMA
  - `ping_all_dcs()` – pings all 5 DCs via each upstream, returns `Vec<StartupPingResult>` with RTT or error
  - `run_health_checks(prefer_ipv6)` – accepts IPv6 preference parameter, rotates DC between cycles (DC1→DC2→...→DC5→DC1...), interval reduced to 30s from 60s, failed checks now mark upstream as unhealthy after 3 consecutive fails
  - `DcPingResult` / `StartupPingResult` – public structures for display
  - DC Ping at startup: calls `upstream_manager.ping_all_dcs()` before accept loop, outputs table via `println!` (always visible)
  - Health checks with `prefer_ipv6`: `run_health_checks(prefer_ipv6)` receives the parameter
  - Exported `StartupPingResult` and `DcPingResult`

- Summary: Startup DC ping with RTT, rotational health-check with EMA latency tracking, 30-second interval, correct unhealthy marking after 3 fails.

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-07 20:18:25 +03:00
Alexey
92cedabc81 Zeroize for key + log refactor + fix tests
- Fixed tests that failed to compile due to mismatched generic parameters of HandshakeResult:
  - Changed `HandshakeResult<i32>` to `HandshakeResult<i32, (), ()>`
  - Changed `HandshakeResult::BadClient` to `HandshakeResult::BadClient { reader: (), writer: () }`

- Added Zeroize for all structures holding key material:
  - AesCbc – key and IV are zeroized on drop
  - SecureRandomInner – PRNG output buffer is zeroized on drop; local key copy in constructor is zeroized immediately after being passed to the cipher
  - ObfuscationParams – all four key‑material fields are zeroized on drop
  - HandshakeSuccess – all four key‑material fields are zeroized on drop

- Added protocol‑requirement documentation for legacy hashes (CodeQL suppression) in hash.rs (MD5/SHA‑1)

- Added documentation for zeroize limitations of AesCtr (opaque cipher state) in aes.rs

- Implemented silent‑mode logging and refactored initialization:
  - Added LogLevel enum to config and CLI flags --silent / --log-level
  - Added parse_cli() to handle --silent, --log-level, --help
  - Restructured main.rs initialization order: CLI → config load → determine log level → init tracing
  - Errors before tracing initialization are printed via eprintln!
  - Proxy links (tg://) are printed via println! – always visible regardless of log level
  - Configuration summary and operational messages are logged via info! (suppressed in silent mode)
  - Connection processing errors are lowered to debug! (hidden in silent mode)
  - Warning about default tls_domain moved to main (after tracing init)

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-07 19:49:41 +03:00
Alexey
b9428d9780 Antireplay on sliding window + SecureRandom 2026-02-07 18:26:44 +03:00
Alexey
5876f0c4d5 Update rust.yml 2026-02-07 17:58:10 +03:00
Alexey
94750a2749 Update README.md 2026-01-22 03:33:13 +03:00
Alexey
cf4b240913 Update README.md 2026-01-22 03:26:34 +03:00
Alexey
1424fbb1d5 Update README.md 2026-01-22 03:19:50 +03:00
Alexey
97f4c0d3b7 Update README.md 2026-01-22 03:17:37 +03:00
Alexey
806536fab6 Update README.md 2026-01-22 03:14:39 +03:00
Alexey
df8cfe462b Update README.md 2026-01-22 03:13:08 +03:00
Alexey
a5f1521d71 Update README.md 2026-01-22 03:07:38 +03:00
Alexey
8de7b7adc0 Update README.md 2026-01-22 03:03:19 +03:00
Alexey
cde1b15ef0 Update config.toml 2026-01-22 02:45:30 +03:00
Alexey
46e4c06ba6 Update README.md 2026-01-22 01:59:18 +03:00
Alexey
b7673daf0f Update README.md 2026-01-22 01:57:44 +03:00
Alexey
397ed8f193 Update README.md 2026-01-22 01:56:42 +03:00
Alexey
d90b2fd300 Update README.md 2026-01-22 01:55:31 +03:00
Alexey
d62136d9fa Update README.md 2026-01-22 01:53:05 +03:00
Alexey
0f8933b908 Update README.md 2026-01-22 01:48:37 +03:00
Alexey
0ec87974d1 Update README.md 2026-01-22 01:47:43 +03:00
Alexey
c8446c32d1 Update README.md 2026-01-22 01:46:28 +03:00
Alexey
f79a2eb097 Update README.md 2026-01-22 01:26:36 +03:00
Alexey
dea1a3b5de Update README.md 2026-01-22 01:16:46 +03:00
Alexey
97ce235ae4 Update README.md 2026-01-22 01:16:35 +03:00
Alexey
d04757eb9c Update README.md 2026-01-20 11:13:33 +03:00
Alexey
2d7901a978 Update README.md 2026-01-20 11:09:24 +03:00
Alexey
3881ba9bed 1.1.1.0 2026-01-20 02:09:56 +03:00
Alexey
5ac9089ccb Update README.md 2026-01-20 01:39:59 +03:00
Alexey
eb8b991818 Update README.md 2026-01-20 01:32:39 +03:00
Alexey
2ce8fbb2cc 1.1.0.0 2026-01-20 01:20:02 +03:00
Alexey
038f0cd5d1 Update README.md 2026-01-19 23:52:31 +03:00
Alexey
efea3f981d Update README.md 2026-01-19 23:51:43 +03:00
Alexey
42ce9dd671 Update README.md 2026-01-12 22:11:21 +03:00
Alexey
4fa6867056 Merge pull request #7 from telemt/1.0.3.0
1.0.3.0
2026-01-12 00:49:31 +03:00
Alexey
54ea6efdd0 Global rewrite of AES-CTR + Upstream Pending + to_accept selection 2026-01-12 00:46:51 +03:00
brekotis
27ac32a901 Fixes in TLS for iOS 2026-01-12 00:32:42 +03:00
Alexey
829f53c123 Fixes for iOS 2026-01-11 22:59:51 +03:00
Alexey
43eae6127d Update README.md 2026-01-10 22:17:03 +03:00
Alexey
a03212c8cc Update README.md 2026-01-10 22:15:02 +03:00
Alexey
2613969a7c Update rust.yml 2026-01-09 23:15:52 +03:00
Alexey
be1b2db867 Update README.md 2026-01-08 02:10:34 +03:00
Alexey
8fbee8701b Update README.md 2026-01-08 02:10:02 +03:00
Alexey
952d160870 Update README.md 2026-01-08 02:03:30 +03:00
Alexey
91ae6becde Update README.md 2026-01-08 02:01:50 +03:00
Alexey
e1f576e4fe Update README.md 2026-01-08 02:00:27 +03:00
Alexey
a7556cabdc Update README.md 2026-01-07 19:12:16 +03:00
Alexey
b2e8d16bb1 Update README.md 2026-01-07 19:10:04 +03:00
Alexey
d95e762812 Update README.md 2026-01-07 19:07:08 +03:00
Alexey
384f927fc3 Update README.md 2026-01-07 19:06:28 +03:00
Alexey
1b7c09ae18 Update README.md 2026-01-07 18:54:44 +03:00
Alexey
85cb4092d5 1.0.2.0 2026-01-07 18:16:01 +03:00
Alexey
5016160ac3 1.0.1.2 2026-01-07 17:42:30 +03:00
Alexey
4f007f3128 1.0.1.1
Drafting Upstreams and SOCKS
2026-01-07 17:22:10 +03:00
Alexey
7746a1177c Update README.md 2026-01-06 15:10:14 +03:00
Alexey
2bb2a2983f Update README.md 2026-01-06 14:57:52 +03:00
Alexey
5778be4f6e Update README.md 2026-01-02 19:10:12 +03:00
Alexey
f443d3dfc7 Update README.md 2026-01-02 16:54:35 +03:00
Alexey
450cf180ad Update README.md 2026-01-02 16:33:42 +03:00
Alexey
84fa7face0 Update README.md 2026-01-02 16:33:07 +03:00
Alexey
f8a2ea1972 Update README.md 2026-01-02 16:31:55 +03:00
Alexey
96d0a6bdfa Update README.md 2026-01-02 16:31:29 +03:00
Alexey
eeee55e8ea Update README.md 2026-01-02 16:21:52 +03:00
Alexey
7be179b3c0 Added accurate MTProto Frame Types + Tokio Async Intergr 2026-01-02 01:37:02 +03:00
Alexey
b2e034f8f1 Deleting of inconsiderately added rs 2026-01-02 01:20:26 +03:00
Alexey
ffe5a6cfb7 Fake TLS Fixes for Async IO
added more comments and schemas
2026-01-02 01:17:56 +03:00
Alexey
0e096ca8fb TLS Stream Tuning 2026-01-01 23:48:52 +03:00
Alexey
50658525cf Merge branch 'main' of https://github.com/telemt/telemt 2026-01-01 23:34:13 +03:00
Alexey
4fd5ff4e83 ET + SM + Crypto Fixes 2026-01-01 23:34:04 +03:00
Alexey
df4f312fec Update rust.yml 2025-12-31 06:04:56 +03:00
Alexey
7d9a8b99b4 Update rust.yml 2025-12-31 06:01:59 +03:00
Alexey
06f34e55cd Update rust.yml 2025-12-31 05:59:20 +03:00
Alexey
153cb7f3a3 Create rust.yml 2025-12-31 05:54:45 +03:00
Alexey
7f8904a989 Update README.md 2025-12-31 05:48:17 +03:00
Alexey
0ee71a59a0 Update README.md 2025-12-31 05:44:48 +03:00
Alexey
45c7347e22 Update README.md 2025-12-31 05:29:09 +03:00
Alexey
3805237d74 Update README.md 2025-12-31 05:28:32 +03:00
Alexey
5b281bf7fd Create telemt.service
based Systemd service
2025-12-31 05:10:18 +03:00
Alexey
d64cccd52c Update README.md 2025-12-31 04:45:28 +03:00
Alexey
016fdada68 Update README.md 2025-12-31 04:39:49 +03:00
Alexey
2c2ceeaf54 Update README.md 2025-12-30 22:18:22 +03:00
Alexey
dd6badd786 Update README.md 2025-12-30 21:31:54 +03:00
Alexey
50e72368c8 Update README.md 2025-12-30 21:29:04 +03:00
35 changed files with 10707 additions and 1598 deletions

46
.github/workflows/rust.yml vendored Normal file
View File

@@ -0,0 +1,46 @@
name: Rust
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
env:
CARGO_TERM_COLOR: always
jobs:
build:
name: Build
runs-on: ubuntu-latest
permissions:
contents: read
actions: write
checks: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install latest stable Rust toolchain
uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt, clippy
- name: Cache cargo registry & build artifacts
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
- name: Build Release
run: cargo build --release --verbose
- name: Check for unused dependencies
run: cargo udeps || true

2742
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,15 +1,14 @@
[package] [package]
name = "telemt" name = "telemt"
version = "1.0.0" version = "1.2.0"
edition = "2021" edition = "2024"
rust-version = "1.75"
[dependencies] [dependencies]
# C # C
libc = "0.2" libc = "0.2"
# Async runtime # Async runtime
tokio = { version = "1.35", features = ["full", "tracing"] } tokio = { version = "1.42", features = ["full", "tracing"] }
tokio-util = { version = "0.7", features = ["codec"] } tokio-util = { version = "0.7", features = ["codec"] }
# Crypto # Crypto
@@ -20,40 +19,41 @@ sha2 = "0.10"
sha1 = "0.10" sha1 = "0.10"
md-5 = "0.10" md-5 = "0.10"
hmac = "0.12" hmac = "0.12"
crc32fast = "1.3" crc32fast = "1.4"
zeroize = { version = "1.8", features = ["derive"] }
# Network # Network
socket2 = { version = "0.5", features = ["all"] } socket2 = { version = "0.5", features = ["all"] }
rustls = "0.22"
# Serial # Serialization
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
toml = "0.8" toml = "0.8"
# Utils # Utils
bytes = "1.5" bytes = "1.9"
thiserror = "1.0" thiserror = "2.0"
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }
parking_lot = "0.12" parking_lot = "0.12"
dashmap = "5.5" dashmap = "5.5"
lru = "0.12" lru = "0.12"
rand = "0.8" rand = "0.9"
chrono = { version = "0.4", features = ["serde"] } chrono = { version = "0.4", features = ["serde"] }
hex = "0.4" hex = "0.4"
base64 = "0.21" base64 = "0.22"
url = "2.5" url = "2.5"
regex = "1.10" regex = "1.11"
once_cell = "1.19" crossbeam-queue = "0.3"
# HTTP # HTTP
reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false } reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false }
[dev-dependencies] [dev-dependencies]
tokio-test = "0.4" tokio-test = "0.4"
criterion = "0.5" criterion = "0.5"
proptest = "1.4" proptest = "1.4"
futures = "0.3"
[[bench]] [[bench]]
name = "crypto_bench" name = "crypto_bench"

425
README.md
View File

@@ -1,2 +1,423 @@
# telemt # Telemt - MTProxy on Rust + Tokio
MTProxy for Telegram on Rust + Tokio
**Telemt** is a fast, secure, and feature-rich server written in Rust: it fully implements the official Telegram proxy algo and adds many production-ready improvements such as connection pooling, replay protection, detailed statistics, masking from "prying" eyes
## Emergency
**Важное сообщение для пользователей из России**
Мы работаем над проектом с Нового года и сейчас готовим новый релиз - 1.2
В нём имплементируется поддержка Middle Proxy Protocol - основного терминатора для Ad Tag:
работа над ним идёт с 6 ферваля, а уже 10 февраля произошли "громкие события"...
Если у вас есть компетенции в асинхронных сетевых приложениях - мы открыты к предложениям и pull requests
**Important message for users from Russia**
We've been working on the project since December 30 and are currently preparing a new release 1.2
It implements support for the Middle Proxy Protocol the primary point for the Ad Tag:
development on it started on February 6th, and by February 10th, "big activity" in Russia had already "taken place"...
If you have expertise in asynchronous network applications we are open to ideas and pull requests!
# Features
💥 The configuration structure has changed since version 1.1.0.0, change it in your environment!
⚓ Our implementation of **TLS-fronting** is one of the most deeply debugged, focused, advanced and *almost* **"behaviorally consistent to real"**: we are confident we have it right - [see evidence on our validation and traces](#recognizability-for-dpi-and-crawler)
# GOTO
- [Features](#features)
- [Quick Start Guide](#quick-start-guide)
- [How to use?](#how-to-use)
- [Systemd Method](#telemt-via-systemd)
- [Configuration](#configuration)
- [Minimal Configuration](#minimal-configuration-for-first-start)
- [Advanced](#advanced)
- [Adtag](#adtag)
- [Listening and Announce IPs](#listening-and-announce-ips)
- [Upstream Manager](#upstream-manager)
- [IP](#bind-on-ip)
- [SOCKS](#socks45-as-upstream)
- [FAQ](#faq)
- [Recognizability for DPI + crawler](#recognizability-for-dpi-and-crawler)
- [Telegram Calls](#telegram-calls-via-mtproxy)
- [DPI](#how-does-dpi-see-mtproxy-tls)
- [Whitelist on Network Level](#whitelist-on-ip)
- [Build](#build)
- [Why Rust?](#why-rust)
## Features
- Full support for all official MTProto proxy modes:
- Classic
- Secure - with `dd` prefix
- Fake TLS - with `ee` prefix + SNI fronting
- Replay attack protection
- Optional traffic masking: forward unrecognized connections to a real web server, e.g. GitHub 🤪
- Configurable keepalives + timeouts + IPv6 and "Fast Mode"
- Graceful shutdown on Ctrl+C
- Extensive logging via `trace` and `debug` with `RUST_LOG` method
## Quick Start Guide
**This software is designed for Debian-based OS: in addition to Debian, these are Ubuntu, Mint, Kali, MX and many other Linux**
1. Download release
```bash
wget https://github.com/telemt/telemt/releases/latest/download/telemt
```
2. Move to Bin Folder
```bash
mv telemt /bin
```
4. Make Executable
```bash
chmod +x /bin/telemt
```
5. Go to [How to use?](#how-to-use) section for for further steps
## How to use?
### Telemt via Systemd
**This instruction "assume" that you:**
- logged in as root or executed `su -` / `sudo su`
- you already have an assembled and executable `telemt` in /bin folder as a result of the [Quick Start Guide](#quick-start-guide) or [Build](#build)
**0. Check port and generate secrets**
The port you have selected for use should be MISSING from the list, when:
```bash
netstat -lnp
```
Generate 16 bytes/32 characters HEX with OpenSSL or another way:
```bash
openssl rand -hex 16
```
OR
```bash
xxd -l 16 -p /dev/urandom
```
OR
```bash
python3 -c 'import os; print(os.urandom(16).hex())'
```
**1. Place your config to /etc/telemt.toml**
Open nano
```bash
nano /etc/telemt.toml
```
paste your config from [Configuration](#configuration) section
then Ctrl+X -> Y -> Enter to save
**2. Create service on /etc/systemd/system/telemt.service**
Open nano
```bash
nano /etc/systemd/system/telemt.service
```
paste this Systemd Module
```bash
[Unit]
Description=Telemt
After=network.target
[Service]
Type=simple
WorkingDirectory=/bin
ExecStart=/bin/telemt /etc/telemt.toml
Restart=on-failure
[Install]
WantedBy=multi-user.target
```
then Ctrl+X -> Y -> Enter to save
**3.** In Shell type `systemctl start telemt` - it must start with zero exit-code
**4.** In Shell type `systemctl status telemt` - there you can reach info about current MTProxy status
**5.** In Shell type `systemctl enable telemt` - then telemt will start with system startup, after the network is up
## Configuration
### Minimal Configuration for First Start
```toml
# === UI ===
# Users to show in the startup log (tg:// links)
show_link = ["hello"]
# === General Settings ===
[general]
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = false
# ad_tag = "..."
[general.modes]
classic = false
secure = false
tls = true
# === Server Binding ===
[server]
port = 443
listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::"
# metrics_port = 9090
# metrics_whitelist = ["127.0.0.1", "::1"]
# Listen on multiple interfaces/IPs (overrides listen_addr_*)
[[server.listeners]]
ip = "0.0.0.0"
# announce_ip = "1.2.3.4" # Optional: Public IP for tg:// links
[[server.listeners]]
ip = "::"
# === Timeouts (in seconds) ===
[timeouts]
client_handshake = 15
tg_connect = 10
client_keepalive = 60
client_ack = 300
# === Anti-Censorship & Masking ===
[censorship]
tls_domain = "petrovich.ru"
mask = true
mask_port = 443
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
fake_cert_len = 2048
# === Access Control & Users ===
# username "hello" is used for example
[access]
replay_check_len = 65536
ignore_time_skew = false
[access.users]
# format: "username" = "32_hex_chars_secret"
hello = "00000000000000000000000000000000"
# [access.user_max_tcp_conns]
# hello = 50
# [access.user_data_quota]
# hello = 1073741824 # 1 GB
# === Upstreams & Routing ===
# By default, direct connection is used, but you can add SOCKS proxy
# Direct - Default
[[upstreams]]
type = "direct"
enabled = true
weight = 10
# SOCKS5
# [[upstreams]]
# type = "socks5"
# address = "127.0.0.1:9050"
# enabled = false
# weight = 1
```
### Advanced
#### Adtag
To use channel advertising and usage statistics from Telegram, get Adtag from [@mtproxybot](https://t.me/mtproxybot), add this parameter to section `[General]`
```toml
ad_tag = "00000000000000000000000000000000" # Replace zeros to your adtag from @mtproxybot
```
#### Listening and Announce IPs
To specify listening address and/or address in links, add to section `[[server.listeners]]` of config.toml:
```toml
[[server.listeners]]
ip = "0.0.0.0" # 0.0.0.0 = all IPs; your IP = specific listening
announce_ip = "1.2.3.4" # IP in links; comment with # if not used
```
#### Upstream Manager
To specify upstream, add to section `[[upstreams]]` of config.toml:
##### Bind on IP
```toml
[[upstreams]]
type = "direct"
weight = 1
enabled = true
interface = "192.168.1.100" # Change to your outgoing IP
```
##### SOCKS4/5 as Upstream
- Without Auth:
```toml
[[upstreams]]
type = "socks5" # Specify SOCKS4 or SOCKS5
address = "1.2.3.4:1234" # SOCKS-server Address
weight = 1 # Set Weight for Scenarios
enabled = true
```
- With Auth:
```toml
[[upstreams]]
type = "socks5" # Specify SOCKS4 or SOCKS5
address = "1.2.3.4:1234" # SOCKS-server Address
username = "user" # Username for Auth on SOCKS-server
password = "pass" # Password for Auth on SOCKS-server
weight = 1 # Set Weight for Scenarios
enabled = true
```
## FAQ
### Recognizability for DPI and crawler
Since version 1.1.0.0, we have debugged masking perfectly: for all clients without "presenting" a key,
we transparently direct traffic to the target host!
- We consider this a breakthrough aspect, which has no stable analogues today
- Based on this: if `telemt` configured correctly, **TLS mode is completely identical to real-life handshake + communication** with a specified host
- Here is our evidence:
- 212.220.88.77 - "dummy" host, running `telemt`
- `petrovich.ru` - `tls` + `masking` host, in HEX: `706574726f766963682e7275`
- **No MITM + No Fake Certificates/Crypto** = pure transparent *TCP Splice* to "best" upstream: MTProxy or tls/mask-host:
- DPI see legitimate HTTPS to `tls_host`, including *valid chain-of-trust* and entropy
- Crawlers completely satisfied receiving responses from `mask_host`
#### Client WITH secret-key accesses the MTProxy resource:
<img width="360" height="439" alt="telemt" src="https://github.com/user-attachments/assets/39352afb-4a11-4ecc-9d91-9e8cfb20607d" />
#### Client WITHOUT secret-key gets transparent access to the specified resource:
- with trusted certificate
- with original handshake
- with full request-response way
- with low-latency overhead
```bash
root@debian:~/telemt# curl -v -I --resolve petrovich.ru:443:212.220.88.77 https://petrovich.ru/
* Added petrovich.ru:443:212.220.88.77 to DNS cache
* Hostname petrovich.ru was found in DNS cache
* Trying 212.220.88.77:443...
* Connected to petrovich.ru (212.220.88.77) port 443 (#0)
* ALPN: offers h2,http/1.1
* TLSv1.3 (OUT), TLS handshake, Client hello (1):
* CAfile: /etc/ssl/certs/ca-certificates.crt
* CApath: /etc/ssl/certs
* TLSv1.3 (IN), TLS handshake, Server hello (2):
* TLSv1.3 (IN), TLS handshake, Encrypted Extensions (8):
* TLSv1.3 (IN), TLS handshake, Certificate (11):
* TLSv1.3 (IN), TLS handshake, CERT verify (15):
* TLSv1.3 (IN), TLS handshake, Finished (20):
* TLSv1.3 (OUT), TLS change cipher, Change cipher spec (1):
* TLSv1.3 (OUT), TLS handshake, Finished (20):
* SSL connection using TLSv1.3 / TLS_AES_256_GCM_SHA384
* ALPN: server did not agree on a protocol. Uses default.
* Server certificate:
* subject: C=RU; ST=Saint Petersburg; L=Saint Petersburg; O=STD Petrovich; CN=*.petrovich.ru
* start date: Jan 28 11:21:01 2025 GMT
* expire date: Mar 1 11:21:00 2026 GMT
* subjectAltName: host "petrovich.ru" matched cert's "petrovich.ru"
* issuer: C=BE; O=GlobalSign nv-sa; CN=GlobalSign RSA OV SSL CA 2018
* SSL certificate verify ok.
* using HTTP/1.x
> HEAD / HTTP/1.1
> Host: petrovich.ru
> User-Agent: curl/7.88.1
> Accept: */*
>
* TLSv1.3 (IN), TLS handshake, Newsession Ticket (4):
* TLSv1.3 (IN), TLS handshake, Newsession Ticket (4):
* old SSL session ID is stale, removing
< HTTP/1.1 200 OK
HTTP/1.1 200 OK
< Server: Variti/0.9.3a
Server: Variti/0.9.3a
< Date: Thu, 01 Jan 2026 00:0000 GMT
Date: Thu, 01 Jan 2026 00:0000 GMT
< Access-Control-Allow-Origin: *
Access-Control-Allow-Origin: *
< Content-Type: text/html
Content-Type: text/html
< Cache-Control: no-store
Cache-Control: no-store
< Expires: Thu, 01 Jan 2026 00:0000 GMT
Expires: Thu, 01 Jan 2026 00:0000 GMT
< Pragma: no-cache
Pragma: no-cache
< Set-Cookie: ipp_uid=XXXXX/XXXXX/XXXXX==; Expires=Tue, 31 Dec 2040 23:59:59 GMT; Domain=.petrovich.ru; Path=/
Set-Cookie: ipp_uid=XXXXX/XXXXX/XXXXX==; Expires=Tue, 31 Dec 2040 23:59:59 GMT; Domain=.petrovich.ru; Path=/
< Content-Type: text/html
Content-Type: text/html
< Content-Length: 31253
Content-Length: 31253
< Connection: keep-alive
Connection: keep-alive
< Keep-Alive: timeout=60
Keep-Alive: timeout=60
<
* Connection #0 to host petrovich.ru left intact
```
- We challenged ourselves, we kept trying and we didn't only *beat the air*: now, we have something to show you
- Do not just take our word for it? - This is great and we respect that: you can build your own `telemt` or download a build and check it right now
### Telegram Calls via MTProxy
- Telegram architecture **does NOT allow calls via MTProxy**, but only via SOCKS5, which cannot be obfuscated
### How does DPI see MTProxy TLS?
- DPI sees MTProxy in Fake TLS (ee) mode as TLS 1.3
- the SNI you specify sends both the client and the server;
- ALPN is similar to HTTP 1.1/2;
- high entropy, which is normal for AES-encrypted traffic;
### Whitelist on IP
- MTProxy cannot work when there is:
- no IP connectivity to the target host: Russian Whitelist on Mobile Networks - "Белый список"
- OR all TCP traffic is blocked
- OR high entropy/encrypted traffic is blocked: content filters at universities and critical infrastructure
- OR all TLS traffic is blocked
- OR specified port is blocked: use 443 to make it "like real"
- OR provided SNI is blocked: use "officially approved"/innocuous name
- like most protocols on the Internet;
- these situations are observed:
- in China behind the Great Firewall
- in Russia on mobile networks, less in wired networks
- in Iran during "activity"
## Build
```bash
# Cloning repo
git clone https://github.com/telemt/telemt
# Changing Directory to telemt
cd telemt
# Starting Release Build
cargo build --release
# Move to /bin
mv ./target/release/telemt /bin
# Make executable
chmod +x /bin/telemt
# Lets go!
telemt config.toml
```
## Why Rust?
- Long-running reliability and idempotent behavior
- Rusts deterministic resource management - RAII
- No garbage collector
- Memory safety and reduced attack surface
- Tokio's asynchronous architecture
## Issues
- ✅ [SOCKS5 as Upstream](https://github.com/telemt/telemt/issues/1) -> added Upstream Management
- ✅ [iOS - Media Upload Hanging-in-Loop](https://github.com/telemt/telemt/issues/2)
## Roadmap
- Public IP in links
- Config Reload-on-fly
- Bind to device or IP for outbound/inbound connections
- Adtag Support per SNI / Secret
- Fail-fast on start + Fail-soft on runtime (only WARN/ERROR)
- Zero-copy, minimal allocs on hotpath
- DC Healthchecks + global fallback
- No global mutable state
- Client isolation + Fair Bandwidth
- Backpressure-aware IO
- "Secret Policy" - SNI / Secret Routing :D
- Multi-upstream Balancer and Failover
- Strict FSM per handshake
- Session-based Antireplay with Sliding window, non-broking reconnects
- Web Control: statistic, state of health, latency, client experience...

View File

@@ -1,13 +1,79 @@
port = 443 # === UI ===
# Users to show in the startup log (tg:// links)
show_link = ["hello"]
[users] # === General Settings ===
user1 = "00000000000000000000000000000000" [general]
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = true
ad_tag = "00000000000000000000000000000000"
[modes] # Log level: debug | verbose | normal | silent
classic = true # Can be overridden with --silent or --log-level CLI flags
secure = true # RUST_LOG env var takes absolute priority over all of these
log_level = "normal"
[general.modes]
classic = false
secure = false
tls = true tls = true
tls_domain = "www.github.com" # === Server Binding ===
fast_mode = true [server]
prefer_ipv6 = false port = 443
listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::"
# metrics_port = 9090
# metrics_whitelist = ["127.0.0.1", "::1"]
# Listen on multiple interfaces/IPs (overrides listen_addr_*)
[[server.listeners]]
ip = "0.0.0.0"
# announce_ip = "1.2.3.4" # Optional: Public IP for tg:// links
[[server.listeners]]
ip = "::"
# === Timeouts (in seconds) ===
[timeouts]
client_handshake = 15
tg_connect = 10
client_keepalive = 60
client_ack = 300
# === Anti-Censorship & Masking ===
[censorship]
tls_domain = "google.ru"
mask = true
mask_port = 443
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
fake_cert_len = 2048
# === Access Control & Users ===
[access]
replay_check_len = 65536
replay_window_secs = 1800
ignore_time_skew = false
[access.users]
# format: "username" = "32_hex_chars_secret"
hello = "00000000000000000000000000000000"
# [access.user_max_tcp_conns]
# hello = 50
# [access.user_data_quota]
# hello = 1073741824 # 1 GB
# === Upstreams & Routing ===
[[upstreams]]
type = "direct"
enabled = true
weight = 10
# [[upstreams]]
# type = "socks5"
# address = "127.0.0.1:9050"
# enabled = false
# weight = 1

300
src/cli.rs Normal file
View File

@@ -0,0 +1,300 @@
//! CLI commands: --init (fire-and-forget setup)
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
use rand::Rng;
/// Options for the init command
pub struct InitOptions {
pub port: u16,
pub domain: String,
pub secret: Option<String>,
pub username: String,
pub config_dir: PathBuf,
pub no_start: bool,
}
impl Default for InitOptions {
fn default() -> Self {
Self {
port: 443,
domain: "www.google.com".to_string(),
secret: None,
username: "user".to_string(),
config_dir: PathBuf::from("/etc/telemt"),
no_start: false,
}
}
}
/// Parse --init subcommand options from CLI args.
///
/// Returns `Some(InitOptions)` if `--init` was found, `None` otherwise.
pub fn parse_init_args(args: &[String]) -> Option<InitOptions> {
if !args.iter().any(|a| a == "--init") {
return None;
}
let mut opts = InitOptions::default();
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--port" => {
i += 1;
if i < args.len() {
opts.port = args[i].parse().unwrap_or(443);
}
}
"--domain" => {
i += 1;
if i < args.len() {
opts.domain = args[i].clone();
}
}
"--secret" => {
i += 1;
if i < args.len() {
opts.secret = Some(args[i].clone());
}
}
"--user" => {
i += 1;
if i < args.len() {
opts.username = args[i].clone();
}
}
"--config-dir" => {
i += 1;
if i < args.len() {
opts.config_dir = PathBuf::from(&args[i]);
}
}
"--no-start" => {
opts.no_start = true;
}
_ => {}
}
i += 1;
}
Some(opts)
}
/// Run the fire-and-forget setup.
pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> {
eprintln!("[telemt] Fire-and-forget setup");
eprintln!();
// 1. Generate or validate secret
let secret = match opts.secret {
Some(s) => {
if s.len() != 32 || !s.chars().all(|c| c.is_ascii_hexdigit()) {
eprintln!("[error] Secret must be exactly 32 hex characters");
std::process::exit(1);
}
s
}
None => generate_secret(),
};
eprintln!("[+] Secret: {}", secret);
eprintln!("[+] User: {}", opts.username);
eprintln!("[+] Port: {}", opts.port);
eprintln!("[+] Domain: {}", opts.domain);
// 2. Create config directory
fs::create_dir_all(&opts.config_dir)?;
let config_path = opts.config_dir.join("config.toml");
// 3. Write config
let config_content = generate_config(&opts.username, &secret, opts.port, &opts.domain);
fs::write(&config_path, &config_content)?;
eprintln!("[+] Config written to {}", config_path.display());
// 4. Write systemd unit
let exe_path = std::env::current_exe()
.unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt"));
let unit_path = Path::new("/etc/systemd/system/telemt.service");
let unit_content = generate_systemd_unit(&exe_path, &config_path);
match fs::write(unit_path, &unit_content) {
Ok(()) => {
eprintln!("[+] Systemd unit written to {}", unit_path.display());
}
Err(e) => {
eprintln!("[!] Cannot write systemd unit (run as root?): {}", e);
eprintln!("[!] Manual unit file content:");
eprintln!("{}", unit_content);
// Still print links and config
print_links(&opts.username, &secret, opts.port, &opts.domain);
return Ok(());
}
}
// 5. Reload systemd
run_cmd("systemctl", &["daemon-reload"]);
// 6. Enable service
run_cmd("systemctl", &["enable", "telemt.service"]);
eprintln!("[+] Service enabled");
// 7. Start service (unless --no-start)
if !opts.no_start {
run_cmd("systemctl", &["start", "telemt.service"]);
eprintln!("[+] Service started");
// Brief delay then check status
std::thread::sleep(std::time::Duration::from_secs(1));
let status = Command::new("systemctl")
.args(["is-active", "telemt.service"])
.output();
match status {
Ok(out) if out.status.success() => {
eprintln!("[+] Service is running");
}
_ => {
eprintln!("[!] Service may not have started correctly");
eprintln!("[!] Check: journalctl -u telemt.service -n 20");
}
}
} else {
eprintln!("[+] Service not started (--no-start)");
eprintln!("[+] Start manually: systemctl start telemt.service");
}
eprintln!();
// 8. Print links
print_links(&opts.username, &secret, opts.port, &opts.domain);
Ok(())
}
fn generate_secret() -> String {
let mut rng = rand::rng();
let bytes: Vec<u8> = (0..16).map(|_| rng.random::<u8>()).collect();
hex::encode(bytes)
}
fn generate_config(username: &str, secret: &str, port: u16, domain: &str) -> String {
format!(
r#"# Telemt MTProxy — auto-generated config
# Re-run `telemt --init` to regenerate
show_link = ["{username}"]
[general]
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = false
log_level = "normal"
[general.modes]
classic = false
secure = false
tls = true
[server]
port = {port}
listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::"
[[server.listeners]]
ip = "0.0.0.0"
[[server.listeners]]
ip = "::"
[timeouts]
client_handshake = 15
tg_connect = 10
client_keepalive = 60
client_ack = 300
[censorship]
tls_domain = "{domain}"
mask = true
mask_port = 443
fake_cert_len = 2048
[access]
replay_check_len = 65536
replay_window_secs = 1800
ignore_time_skew = false
[access.users]
{username} = "{secret}"
[[upstreams]]
type = "direct"
enabled = true
weight = 10
"#,
username = username,
secret = secret,
port = port,
domain = domain,
)
}
fn generate_systemd_unit(exe_path: &Path, config_path: &Path) -> String {
format!(
r#"[Unit]
Description=Telemt MTProxy
Documentation=https://github.com/nicepkg/telemt
After=network-online.target
Wants=network-online.target
[Service]
Type=simple
ExecStart={exe} {config}
Restart=always
RestartSec=5
LimitNOFILE=65535
# Security hardening
NoNewPrivileges=true
ProtectSystem=strict
ProtectHome=true
ReadWritePaths=/etc/telemt
PrivateTmp=true
[Install]
WantedBy=multi-user.target
"#,
exe = exe_path.display(),
config = config_path.display(),
)
}
fn run_cmd(cmd: &str, args: &[&str]) {
match Command::new(cmd).args(args).output() {
Ok(output) => {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
eprintln!("[!] {} {} failed: {}", cmd, args.join(" "), stderr.trim());
}
}
Err(e) => {
eprintln!("[!] Failed to run {} {}: {}", cmd, args.join(" "), e);
}
}
}
fn print_links(username: &str, secret: &str, port: u16, domain: &str) {
let domain_hex = hex::encode(domain);
println!("=== Proxy Links ===");
println!("[{}]", username);
println!(" EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}",
port, secret, domain_hex);
println!();
println!("Replace YOUR_SERVER_IP with your server's public IP.");
println!("The proxy will auto-detect and display the correct link on startup.");
println!("Check: journalctl -u telemt.service | head -30");
println!("===================");
}

View File

@@ -1,12 +1,89 @@
//! Configuration //! Configuration
use std::collections::HashMap; use std::collections::HashMap;
use std::net::IpAddr; use std::net::{IpAddr, SocketAddr};
use std::path::Path; use std::path::Path;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
// ============= Helper Defaults =============
fn default_true() -> bool { true }
fn default_port() -> u16 { 443 }
fn default_tls_domain() -> String { "www.google.com".to_string() }
fn default_mask_port() -> u16 { 443 }
fn default_replay_check_len() -> usize { 65536 }
fn default_replay_window_secs() -> u64 { 1800 }
fn default_handshake_timeout() -> u64 { 15 }
fn default_connect_timeout() -> u64 { 10 }
fn default_keepalive() -> u64 { 60 }
fn default_ack_timeout() -> u64 { 300 }
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
fn default_fake_cert_len() -> usize { 2048 }
fn default_weight() -> u16 { 1 }
fn default_metrics_whitelist() -> Vec<IpAddr> {
vec![
"127.0.0.1".parse().unwrap(),
"::1".parse().unwrap(),
]
}
// ============= Log Level =============
/// Logging verbosity level
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum LogLevel {
/// All messages including trace (trace + debug + info + warn + error)
Debug,
/// Detailed operational logs (debug + info + warn + error)
Verbose,
/// Standard operational logs (info + warn + error)
#[default]
Normal,
/// Minimal output: only warnings and errors (warn + error).
/// Startup messages (config, DC connectivity, proxy links) are always shown
/// via info! before the filter is applied.
Silent,
}
impl LogLevel {
/// Convert to tracing EnvFilter directive string
pub fn to_filter_str(&self) -> &'static str {
match self {
LogLevel::Debug => "trace",
LogLevel::Verbose => "debug",
LogLevel::Normal => "info",
LogLevel::Silent => "warn",
}
}
/// Parse from a loose string (CLI argument)
pub fn from_str_loose(s: &str) -> Self {
match s.to_lowercase().as_str() {
"debug" | "trace" => LogLevel::Debug,
"verbose" => LogLevel::Verbose,
"normal" | "info" => LogLevel::Normal,
"silent" | "quiet" | "error" | "warn" => LogLevel::Silent,
_ => LogLevel::Normal,
}
}
}
impl std::fmt::Display for LogLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LogLevel::Debug => write!(f, "debug"),
LogLevel::Verbose => write!(f, "verbose"),
LogLevel::Normal => write!(f, "normal"),
LogLevel::Silent => write!(f, "silent"),
}
}
}
// ============= Sub-Configs =============
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyModes { pub struct ProxyModes {
#[serde(default)] #[serde(default)]
@@ -17,8 +94,6 @@ pub struct ProxyModes {
pub tls: bool, pub tls: bool,
} }
fn default_true() -> bool { true }
impl Default for ProxyModes { impl Default for ProxyModes {
fn default() -> Self { fn default() -> Self {
Self { classic: true, secure: true, tls: true } Self { classic: true, secure: true, tls: true }
@@ -26,31 +101,10 @@ impl Default for ProxyModes {
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyConfig { pub struct GeneralConfig {
#[serde(default = "default_port")]
pub port: u16,
#[serde(default)]
pub users: HashMap<String, String>,
#[serde(default)]
pub ad_tag: Option<String>,
#[serde(default)] #[serde(default)]
pub modes: ProxyModes, pub modes: ProxyModes,
#[serde(default = "default_tls_domain")]
pub tls_domain: String,
#[serde(default = "default_true")]
pub mask: bool,
#[serde(default)]
pub mask_host: Option<String>,
#[serde(default = "default_mask_port")]
pub mask_port: u16,
#[serde(default)] #[serde(default)]
pub prefer_ipv6: bool, pub prefer_ipv6: bool,
@@ -59,34 +113,32 @@ pub struct ProxyConfig {
#[serde(default)] #[serde(default)]
pub use_middle_proxy: bool, pub use_middle_proxy: bool,
#[serde(default)]
pub ad_tag: Option<String>,
#[serde(default)] #[serde(default)]
pub user_max_tcp_conns: HashMap<String, usize>, pub log_level: LogLevel,
}
#[serde(default)]
pub user_expirations: HashMap<String, DateTime<Utc>>, impl Default for GeneralConfig {
fn default() -> Self {
#[serde(default)] Self {
pub user_data_quota: HashMap<String, u64>, modes: ProxyModes::default(),
prefer_ipv6: false,
#[serde(default = "default_replay_check_len")] fast_mode: true,
pub replay_check_len: usize, use_middle_proxy: false,
ad_tag: None,
#[serde(default)] log_level: LogLevel::Normal,
pub ignore_time_skew: bool, }
}
#[serde(default = "default_handshake_timeout")] }
pub client_handshake_timeout: u64,
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default = "default_connect_timeout")] pub struct ServerConfig {
pub tg_connect_timeout: u64, #[serde(default = "default_port")]
pub port: u16,
#[serde(default = "default_keepalive")]
pub client_keepalive: u64,
#[serde(default = "default_ack_timeout")]
pub client_ack_timeout: u64,
#[serde(default = "default_listen_addr")] #[serde(default = "default_listen_addr")]
pub listen_addr_ipv4: String, pub listen_addr_ipv4: String,
@@ -101,65 +153,206 @@ pub struct ProxyConfig {
#[serde(default = "default_metrics_whitelist")] #[serde(default = "default_metrics_whitelist")]
pub metrics_whitelist: Vec<IpAddr>, pub metrics_whitelist: Vec<IpAddr>,
#[serde(default = "default_fake_cert_len")] #[serde(default)]
pub fake_cert_len: usize, pub listeners: Vec<ListenerConfig>,
} }
fn default_port() -> u16 { 443 } impl Default for ServerConfig {
fn default_tls_domain() -> String { "www.google.com".to_string() }
fn default_mask_port() -> u16 { 443 }
fn default_replay_check_len() -> usize { 65536 }
fn default_handshake_timeout() -> u64 { 10 }
fn default_connect_timeout() -> u64 { 10 }
fn default_keepalive() -> u64 { 600 }
fn default_ack_timeout() -> u64 { 300 }
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
fn default_fake_cert_len() -> usize { 2048 }
fn default_metrics_whitelist() -> Vec<IpAddr> {
vec![
"127.0.0.1".parse().unwrap(),
"::1".parse().unwrap(),
]
}
impl Default for ProxyConfig {
fn default() -> Self { fn default() -> Self {
let mut users = HashMap::new();
users.insert("default".to_string(), "00000000000000000000000000000000".to_string());
Self { Self {
port: default_port(), port: default_port(),
users,
ad_tag: None,
modes: ProxyModes::default(),
tls_domain: default_tls_domain(),
mask: true,
mask_host: None,
mask_port: default_mask_port(),
prefer_ipv6: false,
fast_mode: true,
use_middle_proxy: false,
user_max_tcp_conns: HashMap::new(),
user_expirations: HashMap::new(),
user_data_quota: HashMap::new(),
replay_check_len: default_replay_check_len(),
ignore_time_skew: false,
client_handshake_timeout: default_handshake_timeout(),
tg_connect_timeout: default_connect_timeout(),
client_keepalive: default_keepalive(),
client_ack_timeout: default_ack_timeout(),
listen_addr_ipv4: default_listen_addr(), listen_addr_ipv4: default_listen_addr(),
listen_addr_ipv6: Some("::".to_string()), listen_addr_ipv6: Some("::".to_string()),
listen_unix_sock: None, listen_unix_sock: None,
metrics_port: None, metrics_port: None,
metrics_whitelist: default_metrics_whitelist(), metrics_whitelist: default_metrics_whitelist(),
listeners: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimeoutsConfig {
#[serde(default = "default_handshake_timeout")]
pub client_handshake: u64,
#[serde(default = "default_connect_timeout")]
pub tg_connect: u64,
#[serde(default = "default_keepalive")]
pub client_keepalive: u64,
#[serde(default = "default_ack_timeout")]
pub client_ack: u64,
}
impl Default for TimeoutsConfig {
fn default() -> Self {
Self {
client_handshake: default_handshake_timeout(),
tg_connect: default_connect_timeout(),
client_keepalive: default_keepalive(),
client_ack: default_ack_timeout(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AntiCensorshipConfig {
#[serde(default = "default_tls_domain")]
pub tls_domain: String,
#[serde(default = "default_true")]
pub mask: bool,
#[serde(default)]
pub mask_host: Option<String>,
#[serde(default = "default_mask_port")]
pub mask_port: u16,
#[serde(default = "default_fake_cert_len")]
pub fake_cert_len: usize,
}
impl Default for AntiCensorshipConfig {
fn default() -> Self {
Self {
tls_domain: default_tls_domain(),
mask: true,
mask_host: None,
mask_port: default_mask_port(),
fake_cert_len: default_fake_cert_len(), fake_cert_len: default_fake_cert_len(),
} }
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessConfig {
#[serde(default)]
pub users: HashMap<String, String>,
#[serde(default)]
pub user_max_tcp_conns: HashMap<String, usize>,
#[serde(default)]
pub user_expirations: HashMap<String, DateTime<Utc>>,
#[serde(default)]
pub user_data_quota: HashMap<String, u64>,
#[serde(default = "default_replay_check_len")]
pub replay_check_len: usize,
#[serde(default = "default_replay_window_secs")]
pub replay_window_secs: u64,
#[serde(default)]
pub ignore_time_skew: bool,
}
impl Default for AccessConfig {
fn default() -> Self {
let mut users = HashMap::new();
users.insert("default".to_string(), "00000000000000000000000000000000".to_string());
Self {
users,
user_max_tcp_conns: HashMap::new(),
user_expirations: HashMap::new(),
user_data_quota: HashMap::new(),
replay_check_len: default_replay_check_len(),
replay_window_secs: default_replay_window_secs(),
ignore_time_skew: false,
}
}
}
// ============= Aux Structures =============
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum UpstreamType {
Direct {
#[serde(default)]
interface: Option<String>,
},
Socks4 {
address: String,
#[serde(default)]
interface: Option<String>,
#[serde(default)]
user_id: Option<String>,
},
Socks5 {
address: String,
#[serde(default)]
interface: Option<String>,
#[serde(default)]
username: Option<String>,
#[serde(default)]
password: Option<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpstreamConfig {
#[serde(flatten)]
pub upstream_type: UpstreamType,
#[serde(default = "default_weight")]
pub weight: u16,
#[serde(default = "default_true")]
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListenerConfig {
pub ip: IpAddr,
#[serde(default)]
pub announce_ip: Option<IpAddr>,
}
// ============= Main Config =============
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ProxyConfig {
#[serde(default)]
pub general: GeneralConfig,
#[serde(default)]
pub server: ServerConfig,
#[serde(default)]
pub timeouts: TimeoutsConfig,
#[serde(default)]
pub censorship: AntiCensorshipConfig,
#[serde(default)]
pub access: AccessConfig,
#[serde(default)]
pub upstreams: Vec<UpstreamConfig>,
#[serde(default)]
pub show_link: Vec<String>,
/// DC address overrides for non-standard DCs (CDN, media, test, etc.)
/// Keys are DC indices as strings, values are "ip:port" addresses.
/// Matches the C implementation's `proxy_for <dc_id> <ip>:<port>` config directive.
/// Example in config.toml:
/// [dc_overrides]
/// "203" = "149.154.175.100:443"
#[serde(default)]
pub dc_overrides: HashMap<String, String>,
/// Default DC index (1-5) for unmapped non-standard DCs.
/// Matches the C implementation's `default <dc_id>` config directive.
/// If not set, defaults to 2 (matching Telegram's official `default 2;` in proxy-multi.conf).
#[serde(default)]
pub default_dc: Option<u8>,
}
impl ProxyConfig { impl ProxyConfig {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = std::fs::read_to_string(path) let content = std::fs::read_to_string(path)
@@ -169,7 +362,7 @@ impl ProxyConfig {
.map_err(|e| ProxyError::Config(e.to_string()))?; .map_err(|e| ProxyError::Config(e.to_string()))?;
// Validate secrets // Validate secrets
for (user, secret) in &config.users { for (user, secret) in &config.access.users {
if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 { if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 {
return Err(ProxyError::InvalidSecret { return Err(ProxyError::InvalidSecret {
user: user.clone(), user: user.clone(),
@@ -178,50 +371,65 @@ impl ProxyConfig {
} }
} }
// Default mask_host // Validate tls_domain
if config.mask_host.is_none() { if config.censorship.tls_domain.is_empty() {
config.mask_host = Some(config.tls_domain.clone()); return Err(ProxyError::Config("tls_domain cannot be empty".to_string()));
}
// Default mask_host to tls_domain if not set
if config.censorship.mask_host.is_none() {
config.censorship.mask_host = Some(config.censorship.tls_domain.clone());
} }
// Random fake_cert_len // Random fake_cert_len
use rand::Rng; use rand::Rng;
config.fake_cert_len = rand::thread_rng().gen_range(1024..4096); config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096);
// Migration: Populate listeners if empty
if config.server.listeners.is_empty() {
if let Ok(ipv4) = config.server.listen_addr_ipv4.parse::<IpAddr>() {
config.server.listeners.push(ListenerConfig {
ip: ipv4,
announce_ip: None,
});
}
if let Some(ipv6_str) = &config.server.listen_addr_ipv6 {
if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() {
config.server.listeners.push(ListenerConfig {
ip: ipv6,
announce_ip: None,
});
}
}
}
// Migration: Populate upstreams if empty (Default Direct)
if config.upstreams.is_empty() {
config.upstreams.push(UpstreamConfig {
upstream_type: UpstreamType::Direct { interface: None },
weight: 1,
enabled: true,
});
}
Ok(config) Ok(config)
} }
pub fn validate(&self) -> Result<()> { pub fn validate(&self) -> Result<()> {
if self.users.is_empty() { if self.access.users.is_empty() {
return Err(ProxyError::Config("No users configured".to_string())); return Err(ProxyError::Config("No users configured".to_string()));
} }
if !self.modes.classic && !self.modes.secure && !self.modes.tls { if !self.general.modes.classic && !self.general.modes.secure && !self.general.modes.tls {
return Err(ProxyError::Config("No modes enabled".to_string())); return Err(ProxyError::Config("No modes enabled".to_string()));
} }
if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') {
return Err(ProxyError::Config(
format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain)
));
}
Ok(()) Ok(())
} }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = ProxyConfig::default();
assert_eq!(config.port, 443);
assert!(config.modes.tls);
assert_eq!(config.client_keepalive, 600);
assert_eq!(config.client_ack_timeout, 300);
}
#[test]
fn test_config_validate() {
let mut config = ProxyConfig::default();
assert!(config.validate().is_ok());
config.users.clear();
assert!(config.validate().is_err());
}
} }

View File

@@ -1,21 +1,39 @@
//! AES //! AES encryption implementations
//!
//! Provides AES-256-CTR and AES-256-CBC modes for MTProto encryption.
//!
//! ## Zeroize policy
//!
//! - `AesCbc` stores raw key/IV bytes and zeroizes them on drop.
//! - `AesCtr` wraps an opaque `Aes256Ctr` cipher from the `ctr` crate.
//! The expanded key schedule lives inside that type and cannot be
//! zeroized from outside. Callers that hold raw key material (e.g.
//! `HandshakeSuccess`, `ObfuscationParams`) are responsible for
//! zeroizing their own copies.
use aes::Aes256; use aes::Aes256;
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}}; use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
use cbc::{Encryptor as CbcEncryptor, Decryptor as CbcDecryptor}; use zeroize::Zeroize;
use cbc::cipher::{BlockEncryptMut, BlockDecryptMut, block_padding::NoPadding};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
type Aes256Ctr = Ctr128BE<Aes256>; type Aes256Ctr = Ctr128BE<Aes256>;
type Aes256CbcEnc = CbcEncryptor<Aes256>;
type Aes256CbcDec = CbcDecryptor<Aes256>; // ============= AES-256-CTR =============
/// AES-256-CTR encryptor/decryptor /// AES-256-CTR encryptor/decryptor
///
/// CTR mode is symmetric — encryption and decryption are the same operation.
///
/// **Zeroize note:** The inner `Aes256Ctr` cipher state (expanded key schedule
/// + counter) is opaque and cannot be zeroized. If you need to protect key
/// material, zeroize the `[u8; 32]` key and `u128` IV at the call site
/// before dropping them.
pub struct AesCtr { pub struct AesCtr {
cipher: Aes256Ctr, cipher: Aes256Ctr,
} }
impl AesCtr { impl AesCtr {
/// Create new AES-CTR cipher with key and IV
pub fn new(key: &[u8; 32], iv: u128) -> Self { pub fn new(key: &[u8; 32], iv: u128) -> Self {
let iv_bytes = iv.to_be_bytes(); let iv_bytes = iv.to_be_bytes();
Self { Self {
@@ -23,6 +41,7 @@ impl AesCtr {
} }
} }
/// Create from key and IV slices
pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result<Self> { pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result<Self> {
if key.len() != 32 { if key.len() != 32 {
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() }); return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
@@ -54,17 +73,37 @@ impl AesCtr {
} }
} }
/// AES-256-CBC Ciphermagic // ============= AES-256-CBC =============
/// AES-256-CBC cipher with proper chaining
///
/// Unlike CTR mode, CBC is NOT symmetric — encryption and decryption
/// are different operations. This implementation handles CBC chaining
/// correctly across multiple blocks.
///
/// Key and IV are zeroized on drop.
pub struct AesCbc { pub struct AesCbc {
key: [u8; 32], key: [u8; 32],
iv: [u8; 16], iv: [u8; 16],
} }
impl Drop for AesCbc {
fn drop(&mut self) {
self.key.zeroize();
self.iv.zeroize();
}
}
impl AesCbc { impl AesCbc {
/// AES block size
const BLOCK_SIZE: usize = 16;
/// Create new AES-CBC cipher with key and IV
pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self { pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self {
Self { key, iv } Self { key, iv }
} }
/// Create from slices
pub fn from_slices(key: &[u8], iv: &[u8]) -> Result<Self> { pub fn from_slices(key: &[u8], iv: &[u8]) -> Result<Self> {
if key.len() != 32 { if key.len() != 32 {
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() }); return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
@@ -79,32 +118,36 @@ impl AesCbc {
}) })
} }
/// Encrypt data using CBC mode /// Encrypt a single block using raw AES (no chaining)
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> { fn encrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] {
if data.len() % 16 != 0 { use aes::cipher::BlockEncrypt;
return Err(ProxyError::Crypto( let mut output = *block;
format!("CBC data must be aligned to 16 bytes, got {}", data.len()) key_schedule.encrypt_block((&mut output).into());
)); output
}
if data.is_empty() {
return Ok(Vec::new());
}
let mut buffer = data.to_vec();
let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into());
for chunk in buffer.chunks_mut(16) {
encryptor.encrypt_block_mut(chunk.into());
}
Ok(buffer)
} }
/// Decrypt data using CBC mode /// Decrypt a single block using raw AES (no chaining)
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> { fn decrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] {
if data.len() % 16 != 0 { use aes::cipher::BlockDecrypt;
let mut output = *block;
key_schedule.decrypt_block((&mut output).into());
output
}
/// XOR two 16-byte blocks
fn xor_blocks(a: &[u8; 16], b: &[u8; 16]) -> [u8; 16] {
let mut result = [0u8; 16];
for i in 0..16 {
result[i] = a[i] ^ b[i];
}
result
}
/// Encrypt data using CBC mode with proper chaining
///
/// CBC Encryption: C[i] = AES_Encrypt(P[i] XOR C[i-1]), where C[-1] = IV
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() % Self::BLOCK_SIZE != 0 {
return Err(ProxyError::Crypto( return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len()) format!("CBC data must be aligned to 16 bytes, got {}", data.len())
)); ));
@@ -114,20 +157,57 @@ impl AesCbc {
return Ok(Vec::new()); return Ok(Vec::new());
} }
let mut buffer = data.to_vec(); use aes::cipher::KeyInit;
let key_schedule = aes::Aes256::new((&self.key).into());
let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into()); let mut result = Vec::with_capacity(data.len());
let mut prev_ciphertext = self.iv;
for chunk in buffer.chunks_mut(16) { for chunk in data.chunks(Self::BLOCK_SIZE) {
decryptor.decrypt_block_mut(chunk.into()); let plaintext: [u8; 16] = chunk.try_into().unwrap();
let xored = Self::xor_blocks(&plaintext, &prev_ciphertext);
let ciphertext = self.encrypt_block(&xored, &key_schedule);
prev_ciphertext = ciphertext;
result.extend_from_slice(&ciphertext);
} }
Ok(buffer) Ok(result)
}
/// Decrypt data using CBC mode with proper chaining
///
/// CBC Decryption: P[i] = AES_Decrypt(C[i]) XOR C[i-1], where C[-1] = IV
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() % Self::BLOCK_SIZE != 0 {
return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
));
}
if data.is_empty() {
return Ok(Vec::new());
}
use aes::cipher::KeyInit;
let key_schedule = aes::Aes256::new((&self.key).into());
let mut result = Vec::with_capacity(data.len());
let mut prev_ciphertext = self.iv;
for chunk in data.chunks(Self::BLOCK_SIZE) {
let ciphertext: [u8; 16] = chunk.try_into().unwrap();
let decrypted = self.decrypt_block(&ciphertext, &key_schedule);
let plaintext = Self::xor_blocks(&decrypted, &prev_ciphertext);
prev_ciphertext = ciphertext;
result.extend_from_slice(&plaintext);
}
Ok(result)
} }
/// Encrypt data in-place /// Encrypt data in-place
pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> { pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
if data.len() % 16 != 0 { if data.len() % Self::BLOCK_SIZE != 0 {
return Err(ProxyError::Crypto( return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len()) format!("CBC data must be aligned to 16 bytes, got {}", data.len())
)); ));
@@ -137,10 +217,22 @@ impl AesCbc {
return Ok(()); return Ok(());
} }
let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into()); use aes::cipher::KeyInit;
let key_schedule = aes::Aes256::new((&self.key).into());
for chunk in data.chunks_mut(16) { let mut prev_ciphertext = self.iv;
encryptor.encrypt_block_mut(chunk.into());
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
let block = &mut data[i..i + Self::BLOCK_SIZE];
for j in 0..Self::BLOCK_SIZE {
block[j] ^= prev_ciphertext[j];
}
let block_array: &mut [u8; 16] = block.try_into().unwrap();
*block_array = self.encrypt_block(block_array, &key_schedule);
prev_ciphertext = *block_array;
} }
Ok(()) Ok(())
@@ -148,7 +240,7 @@ impl AesCbc {
/// Decrypt data in-place /// Decrypt data in-place
pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> { pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
if data.len() % 16 != 0 { if data.len() % Self::BLOCK_SIZE != 0 {
return Err(ProxyError::Crypto( return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len()) format!("CBC data must be aligned to 16 bytes, got {}", data.len())
)); ));
@@ -158,16 +250,32 @@ impl AesCbc {
return Ok(()); return Ok(());
} }
let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into()); use aes::cipher::KeyInit;
let key_schedule = aes::Aes256::new((&self.key).into());
for chunk in data.chunks_mut(16) { let mut prev_ciphertext = self.iv;
decryptor.decrypt_block_mut(chunk.into());
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
let block = &mut data[i..i + Self::BLOCK_SIZE];
let current_ciphertext: [u8; 16] = block.try_into().unwrap();
let block_array: &mut [u8; 16] = block.try_into().unwrap();
*block_array = self.decrypt_block(block_array, &key_schedule);
for j in 0..Self::BLOCK_SIZE {
block[j] ^= prev_ciphertext[j];
}
prev_ciphertext = current_ciphertext;
} }
Ok(()) Ok(())
} }
} }
// ============= Encryption Traits =============
/// Trait for unified encryption interface /// Trait for unified encryption interface
pub trait Encryptor: Send + Sync { pub trait Encryptor: Send + Sync {
fn encrypt(&mut self, data: &[u8]) -> Vec<u8>; fn encrypt(&mut self, data: &[u8]) -> Vec<u8>;
@@ -209,6 +317,8 @@ impl Decryptor for PassthroughEncryptor {
mod tests { mod tests {
use super::*; use super::*;
// ============= AES-CTR Tests =============
#[test] #[test]
fn test_aes_ctr_roundtrip() { fn test_aes_ctr_roundtrip() {
let key = [0u8; 32]; let key = [0u8; 32];
@@ -225,12 +335,32 @@ mod tests {
assert_eq!(original.as_slice(), decrypted.as_slice()); assert_eq!(original.as_slice(), decrypted.as_slice());
} }
#[test]
fn test_aes_ctr_in_place() {
let key = [0x42u8; 32];
let iv = 999u128;
let original = b"Test data for in-place encryption";
let mut data = original.to_vec();
let mut cipher = AesCtr::new(&key, iv);
cipher.apply(&mut data);
assert_ne!(&data[..], original);
let mut cipher = AesCtr::new(&key, iv);
cipher.apply(&mut data);
assert_eq!(&data[..], original);
}
// ============= AES-CBC Tests =============
#[test] #[test]
fn test_aes_cbc_roundtrip() { fn test_aes_cbc_roundtrip() {
let key = [0u8; 32]; let key = [0u8; 32];
let iv = [0u8; 16]; let iv = [0u8; 16];
// Must be aligned to 16 bytes
let original = [0u8; 32]; let original = [0u8; 32];
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
@@ -244,45 +374,48 @@ mod tests {
fn test_aes_cbc_chaining_works() { fn test_aes_cbc_chaining_works() {
let key = [0x42u8; 32]; let key = [0x42u8; 32];
let iv = [0x00u8; 16]; let iv = [0x00u8; 16];
let plaintext = [0xAA_u8; 32]; let plaintext = [0xAAu8; 32];
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap(); let ciphertext = cipher.encrypt(&plaintext).unwrap();
// CBC Corrections
let block1 = &ciphertext[0..16]; let block1 = &ciphertext[0..16];
let block2 = &ciphertext[16..32]; let block2 = &ciphertext[16..32];
assert_ne!(block1, block2, "CBC chaining broken: identical plaintext blocks produced identical ciphertext"); assert_ne!(
block1, block2,
"CBC chaining broken: identical plaintext blocks produced identical ciphertext"
);
} }
#[test] #[test]
fn test_aes_cbc_known_vector() { fn test_aes_cbc_known_vector() {
let key = [0u8; 32]; let key = [0u8; 32];
let iv = [0u8; 16]; let iv = [0u8; 16];
let plaintext = [0u8; 16];
// 3 Datablocks
let plaintext = [
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF,
// Block 2
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF,
// Block 3 - different
0xFF, 0xEE, 0xDD, 0xCC, 0xBB, 0xAA, 0x99, 0x88,
0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11, 0x00,
];
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap(); let ciphertext = cipher.encrypt(&plaintext).unwrap();
// Decrypt + Verify
let decrypted = cipher.decrypt(&ciphertext).unwrap(); let decrypted = cipher.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice()); assert_eq!(plaintext.as_slice(), decrypted.as_slice());
// Verify Ciphertexts Block 1 != Block 2 assert_ne!(ciphertext.as_slice(), plaintext.as_slice());
assert_ne!(&ciphertext[0..16], &ciphertext[16..32]); }
#[test]
fn test_aes_cbc_multi_block() {
let key = [0x12u8; 32];
let iv = [0x34u8; 16];
let plaintext: Vec<u8> = (0..80).collect();
let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap();
let decrypted = cipher.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext, decrypted);
} }
#[test] #[test]
@@ -290,8 +423,8 @@ mod tests {
let key = [0x12u8; 32]; let key = [0x12u8; 32];
let iv = [0x34u8; 16]; let iv = [0x34u8; 16];
let original = [0x56u8; 48]; // 3 blocks let original = [0x56u8; 48];
let mut buffer = original.clone(); let mut buffer = original;
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
@@ -317,35 +450,93 @@ mod tests {
fn test_aes_cbc_unaligned_error() { fn test_aes_cbc_unaligned_error() {
let cipher = AesCbc::new([0u8; 32], [0u8; 16]); let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
// 15 bytes
let result = cipher.encrypt(&[0u8; 15]); let result = cipher.encrypt(&[0u8; 15]);
assert!(result.is_err()); assert!(result.is_err());
// 17 bytes
let result = cipher.encrypt(&[0u8; 17]); let result = cipher.encrypt(&[0u8; 17]);
assert!(result.is_err()); assert!(result.is_err());
} }
#[test] #[test]
fn test_aes_cbc_avalanche_effect() { fn test_aes_cbc_avalanche_effect() {
// Cipherplane
let key = [0xAB; 32]; let key = [0xAB; 32];
let iv = [0xCD; 16]; let iv = [0xCD; 16];
let mut plaintext1 = [0u8; 32]; let plaintext1 = [0u8; 32];
let mut plaintext2 = [0u8; 32]; let mut plaintext2 = [0u8; 32];
plaintext2[0] = 0x01; // Один бит отличается plaintext2[0] = 0x01;
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let ciphertext1 = cipher.encrypt(&plaintext1).unwrap(); let ciphertext1 = cipher.encrypt(&plaintext1).unwrap();
let ciphertext2 = cipher.encrypt(&plaintext2).unwrap(); let ciphertext2 = cipher.encrypt(&plaintext2).unwrap();
// First Blocks Diff
assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]); assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]);
// Second Blocks Diff
assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]); assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]);
} }
#[test]
fn test_aes_cbc_iv_matters() {
let key = [0x55; 32];
let plaintext = [0x77u8; 16];
let cipher1 = AesCbc::new(key, [0u8; 16]);
let cipher2 = AesCbc::new(key, [1u8; 16]);
let ciphertext1 = cipher1.encrypt(&plaintext).unwrap();
let ciphertext2 = cipher2.encrypt(&plaintext).unwrap();
assert_ne!(ciphertext1, ciphertext2);
}
#[test]
fn test_aes_cbc_deterministic() {
let key = [0x99; 32];
let iv = [0x88; 16];
let plaintext = [0x77u8; 32];
let cipher = AesCbc::new(key, iv);
let ciphertext1 = cipher.encrypt(&plaintext).unwrap();
let ciphertext2 = cipher.encrypt(&plaintext).unwrap();
assert_eq!(ciphertext1, ciphertext2);
}
// ============= Zeroize Tests =============
#[test]
fn test_aes_cbc_zeroize_on_drop() {
let key = [0xAA; 32];
let iv = [0xBB; 16];
let cipher = AesCbc::new(key, iv);
// Verify key/iv are set
assert_eq!(cipher.key, [0xAA; 32]);
assert_eq!(cipher.iv, [0xBB; 16]);
drop(cipher);
// After drop, key/iv are zeroized (can't observe directly,
// but the Drop impl runs without panic)
}
// ============= Error Handling Tests =============
#[test]
fn test_invalid_key_length() {
let result = AesCtr::from_key_iv(&[0u8; 16], &[0u8; 16]);
assert!(result.is_err());
let result = AesCbc::from_slices(&[0u8; 16], &[0u8; 16]);
assert!(result.is_err());
}
#[test]
fn test_invalid_iv_length() {
let result = AesCtr::from_key_iv(&[0u8; 32], &[0u8; 8]);
assert!(result.is_err());
let result = AesCbc::from_slices(&[0u8; 32], &[0u8; 8]);
assert!(result.is_err());
}
} }

View File

@@ -1,3 +1,16 @@
//! Cryptographic hash functions
//!
//! ## Protocol-required algorithms
//!
//! This module exposes MD5 and SHA-1 alongside SHA-256. These weaker
//! hash functions are **required by the Telegram Middle Proxy protocol**
//! (`derive_middleproxy_keys`) and cannot be replaced without breaking
//! compatibility. They are NOT used for any security-sensitive purpose
//! outside of that specific key derivation scheme mandated by Telegram.
//!
//! Static analysis tools (CodeQL, cargo-audit) may flag them — the
//! usages are intentional and protocol-mandated.
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
use sha2::Sha256; use sha2::Sha256;
use md5::Md5; use md5::Md5;
@@ -21,14 +34,16 @@ pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] {
mac.finalize().into_bytes().into() mac.finalize().into_bytes().into()
} }
/// SHA-1 /// SHA-1 — **protocol-required** by Telegram Middle Proxy key derivation.
/// Not used for general-purpose hashing.
pub fn sha1(data: &[u8]) -> [u8; 20] { pub fn sha1(data: &[u8]) -> [u8; 20] {
let mut hasher = Sha1::new(); let mut hasher = Sha1::new();
hasher.update(data); hasher.update(data);
hasher.finalize().into() hasher.finalize().into()
} }
/// MD5 /// MD5 — **protocol-required** by Telegram Middle Proxy key derivation.
/// Not used for general-purpose hashing.
pub fn md5(data: &[u8]) -> [u8; 16] { pub fn md5(data: &[u8]) -> [u8; 16] {
let mut hasher = Md5::new(); let mut hasher = Md5::new();
hasher.update(data); hasher.update(data);
@@ -40,7 +55,11 @@ pub fn crc32(data: &[u8]) -> u32 {
crc32fast::hash(data) crc32fast::hash(data)
} }
/// Middle Proxy Keygen /// Middle Proxy key derivation
///
/// Uses MD5 + SHA-1 as mandated by the Telegram Middle Proxy protocol.
/// These algorithms are NOT replaceable here — changing them would break
/// interoperability with Telegram's middle proxy infrastructure.
pub fn derive_middleproxy_keys( pub fn derive_middleproxy_keys(
nonce_srv: &[u8; 16], nonce_srv: &[u8; 16],
nonce_clt: &[u8; 16], nonce_clt: &[u8; 16],

View File

@@ -6,4 +6,4 @@ pub mod random;
pub use aes::{AesCtr, AesCbc}; pub use aes::{AesCtr, AesCbc};
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32}; pub use hash::{sha256, sha256_hmac, sha1, md5, crc32};
pub use random::{SecureRandom, SECURE_RANDOM}; pub use random::SecureRandom;

View File

@@ -3,11 +3,8 @@
use rand::{Rng, RngCore, SeedableRng}; use rand::{Rng, RngCore, SeedableRng};
use rand::rngs::StdRng; use rand::rngs::StdRng;
use parking_lot::Mutex; use parking_lot::Mutex;
use zeroize::Zeroize;
use crate::crypto::AesCtr; use crate::crypto::AesCtr;
use once_cell::sync::Lazy;
/// Global secure random instance
pub static SECURE_RANDOM: Lazy<SecureRandom> = Lazy::new(SecureRandom::new);
/// Cryptographically secure PRNG with AES-CTR /// Cryptographically secure PRNG with AES-CTR
pub struct SecureRandom { pub struct SecureRandom {
@@ -20,18 +17,30 @@ struct SecureRandomInner {
buffer: Vec<u8>, buffer: Vec<u8>,
} }
impl Drop for SecureRandomInner {
fn drop(&mut self) {
self.buffer.zeroize();
}
}
impl SecureRandom { impl SecureRandom {
pub fn new() -> Self { pub fn new() -> Self {
let mut rng = StdRng::from_entropy(); let mut seed_source = rand::rng();
let mut rng = StdRng::from_rng(&mut seed_source);
let mut key = [0u8; 32]; let mut key = [0u8; 32];
rng.fill_bytes(&mut key); rng.fill_bytes(&mut key);
let iv: u128 = rng.gen(); let iv: u128 = rng.random();
let cipher = AesCtr::new(&key, iv);
// Zeroize local key copy — cipher already consumed it
key.zeroize();
Self { Self {
inner: Mutex::new(SecureRandomInner { inner: Mutex::new(SecureRandomInner {
rng, rng,
cipher: AesCtr::new(&key, iv), cipher,
buffer: Vec::with_capacity(1024), buffer: Vec::with_capacity(1024),
}), }),
} }
@@ -78,7 +87,6 @@ impl SecureRandom {
result |= (b as u64) << (i * 8); result |= (b as u64) << (i * 8);
} }
// Mask extra bits
if k < 64 { if k < 64 {
result &= (1u64 << k) - 1; result &= (1u64 << k) - 1;
} }
@@ -107,13 +115,13 @@ impl SecureRandom {
/// Generate random u32 /// Generate random u32
pub fn u32(&self) -> u32 { pub fn u32(&self) -> u32 {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
inner.rng.gen() inner.rng.random()
} }
/// Generate random u64 /// Generate random u64
pub fn u64(&self) -> u64 { pub fn u64(&self) -> u64 {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
inner.rng.gen() inner.rng.random()
} }
} }
@@ -162,12 +170,10 @@ mod tests {
fn test_bits() { fn test_bits() {
let rng = SecureRandom::new(); let rng = SecureRandom::new();
// Single bit should be 0 or 1
for _ in 0..100 { for _ in 0..100 {
assert!(rng.bits(1) <= 1); assert!(rng.bits(1) <= 1);
} }
// 8 bits should be 0-255
for _ in 0..100 { for _ in 0..100 {
assert!(rng.bits(8) <= 255); assert!(rng.bits(8) <= 255);
} }
@@ -185,10 +191,8 @@ mod tests {
} }
} }
// Should have seen all items
assert_eq!(seen.len(), 5); assert_eq!(seen.len(), 5);
// Empty slice should return None
let empty: Vec<i32> = vec![]; let empty: Vec<i32> = vec![];
assert!(rng.choose(&empty).is_none()); assert!(rng.choose(&empty).is_none());
} }
@@ -201,12 +205,10 @@ mod tests {
let mut shuffled = original.clone(); let mut shuffled = original.clone();
rng.shuffle(&mut shuffled); rng.shuffle(&mut shuffled);
// Should contain same elements
let mut sorted = shuffled.clone(); let mut sorted = shuffled.clone();
sorted.sort(); sorted.sort();
assert_eq!(sorted, original); assert_eq!(sorted, original);
// Should be different order (with very high probability)
assert_ne!(shuffled, original); assert_ne!(shuffled, original);
} }
} }

View File

@@ -1,8 +1,170 @@
//! Error Types //! Error Types
use std::fmt;
use std::net::SocketAddr; use std::net::SocketAddr;
use thiserror::Error; use thiserror::Error;
// ============= Stream Errors =============
/// Errors specific to stream I/O operations
#[derive(Debug)]
pub enum StreamError {
/// Partial read: got fewer bytes than expected
PartialRead {
expected: usize,
got: usize,
},
/// Partial write: wrote fewer bytes than expected
PartialWrite {
expected: usize,
written: usize,
},
/// Stream is in poisoned state and cannot be used
Poisoned {
reason: String,
},
/// Buffer overflow: attempted to buffer more than allowed
BufferOverflow {
limit: usize,
attempted: usize,
},
/// Invalid frame format
InvalidFrame {
details: String,
},
/// Unexpected end of stream
UnexpectedEof,
/// Underlying I/O error
Io(std::io::Error),
}
impl fmt::Display for StreamError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::PartialRead { expected, got } => {
write!(f, "partial read: expected {} bytes, got {}", expected, got)
}
Self::PartialWrite { expected, written } => {
write!(f, "partial write: expected {} bytes, wrote {}", expected, written)
}
Self::Poisoned { reason } => {
write!(f, "stream poisoned: {}", reason)
}
Self::BufferOverflow { limit, attempted } => {
write!(f, "buffer overflow: limit {}, attempted {}", limit, attempted)
}
Self::InvalidFrame { details } => {
write!(f, "invalid frame: {}", details)
}
Self::UnexpectedEof => {
write!(f, "unexpected end of stream")
}
Self::Io(e) => {
write!(f, "I/O error: {}", e)
}
}
}
}
impl std::error::Error for StreamError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
_ => None,
}
}
}
impl From<std::io::Error> for StreamError {
fn from(err: std::io::Error) -> Self {
Self::Io(err)
}
}
impl From<StreamError> for std::io::Error {
fn from(err: StreamError) -> Self {
match err {
StreamError::Io(e) => e,
StreamError::UnexpectedEof => {
std::io::Error::new(std::io::ErrorKind::UnexpectedEof, err)
}
StreamError::Poisoned { .. } => {
std::io::Error::new(std::io::ErrorKind::Other, err)
}
StreamError::BufferOverflow { .. } => {
std::io::Error::new(std::io::ErrorKind::OutOfMemory, err)
}
StreamError::InvalidFrame { .. } => {
std::io::Error::new(std::io::ErrorKind::InvalidData, err)
}
StreamError::PartialRead { .. } | StreamError::PartialWrite { .. } => {
std::io::Error::new(std::io::ErrorKind::Other, err)
}
}
}
}
// ============= Recoverable Trait =============
/// Trait for errors that may be recoverable
pub trait Recoverable {
/// Check if error is recoverable (can retry operation)
fn is_recoverable(&self) -> bool;
/// Check if connection can continue after this error
fn can_continue(&self) -> bool;
}
impl Recoverable for StreamError {
fn is_recoverable(&self) -> bool {
match self {
Self::PartialRead { .. } | Self::PartialWrite { .. } => true,
Self::Io(e) => matches!(
e.kind(),
std::io::ErrorKind::WouldBlock
| std::io::ErrorKind::Interrupted
| std::io::ErrorKind::TimedOut
),
Self::Poisoned { .. }
| Self::BufferOverflow { .. }
| Self::InvalidFrame { .. }
| Self::UnexpectedEof => false,
}
}
fn can_continue(&self) -> bool {
match self {
Self::Poisoned { .. } => false,
Self::UnexpectedEof => false,
Self::BufferOverflow { .. } => false,
_ => true,
}
}
}
impl Recoverable for std::io::Error {
fn is_recoverable(&self) -> bool {
matches!(
self.kind(),
std::io::ErrorKind::WouldBlock
| std::io::ErrorKind::Interrupted
| std::io::ErrorKind::TimedOut
)
}
fn can_continue(&self) -> bool {
!matches!(
self.kind(),
std::io::ErrorKind::BrokenPipe
| std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::NotConnected
)
}
}
// ============= Main Proxy Errors =============
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ProxyError { pub enum ProxyError {
// ============= Crypto Errors ============= // ============= Crypto Errors =============
@@ -13,6 +175,11 @@ pub enum ProxyError {
#[error("Invalid key length: expected {expected}, got {got}")] #[error("Invalid key length: expected {expected}, got {got}")]
InvalidKeyLength { expected: usize, got: usize }, InvalidKeyLength { expected: usize, got: usize },
// ============= Stream Errors =============
#[error("Stream error: {0}")]
Stream(#[from] StreamError),
// ============= Protocol Errors ============= // ============= Protocol Errors =============
#[error("Invalid handshake: {0}")] #[error("Invalid handshake: {0}")]
@@ -39,6 +206,12 @@ pub enum ProxyError {
#[error("Sequence number mismatch: expected={expected}, got={got}")] #[error("Sequence number mismatch: expected={expected}, got={got}")]
SeqNoMismatch { expected: i32, got: i32 }, SeqNoMismatch { expected: i32, got: i32 },
#[error("TLS handshake failed: {reason}")]
TlsHandshakeFailed { reason: String },
#[error("Telegram handshake timeout")]
TgHandshakeTimeout,
// ============= Network Errors ============= // ============= Network Errors =============
#[error("Connection timeout to {addr}")] #[error("Connection timeout to {addr}")]
@@ -55,6 +228,9 @@ pub enum ProxyError {
#[error("Invalid proxy protocol header")] #[error("Invalid proxy protocol header")]
InvalidProxyProtocol, InvalidProxyProtocol,
#[error("Proxy error: {0}")]
Proxy(String),
// ============= Config Errors ============= // ============= Config Errors =============
#[error("Config error: {0}")] #[error("Config error: {0}")]
@@ -77,27 +253,53 @@ pub enum ProxyError {
#[error("Unknown user")] #[error("Unknown user")]
UnknownUser, UnknownUser,
#[error("Rate limited")]
RateLimited,
// ============= General Errors ============= // ============= General Errors =============
#[error("Internal error: {0}")] #[error("Internal error: {0}")]
Internal(String), Internal(String),
} }
impl Recoverable for ProxyError {
fn is_recoverable(&self) -> bool {
match self {
Self::Stream(e) => e.is_recoverable(),
Self::Io(e) => e.is_recoverable(),
Self::ConnectionTimeout { .. } => true,
Self::RateLimited => true,
_ => false,
}
}
fn can_continue(&self) -> bool {
match self {
Self::Stream(e) => e.can_continue(),
Self::Io(e) => e.can_continue(),
_ => false,
}
}
}
/// Convenient Result type alias /// Convenient Result type alias
pub type Result<T> = std::result::Result<T, ProxyError>; pub type Result<T> = std::result::Result<T, ProxyError>;
/// Result type for stream operations
pub type StreamResult<T> = std::result::Result<T, StreamError>;
/// Result with optional bad client handling /// Result with optional bad client handling
#[derive(Debug)] #[derive(Debug)]
pub enum HandshakeResult<T> { pub enum HandshakeResult<T, R, W> {
/// Handshake succeeded /// Handshake succeeded
Success(T), Success(T),
/// Client failed validation, needs masking /// Client failed validation, needs masking. Returns ownership of streams.
BadClient, BadClient { reader: R, writer: W },
/// Error occurred /// Error occurred
Error(ProxyError), Error(ProxyError),
} }
impl<T> HandshakeResult<T> { impl<T, R, W> HandshakeResult<T, R, W> {
/// Check if successful /// Check if successful
pub fn is_success(&self) -> bool { pub fn is_success(&self) -> bool {
matches!(self, HandshakeResult::Success(_)) matches!(self, HandshakeResult::Success(_))
@@ -105,58 +307,87 @@ impl<T> HandshakeResult<T> {
/// Check if bad client /// Check if bad client
pub fn is_bad_client(&self) -> bool { pub fn is_bad_client(&self) -> bool {
matches!(self, HandshakeResult::BadClient) matches!(self, HandshakeResult::BadClient { .. })
}
/// Convert to Result, treating BadClient as error
pub fn into_result(self) -> Result<T> {
match self {
HandshakeResult::Success(v) => Ok(v),
HandshakeResult::BadClient => Err(ProxyError::InvalidHandshake("Bad client".into())),
HandshakeResult::Error(e) => Err(e),
}
} }
/// Map the success value /// Map the success value
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U> { pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U, R, W> {
match self { match self {
HandshakeResult::Success(v) => HandshakeResult::Success(f(v)), HandshakeResult::Success(v) => HandshakeResult::Success(f(v)),
HandshakeResult::BadClient => HandshakeResult::BadClient, HandshakeResult::BadClient { reader, writer } => HandshakeResult::BadClient { reader, writer },
HandshakeResult::Error(e) => HandshakeResult::Error(e), HandshakeResult::Error(e) => HandshakeResult::Error(e),
} }
} }
} }
impl<T> From<ProxyError> for HandshakeResult<T> { impl<T, R, W> From<ProxyError> for HandshakeResult<T, R, W> {
fn from(err: ProxyError) -> Self { fn from(err: ProxyError) -> Self {
HandshakeResult::Error(err) HandshakeResult::Error(err)
} }
} }
impl<T> From<std::io::Error> for HandshakeResult<T> { impl<T, R, W> From<std::io::Error> for HandshakeResult<T, R, W> {
fn from(err: std::io::Error) -> Self { fn from(err: std::io::Error) -> Self {
HandshakeResult::Error(ProxyError::Io(err)) HandshakeResult::Error(ProxyError::Io(err))
} }
} }
impl<T, R, W> From<StreamError> for HandshakeResult<T, R, W> {
fn from(err: StreamError) -> Self {
HandshakeResult::Error(ProxyError::Stream(err))
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test]
fn test_stream_error_display() {
let err = StreamError::PartialRead { expected: 100, got: 50 };
assert!(err.to_string().contains("100"));
assert!(err.to_string().contains("50"));
let err = StreamError::Poisoned { reason: "test".into() };
assert!(err.to_string().contains("test"));
}
#[test]
fn test_stream_error_recoverable() {
assert!(StreamError::PartialRead { expected: 10, got: 5 }.is_recoverable());
assert!(StreamError::PartialWrite { expected: 10, written: 5 }.is_recoverable());
assert!(!StreamError::Poisoned { reason: "x".into() }.is_recoverable());
assert!(!StreamError::UnexpectedEof.is_recoverable());
}
#[test]
fn test_stream_error_can_continue() {
assert!(!StreamError::Poisoned { reason: "x".into() }.can_continue());
assert!(!StreamError::UnexpectedEof.can_continue());
assert!(StreamError::PartialRead { expected: 10, got: 5 }.can_continue());
}
#[test]
fn test_stream_error_to_io_error() {
let stream_err = StreamError::UnexpectedEof;
let io_err: std::io::Error = stream_err.into();
assert_eq!(io_err.kind(), std::io::ErrorKind::UnexpectedEof);
}
#[test] #[test]
fn test_handshake_result() { fn test_handshake_result() {
let success: HandshakeResult<i32> = HandshakeResult::Success(42); let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
assert!(success.is_success()); assert!(success.is_success());
assert!(!success.is_bad_client()); assert!(!success.is_bad_client());
let bad: HandshakeResult<i32> = HandshakeResult::BadClient; let bad: HandshakeResult<i32, (), ()> = HandshakeResult::BadClient { reader: (), writer: () };
assert!(!bad.is_success()); assert!(!bad.is_success());
assert!(bad.is_bad_client()); assert!(bad.is_bad_client());
} }
#[test] #[test]
fn test_handshake_result_map() { fn test_handshake_result_map() {
let success: HandshakeResult<i32> = HandshakeResult::Success(42); let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
let mapped = success.map(|x| x * 2); let mapped = success.map(|x| x * 2);
match mapped { match mapped {
@@ -165,6 +396,15 @@ mod tests {
} }
} }
#[test]
fn test_proxy_error_recoverable() {
let err = ProxyError::RateLimited;
assert!(err.is_recoverable());
let err = ProxyError::InvalidHandshake("bad".into());
assert!(!err.is_recoverable());
}
#[test] #[test]
fn test_error_display() { fn test_error_display() {
let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() }; let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() };

View File

@@ -1,158 +1,306 @@
//! Telemt - MTProxy on Rust //! Telemt - MTProxy on Rust
use std::sync::Arc;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::signal; use tokio::signal;
use tracing::{info, error, Level}; use tokio::sync::Semaphore;
use tracing_subscriber::{FmtSubscriber, EnvFilter}; use tracing::{info, error, warn, debug};
use tracing_subscriber::{fmt, EnvFilter, reload, prelude::*};
mod error; mod cli;
mod config;
mod crypto; mod crypto;
mod error;
mod protocol; mod protocol;
mod proxy;
mod stats;
mod stream; mod stream;
mod transport; mod transport;
mod proxy;
mod config;
mod stats;
mod util; mod util;
use config::ProxyConfig; use crate::config::{ProxyConfig, LogLevel};
use stats::{Stats, ReplayChecker}; use crate::proxy::ClientHandler;
use transport::ConnectionPool; use crate::stats::{Stats, ReplayChecker};
use proxy::ClientHandler; use crate::crypto::SecureRandom;
use crate::transport::{create_listener, ListenOptions, UpstreamManager};
use crate::util::ip::detect_ip;
use crate::stream::BufferPool;
#[tokio::main] fn parse_cli() -> (String, bool, Option<String>) {
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { let mut config_path = "config.toml".to_string();
// Initialize logging with env filter let mut silent = false;
// Use RUST_LOG=debug or RUST_LOG=trace for more details let mut log_level: Option<String> = None;
let filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info"));
let subscriber = FmtSubscriber::builder() let args: Vec<String> = std::env::args().skip(1).collect();
.with_env_filter(filter)
.with_target(true)
.with_thread_ids(false)
.with_file(false)
.with_line_number(false)
.finish();
tracing::subscriber::set_global_default(subscriber)?; // Check for --init first (handled before tokio)
if let Some(init_opts) = cli::parse_init_args(&args) {
// Load configuration if let Err(e) = cli::run_init(init_opts) {
let config_path = std::env::args() eprintln!("[telemt] Init failed: {}", e);
.nth(1) std::process::exit(1);
.unwrap_or_else(|| "config.toml".to_string()); }
std::process::exit(0);
info!("Loading configuration from {}", config_path);
let config = ProxyConfig::load(&config_path).unwrap_or_else(|e| {
error!("Failed to load config: {}", e);
info!("Using default configuration");
ProxyConfig::default()
});
if let Err(e) = config.validate() {
error!("Invalid configuration: {}", e);
std::process::exit(1);
} }
let config = Arc::new(config); let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--silent" | "-s" => { silent = true; }
"--log-level" => {
i += 1;
if i < args.len() { log_level = Some(args[i].clone()); }
}
s if s.starts_with("--log-level=") => {
log_level = Some(s.trim_start_matches("--log-level=").to_string());
}
"--help" | "-h" => {
eprintln!("Usage: telemt [config.toml] [OPTIONS]");
eprintln!();
eprintln!("Options:");
eprintln!(" --silent, -s Suppress info logs");
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
eprintln!(" --help, -h Show this help");
eprintln!();
eprintln!("Setup (fire-and-forget):");
eprintln!(" --init Generate config, install systemd service, start");
eprintln!(" --port <PORT> Listen port (default: 443)");
eprintln!(" --domain <DOMAIN> TLS domain for masking (default: www.google.com)");
eprintln!(" --secret <HEX> 32-char hex secret (auto-generated if omitted)");
eprintln!(" --user <NAME> Username (default: user)");
eprintln!(" --config-dir <DIR> Config directory (default: /etc/telemt)");
eprintln!(" --no-start Don't start the service after install");
std::process::exit(0);
}
s if !s.starts_with('-') => { config_path = s.to_string(); }
other => { eprintln!("Unknown option: {}", other); }
}
i += 1;
}
info!("Starting MTProto Proxy on port {}", config.port); (config_path, silent, log_level)
info!("Fast mode: {}", config.fast_mode); }
info!("Modes: classic={}, secure={}, tls={}",
config.modes.classic, config.modes.secure, config.modes.tls); #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Initialize components let (config_path, cli_silent, cli_log_level) = parse_cli();
let stats = Arc::new(Stats::new());
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len)); let config = match ProxyConfig::load(&config_path) {
let pool = Arc::new(ConnectionPool::new()); Ok(c) => c,
Err(e) => {
// Create handler if std::path::Path::new(&config_path).exists() {
let handler = Arc::new(ClientHandler::new( eprintln!("[telemt] Error: {}", e);
Arc::clone(&config), std::process::exit(1);
Arc::clone(&stats), } else {
Arc::clone(&replay_checker), let default = ProxyConfig::default();
Arc::clone(&pool), std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
)); eprintln!("[telemt] Created default config at {}", config_path);
default
// Start listener
let addr: SocketAddr = format!("{}:{}", config.listen_addr_ipv4, config.port)
.parse()?;
let listener = TcpListener::bind(addr).await?;
info!("Listening on {}", addr);
// Print proxy links
print_proxy_links(&config);
info!("Use RUST_LOG=debug or RUST_LOG=trace for more detailed logging");
// Main accept loop
let accept_loop = async {
loop {
match listener.accept().await {
Ok((stream, peer)) => {
let handler = Arc::clone(&handler);
tokio::spawn(async move {
handler.handle(stream, peer).await;
});
}
Err(e) => {
error!("Accept error: {}", e);
}
} }
} }
}; };
// Graceful shutdown if let Err(e) = config.validate() {
tokio::select! { eprintln!("[telemt] Invalid config: {}", e);
_ = accept_loop => {} std::process::exit(1);
_ = signal::ctrl_c() => {
info!("Shutting down...");
}
} }
// Cleanup
pool.close_all().await;
info!("Goodbye!");
Ok(())
}
fn print_proxy_links(config: &ProxyConfig) { let has_rust_log = std::env::var("RUST_LOG").is_ok();
println!("\n=== Proxy Links ===\n"); let effective_log_level = if cli_silent {
LogLevel::Silent
for (user, secret) in &config.users { } else if let Some(ref s) = cli_log_level {
if config.modes.tls { LogLevel::from_str_loose(s)
let tls_secret = format!( } else {
"ee{}{}", config.general.log_level.clone()
secret, };
hex::encode(config.tls_domain.as_bytes())
); // Start with INFO so startup messages are always visible,
println!( // then switch to user-configured level after startup
"{} (TLS): tg://proxy?server=IP&port={}&secret={}", let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info"));
user, config.port, tls_secret tracing_subscriber::registry()
); .with(filter_layer)
} .with(fmt::Layer::default())
.init();
if config.modes.secure {
println!( info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
"{} (Secure): tg://proxy?server=IP&port={}&secret=dd{}", info!("Log level: {}", effective_log_level);
user, config.port, secret info!("Modes: classic={} secure={} tls={}",
); config.general.modes.classic,
} config.general.modes.secure,
config.general.modes.tls);
if config.modes.classic { info!("TLS domain: {}", config.censorship.tls_domain);
println!( info!("Mask: {} -> {}:{}",
"{} (Classic): tg://proxy?server=IP&port={}&secret={}", config.censorship.mask,
user, config.port, secret config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain),
); config.censorship.mask_port);
}
if config.censorship.tls_domain == "www.google.com" {
println!(); warn!("Using default tls_domain. Consider setting a custom domain.");
} }
println!("===================\n"); let prefer_ipv6 = config.general.prefer_ipv6;
let config = Arc::new(config);
let stats = Arc::new(Stats::new());
let rng = Arc::new(SecureRandom::new());
let replay_checker = Arc::new(ReplayChecker::new(
config.access.replay_check_len,
Duration::from_secs(config.access.replay_window_secs),
));
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
// Connection concurrency limit — prevents OOM under SYN flood / connection storm.
// 10000 is generous; each connection uses ~64KB (2x 16KB relay buffers + overhead).
// 10000 connections ≈ 640MB peak memory.
let max_connections = Arc::new(Semaphore::new(10_000));
// Startup DC ping
info!("=== Telegram DC Connectivity ===");
let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await;
for upstream_result in &ping_results {
info!(" via {}", upstream_result.upstream_name);
for dc in &upstream_result.results {
match (&dc.rtt_ms, &dc.error) {
(Some(rtt), _) => {
info!(" DC{} ({:>21}): {:.0}ms", dc.dc_idx, dc.dc_addr, rtt);
}
(None, Some(err)) => {
info!(" DC{} ({:>21}): FAIL ({})", dc.dc_idx, dc.dc_addr, err);
}
_ => {
info!(" DC{} ({:>21}): FAIL", dc.dc_idx, dc.dc_addr);
}
}
}
}
info!("================================");
// Background tasks
let um_clone = upstream_manager.clone();
tokio::spawn(async move { um_clone.run_health_checks(prefer_ipv6).await; });
let rc_clone = replay_checker.clone();
tokio::spawn(async move { rc_clone.run_periodic_cleanup().await; });
let detected_ip = detect_ip().await;
debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6);
let mut listeners = Vec::new();
for listener_conf in &config.server.listeners {
let addr = SocketAddr::new(listener_conf.ip, config.server.port);
let options = ListenOptions {
ipv6_only: listener_conf.ip.is_ipv6(),
..Default::default()
};
match create_listener(addr, &options) {
Ok(socket) => {
let listener = TcpListener::from_std(socket.into())?;
info!("Listening on {}", addr);
let public_ip = if let Some(ip) = listener_conf.announce_ip {
ip
} else if listener_conf.ip.is_unspecified() {
if listener_conf.ip.is_ipv4() {
detected_ip.ipv4.unwrap_or(listener_conf.ip)
} else {
detected_ip.ipv6.unwrap_or(listener_conf.ip)
}
} else {
listener_conf.ip
};
if !config.show_link.is_empty() {
info!("--- Proxy Links ({}) ---", public_ip);
for user_name in &config.show_link {
if let Some(secret) = config.access.users.get(user_name) {
info!("User: {}", user_name);
if config.general.modes.classic {
info!(" Classic: tg://proxy?server={}&port={}&secret={}",
public_ip, config.server.port, secret);
}
if config.general.modes.secure {
info!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
public_ip, config.server.port, secret);
}
if config.general.modes.tls {
let domain_hex = hex::encode(&config.censorship.tls_domain);
info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
public_ip, config.server.port, secret, domain_hex);
}
} else {
warn!("User '{}' in show_link not found", user_name);
}
}
info!("------------------------");
}
listeners.push(listener);
},
Err(e) => {
error!("Failed to bind to {}: {}", addr, e);
}
}
}
if listeners.is_empty() {
error!("No listeners. Exiting.");
std::process::exit(1);
}
// Switch to user-configured log level after startup
let runtime_filter = if has_rust_log {
EnvFilter::from_default_env()
} else {
EnvFilter::new(effective_log_level.to_filter_str())
};
filter_handle.reload(runtime_filter).expect("Failed to switch log filter");
for listener in listeners {
let config = config.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
tokio::spawn(async move {
loop {
match listener.accept().await {
Ok((stream, peer_addr)) => {
let config = config.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
tokio::spawn(async move {
if let Err(e) = ClientHandler::new(
stream, peer_addr, config, stats,
upstream_manager, replay_checker, buffer_pool, rng
).run().await {
debug!(peer = %peer_addr, error = %e, "Connection error");
}
});
}
Err(e) => {
error!("Accept error: {}", e);
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
});
}
match signal::ctrl_c().await {
Ok(()) => info!("Shutting down..."),
Err(e) => error!("Signal error: {}", e),
}
Ok(())
} }

View File

@@ -1,13 +1,13 @@
//! Protocol constants and datacenter addresses //! Protocol constants and datacenter addresses
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use once_cell::sync::Lazy; use std::sync::LazyLock;
// ============= Telegram Datacenters ============= // ============= Telegram Datacenters =============
pub const TG_DATACENTER_PORT: u16 = 443; pub const TG_DATACENTER_PORT: u16 = 443;
pub static TG_DATACENTERS_V4: Lazy<Vec<IpAddr>> = Lazy::new(|| { pub static TG_DATACENTERS_V4: LazyLock<Vec<IpAddr>> = LazyLock::new(|| {
vec![ vec![
IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)),
IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)), IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)),
@@ -17,7 +17,7 @@ pub static TG_DATACENTERS_V4: Lazy<Vec<IpAddr>> = Lazy::new(|| {
] ]
}); });
pub static TG_DATACENTERS_V6: Lazy<Vec<IpAddr>> = Lazy::new(|| { pub static TG_DATACENTERS_V6: LazyLock<Vec<IpAddr>> = LazyLock::new(|| {
vec![ vec![
IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()), IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()),
IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()), IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()),
@@ -29,8 +29,8 @@ pub static TG_DATACENTERS_V6: Lazy<Vec<IpAddr>> = Lazy::new(|| {
// ============= Middle Proxies (for advertising) ============= // ============= Middle Proxies (for advertising) =============
pub static TG_MIDDLE_PROXIES_V4: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> = pub static TG_MIDDLE_PROXIES_V4: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
Lazy::new(|| { LazyLock::new(|| {
let mut m = std::collections::HashMap::new(); let mut m = std::collections::HashMap::new();
m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
@@ -45,8 +45,8 @@ pub static TG_MIDDLE_PROXIES_V4: Lazy<std::collections::HashMap<i32, Vec<(IpAddr
m m
}); });
pub static TG_MIDDLE_PROXIES_V6: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> = pub static TG_MIDDLE_PROXIES_V6: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
Lazy::new(|| { LazyLock::new(|| {
let mut m = std::collections::HashMap::new(); let mut m = std::collections::HashMap::new();
m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
@@ -167,7 +167,8 @@ pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300;
// ============= Buffer Sizes ============= // ============= Buffer Sizes =============
/// Default buffer size /// Default buffer size
pub const DEFAULT_BUFFER_SIZE: usize = 65536; pub const DEFAULT_BUFFER_SIZE: usize = 16384;
/// Small buffer size for bad client handling /// Small buffer size for bad client handling
pub const SMALL_BUFFER_SIZE: usize = 8192; pub const SMALL_BUFFER_SIZE: usize = 8192;

View File

@@ -1,10 +1,13 @@
//! MTProto Obfuscation //! MTProto Obfuscation
use zeroize::Zeroize;
use crate::crypto::{sha256, AesCtr}; use crate::crypto::{sha256, AesCtr};
use crate::error::Result; use crate::error::Result;
use super::constants::*; use super::constants::*;
/// Obfuscation parameters from handshake /// Obfuscation parameters from handshake
///
/// Key material is zeroized on drop.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ObfuscationParams { pub struct ObfuscationParams {
/// Key for decrypting client -> proxy traffic /// Key for decrypting client -> proxy traffic
@@ -21,25 +24,31 @@ pub struct ObfuscationParams {
pub dc_idx: i16, pub dc_idx: i16,
} }
impl Drop for ObfuscationParams {
fn drop(&mut self) {
self.decrypt_key.zeroize();
self.decrypt_iv.zeroize();
self.encrypt_key.zeroize();
self.encrypt_iv.zeroize();
}
}
impl ObfuscationParams { impl ObfuscationParams {
/// Parse obfuscation parameters from handshake bytes /// Parse obfuscation parameters from handshake bytes
/// Returns None if handshake doesn't match any user secret /// Returns None if handshake doesn't match any user secret
pub fn from_handshake( pub fn from_handshake(
handshake: &[u8; HANDSHAKE_LEN], handshake: &[u8; HANDSHAKE_LEN],
secrets: &[(String, Vec<u8>)], // (username, secret_bytes) secrets: &[(String, Vec<u8>)],
) -> Option<(Self, String)> { ) -> Option<(Self, String)> {
// Extract prekey and IV for decryption
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
// Reversed for encryption direction
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
for (username, secret) in secrets { for (username, secret) in secrets {
// Derive decryption key
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(secret); dec_key_input.extend_from_slice(secret);
@@ -47,26 +56,22 @@ impl ObfuscationParams {
let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
// Create decryptor and decrypt handshake
let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv); let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv);
let decrypted = decryptor.decrypt(handshake); let decrypted = decryptor.decrypt(handshake);
// Check protocol tag
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
.try_into() .try_into()
.unwrap(); .unwrap();
let proto_tag = match ProtoTag::from_bytes(tag_bytes) { let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
Some(tag) => tag, Some(tag) => tag,
None => continue, // Try next secret None => continue,
}; };
// Extract DC index
let dc_idx = i16::from_le_bytes( let dc_idx = i16::from_le_bytes(
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
); );
// Derive encryption key
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(secret); enc_key_input.extend_from_slice(secret);
@@ -123,18 +128,15 @@ pub fn generate_nonce<R: FnMut(usize) -> Vec<u8>>(mut random_bytes: R) -> [u8; H
/// Check if nonce is valid (not matching reserved patterns) /// Check if nonce is valid (not matching reserved patterns)
pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool { pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
// Check first byte
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
return false; return false;
} }
// Check first 4 bytes
let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
return false; return false;
} }
// Check bytes 4-7
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
return false; return false;
@@ -147,12 +149,10 @@ pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
pub fn prepare_tg_nonce( pub fn prepare_tg_nonce(
nonce: &mut [u8; HANDSHAKE_LEN], nonce: &mut [u8; HANDSHAKE_LEN],
proto_tag: ProtoTag, proto_tag: ProtoTag,
enc_key_iv: Option<&[u8]>, // For fast mode enc_key_iv: Option<&[u8]>,
) { ) {
// Set protocol tag
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
// For fast mode, copy the reversed enc_key_iv
if let Some(key_iv) = enc_key_iv { if let Some(key_iv) = enc_key_iv {
let reversed: Vec<u8> = key_iv.iter().rev().copied().collect(); let reversed: Vec<u8> = key_iv.iter().rev().copied().collect();
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed); nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed);
@@ -161,14 +161,12 @@ pub fn prepare_tg_nonce(
/// Encrypt the outgoing nonce for Telegram /// Encrypt the outgoing nonce for Telegram
pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> { pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
// Derive encryption key from the nonce itself
let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let enc_key = sha256(key_iv); let enc_key = sha256(key_iv);
let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap()); let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap());
let mut encryptor = AesCtr::new(&enc_key, enc_iv); let mut encryptor = AesCtr::new(&enc_key, enc_iv);
// Only encrypt from PROTO_TAG_POS onwards
let mut result = nonce.to_vec(); let mut result = nonce.to_vec();
let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]); let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]);
result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part); result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part);
@@ -182,22 +180,18 @@ mod tests {
#[test] #[test]
fn test_is_valid_nonce() { fn test_is_valid_nonce() {
// Valid nonce
let mut valid = [0x42u8; HANDSHAKE_LEN]; let mut valid = [0x42u8; HANDSHAKE_LEN];
valid[4..8].copy_from_slice(&[1, 2, 3, 4]); valid[4..8].copy_from_slice(&[1, 2, 3, 4]);
assert!(is_valid_nonce(&valid)); assert!(is_valid_nonce(&valid));
// Invalid: starts with 0xef
let mut invalid = [0x00u8; HANDSHAKE_LEN]; let mut invalid = [0x00u8; HANDSHAKE_LEN];
invalid[0] = 0xef; invalid[0] = 0xef;
assert!(!is_valid_nonce(&invalid)); assert!(!is_valid_nonce(&invalid));
// Invalid: starts with HEAD
let mut invalid = [0x00u8; HANDSHAKE_LEN]; let mut invalid = [0x00u8; HANDSHAKE_LEN];
invalid[..4].copy_from_slice(b"HEAD"); invalid[..4].copy_from_slice(b"HEAD");
assert!(!is_valid_nonce(&invalid)); assert!(!is_valid_nonce(&invalid));
// Invalid: bytes 4-7 are zeros
let mut invalid = [0x42u8; HANDSHAKE_LEN]; let mut invalid = [0x42u8; HANDSHAKE_LEN];
invalid[4..8].copy_from_slice(&[0, 0, 0, 0]); invalid[4..8].copy_from_slice(&[0, 0, 0, 0]);
assert!(!is_valid_nonce(&invalid)); assert!(!is_valid_nonce(&invalid));

View File

@@ -1,14 +1,22 @@
//! Fake TLS 1.3 Handshake //! Fake TLS 1.3 Handshake
//!
//! This module handles the fake TLS 1.3 handshake used by MTProto proxy
//! for domain fronting. The handshake looks like valid TLS 1.3 but
//! actually carries MTProto authentication data.
use crate::crypto::{sha256_hmac, random::SECURE_RANDOM}; use crate::crypto::{sha256_hmac, SecureRandom};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use super::constants::*; use super::constants::*;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
// ============= Public Constants =============
/// TLS handshake digest length /// TLS handshake digest length
pub const TLS_DIGEST_LEN: usize = 32; pub const TLS_DIGEST_LEN: usize = 32;
/// Position of digest in TLS ClientHello /// Position of digest in TLS ClientHello
pub const TLS_DIGEST_POS: usize = 11; pub const TLS_DIGEST_POS: usize = 11;
/// Length to store for replay protection (first 16 bytes of digest) /// Length to store for replay protection (first 16 bytes of digest)
pub const TLS_DIGEST_HALF_LEN: usize = 16; pub const TLS_DIGEST_HALF_LEN: usize = 16;
@@ -16,6 +24,26 @@ pub const TLS_DIGEST_HALF_LEN: usize = 16;
pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before
pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after
// ============= Private Constants =============
/// TLS Extension types
mod extension_type {
pub const KEY_SHARE: u16 = 0x0033;
pub const SUPPORTED_VERSIONS: u16 = 0x002b;
}
/// TLS Cipher Suites
mod cipher_suite {
pub const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
}
/// TLS Named Curves
mod named_curve {
pub const X25519: u16 = 0x001d;
}
// ============= TLS Validation Result =============
/// Result of validating TLS handshake /// Result of validating TLS handshake
#[derive(Debug)] #[derive(Debug)]
pub struct TlsValidation { pub struct TlsValidation {
@@ -29,7 +57,185 @@ pub struct TlsValidation {
pub timestamp: u32, pub timestamp: u32,
} }
// ============= TLS Extension Builder =============
/// Builder for TLS extensions with correct length calculation
struct TlsExtensionBuilder {
extensions: Vec<u8>,
}
impl TlsExtensionBuilder {
fn new() -> Self {
Self {
extensions: Vec::with_capacity(128),
}
}
/// Add Key Share extension with X25519 key
fn add_key_share(&mut self, public_key: &[u8; 32]) -> &mut Self {
// Extension type: key_share (0x0033)
self.extensions.extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes());
// Key share entry: curve (2) + key_len (2) + key (32) = 36 bytes
// Extension data length
let entry_len: u16 = 2 + 2 + 32; // curve + length + key
self.extensions.extend_from_slice(&entry_len.to_be_bytes());
// Named curve: x25519
self.extensions.extend_from_slice(&named_curve::X25519.to_be_bytes());
// Key length
self.extensions.extend_from_slice(&(32u16).to_be_bytes());
// Key data
self.extensions.extend_from_slice(public_key);
self
}
/// Add Supported Versions extension
fn add_supported_versions(&mut self, version: u16) -> &mut Self {
// Extension type: supported_versions (0x002b)
self.extensions.extend_from_slice(&extension_type::SUPPORTED_VERSIONS.to_be_bytes());
// Extension data: length (2) + version (2)
self.extensions.extend_from_slice(&(2u16).to_be_bytes());
// Selected version
self.extensions.extend_from_slice(&version.to_be_bytes());
self
}
/// Build final extensions with length prefix
fn build(self) -> Vec<u8> {
let mut result = Vec::with_capacity(2 + self.extensions.len());
// Extensions length (2 bytes)
let len = self.extensions.len() as u16;
result.extend_from_slice(&len.to_be_bytes());
// Extensions data
result.extend_from_slice(&self.extensions);
result
}
/// Get current extensions without length prefix (for calculation)
#[allow(dead_code)]
fn as_bytes(&self) -> &[u8] {
&self.extensions
}
}
// ============= ServerHello Builder =============
/// Builder for TLS ServerHello with correct structure
struct ServerHelloBuilder {
/// Random bytes (32 bytes, will contain digest)
random: [u8; 32],
/// Session ID (echoed from ClientHello)
session_id: Vec<u8>,
/// Cipher suite
cipher_suite: [u8; 2],
/// Compression method
compression: u8,
/// Extensions
extensions: TlsExtensionBuilder,
}
impl ServerHelloBuilder {
fn new(session_id: Vec<u8>) -> Self {
Self {
random: [0u8; 32],
session_id,
cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256,
compression: 0x00,
extensions: TlsExtensionBuilder::new(),
}
}
fn with_x25519_key(mut self, key: &[u8; 32]) -> Self {
self.extensions.add_key_share(key);
self
}
fn with_tls13_version(mut self) -> Self {
// TLS 1.3 = 0x0304
self.extensions.add_supported_versions(0x0304);
self
}
/// Build ServerHello message (without record header)
fn build_message(&self) -> Vec<u8> {
let extensions = self.extensions.extensions.clone();
let extensions_len = extensions.len() as u16;
// Calculate total length
let body_len = 2 + // version
32 + // random
1 + self.session_id.len() + // session_id length + data
2 + // cipher suite
1 + // compression
2 + extensions.len(); // extensions length + data
let mut message = Vec::with_capacity(4 + body_len);
// Handshake header
message.push(0x02); // ServerHello message type
// 3-byte length
let len_bytes = (body_len as u32).to_be_bytes();
message.extend_from_slice(&len_bytes[1..4]);
// Server version (TLS 1.2 in header, actual version in extension)
message.extend_from_slice(&TLS_VERSION);
// Random (32 bytes) - placeholder, will be replaced with digest
message.extend_from_slice(&self.random);
// Session ID
message.push(self.session_id.len() as u8);
message.extend_from_slice(&self.session_id);
// Cipher suite
message.extend_from_slice(&self.cipher_suite);
// Compression method
message.push(self.compression);
// Extensions length
message.extend_from_slice(&extensions_len.to_be_bytes());
// Extensions data
message.extend_from_slice(&extensions);
message
}
/// Build complete ServerHello TLS record
fn build_record(&self) -> Vec<u8> {
let message = self.build_message();
let mut record = Vec::with_capacity(5 + message.len());
// TLS record header
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(message.len() as u16).to_be_bytes());
// Message
record.extend_from_slice(&message);
record
}
}
// ============= Public Functions =============
/// Validate TLS ClientHello against user secrets /// Validate TLS ClientHello against user secrets
///
/// Returns validation result if a matching user is found.
pub fn validate_tls_handshake( pub fn validate_tls_handshake(
handshake: &[u8], handshake: &[u8],
secrets: &[(String, Vec<u8>)], secrets: &[(String, Vec<u8>)],
@@ -86,7 +292,8 @@ pub fn validate_tls_handshake(
// Check time skew // Check time skew
if !ignore_time_skew { if !ignore_time_skew {
// Allow very small timestamps (boot time instead of unix time) // Allow very small timestamps (boot time instead of unix time)
let is_boot_time = timestamp < 60 * 60 * 24 * 1000; // This is a quirk in some clients that use uptime instead of real time
let is_boot_time = timestamp < 60 * 60 * 24 * 1000; // < ~2.7 years in seconds
if !is_boot_time && (time_diff < TIME_SKEW_MIN || time_diff > TIME_SKEW_MAX) { if !is_boot_time && (time_diff < TIME_SKEW_MIN || time_diff > TIME_SKEW_MAX) {
continue; continue;
@@ -105,79 +312,73 @@ pub fn validate_tls_handshake(
} }
/// Generate a fake X25519 public key for TLS /// Generate a fake X25519 public key for TLS
/// This generates a value that looks like a valid X25519 key ///
pub fn gen_fake_x25519_key() -> [u8; 32] { /// This generates random bytes that look like a valid X25519 public key.
// For simplicity, just generate random 32 bytes /// Since we're not doing real TLS, the actual cryptographic properties don't matter.
// In real X25519, this would be a point on the curve pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] {
let bytes = SECURE_RANDOM.bytes(32); let bytes = rng.bytes(32);
bytes.try_into().unwrap() bytes.try_into().unwrap()
} }
/// Build TLS ServerHello response /// Build TLS ServerHello response
///
/// This builds a complete TLS 1.3-like response including:
/// - ServerHello record with extensions
/// - Change Cipher Spec record
/// - Fake encrypted certificate (Application Data record)
///
/// The response includes an HMAC digest that the client can verify.
pub fn build_server_hello( pub fn build_server_hello(
secret: &[u8], secret: &[u8],
client_digest: &[u8; TLS_DIGEST_LEN], client_digest: &[u8; TLS_DIGEST_LEN],
session_id: &[u8], session_id: &[u8],
fake_cert_len: usize, fake_cert_len: usize,
rng: &SecureRandom,
) -> Vec<u8> { ) -> Vec<u8> {
let x25519_key = gen_fake_x25519_key(); let x25519_key = gen_fake_x25519_key(rng);
// TLS extensions // Build ServerHello
let mut extensions = Vec::new(); let server_hello = ServerHelloBuilder::new(session_id.to_vec())
extensions.extend_from_slice(&[0x00, 0x2e]); // Extension length placeholder .with_x25519_key(&x25519_key)
extensions.extend_from_slice(&[0x00, 0x33, 0x00, 0x24]); // Key share extension .with_tls13_version()
extensions.extend_from_slice(&[0x00, 0x1d, 0x00, 0x20]); // X25519 curve .build_record();
extensions.extend_from_slice(&x25519_key);
extensions.extend_from_slice(&[0x00, 0x2b, 0x00, 0x02, 0x03, 0x04]); // Supported versions
// ServerHello body // Build Change Cipher Spec record
let mut srv_hello = Vec::new(); let change_cipher_spec = [
srv_hello.extend_from_slice(&TLS_VERSION); TLS_RECORD_CHANGE_CIPHER,
srv_hello.extend_from_slice(&[0u8; TLS_DIGEST_LEN]); // Placeholder for digest TLS_VERSION[0], TLS_VERSION[1],
srv_hello.push(session_id.len() as u8); 0x00, 0x01, // length = 1
srv_hello.extend_from_slice(session_id); 0x01, // CCS byte
srv_hello.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256 ];
srv_hello.push(0x00); // No compression
srv_hello.extend_from_slice(&extensions);
// Build complete packet // Build fake certificate (Application Data record)
let mut hello_pkt = Vec::new(); let fake_cert = rng.bytes(fake_cert_len);
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
app_data_record.push(TLS_RECORD_APPLICATION);
app_data_record.extend_from_slice(&TLS_VERSION);
app_data_record.extend_from_slice(&(fake_cert_len as u16).to_be_bytes());
app_data_record.extend_from_slice(&fake_cert);
// ServerHello record // Combine all records
hello_pkt.push(TLS_RECORD_HANDSHAKE); let mut response = Vec::with_capacity(
hello_pkt.extend_from_slice(&TLS_VERSION); server_hello.len() + change_cipher_spec.len() + app_data_record.len()
hello_pkt.extend_from_slice(&((srv_hello.len() + 4) as u16).to_be_bytes()); );
hello_pkt.push(0x02); // ServerHello message type response.extend_from_slice(&server_hello);
let len_bytes = (srv_hello.len() as u32).to_be_bytes(); response.extend_from_slice(&change_cipher_spec);
hello_pkt.extend_from_slice(&len_bytes[1..4]); // 3-byte length response.extend_from_slice(&app_data_record);
hello_pkt.extend_from_slice(&srv_hello);
// Change Cipher Spec record
hello_pkt.extend_from_slice(&[
TLS_RECORD_CHANGE_CIPHER,
TLS_VERSION[0], TLS_VERSION[1],
0x00, 0x01, 0x01
]);
// Application Data record (fake certificate)
let fake_cert = SECURE_RANDOM.bytes(fake_cert_len);
hello_pkt.push(TLS_RECORD_APPLICATION);
hello_pkt.extend_from_slice(&TLS_VERSION);
hello_pkt.extend_from_slice(&(fake_cert.len() as u16).to_be_bytes());
hello_pkt.extend_from_slice(&fake_cert);
// Compute HMAC for the response // Compute HMAC for the response
let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + hello_pkt.len()); let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + response.len());
hmac_input.extend_from_slice(client_digest); hmac_input.extend_from_slice(client_digest);
hmac_input.extend_from_slice(&hello_pkt); hmac_input.extend_from_slice(&response);
let response_digest = sha256_hmac(secret, &hmac_input); let response_digest = sha256_hmac(secret, &hmac_input);
// Insert computed digest // Insert computed digest into ServerHello
// Position: after record header (5) + message type/length (4) + version (2) = 11 // Position: record header (5) + message type (1) + length (3) + version (2) = 11
hello_pkt[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.copy_from_slice(&response_digest); .copy_from_slice(&response_digest);
hello_pkt response
} }
/// Check if bytes look like a TLS ClientHello /// Check if bytes look like a TLS ClientHello
@@ -186,7 +387,7 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
return false; return false;
} }
// TLS record header: 0x16 0x03 0x01 // TLS record header: 0x16 (handshake) 0x03 0x01 (TLS 1.0)
first_bytes[0] == TLS_RECORD_HANDSHAKE first_bytes[0] == TLS_RECORD_HANDSHAKE
&& first_bytes[1] == 0x03 && first_bytes[1] == 0x03
&& first_bytes[2] == 0x01 && first_bytes[2] == 0x01
@@ -206,6 +407,61 @@ pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> {
Some((record_type, length)) Some((record_type, length))
} }
/// Validate a ServerHello response structure
///
/// This is useful for testing that our ServerHello is well-formed.
#[cfg(test)]
fn validate_server_hello_structure(data: &[u8]) -> Result<()> {
if data.len() < 5 {
return Err(ProxyError::InvalidTlsRecord {
record_type: 0,
version: [0, 0],
});
}
// Check record header
if data[0] != TLS_RECORD_HANDSHAKE {
return Err(ProxyError::InvalidTlsRecord {
record_type: data[0],
version: [data[1], data[2]],
});
}
// Check version
if data[1..3] != TLS_VERSION {
return Err(ProxyError::InvalidTlsRecord {
record_type: data[0],
version: [data[1], data[2]],
});
}
// Check record length
let record_len = u16::from_be_bytes([data[3], data[4]]) as usize;
if data.len() < 5 + record_len {
return Err(ProxyError::InvalidHandshake(
format!("ServerHello record truncated: expected {}, got {}",
5 + record_len, data.len())
));
}
// Check message type
if data[5] != 0x02 {
return Err(ProxyError::InvalidHandshake(
format!("Expected ServerHello (0x02), got 0x{:02x}", data[5])
));
}
// Parse message length
let msg_len = u32::from_be_bytes([0, data[6], data[7], data[8]]) as usize;
if msg_len + 4 != record_len {
return Err(ProxyError::InvalidHandshake(
format!("Message length mismatch: {} + 4 != {}", msg_len, record_len)
));
}
Ok(())
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@@ -234,11 +490,155 @@ mod tests {
#[test] #[test]
fn test_gen_fake_x25519_key() { fn test_gen_fake_x25519_key() {
let key1 = gen_fake_x25519_key(); let rng = SecureRandom::new();
let key2 = gen_fake_x25519_key(); let key1 = gen_fake_x25519_key(&rng);
let key2 = gen_fake_x25519_key(&rng);
assert_eq!(key1.len(), 32); assert_eq!(key1.len(), 32);
assert_eq!(key2.len(), 32); assert_eq!(key2.len(), 32);
assert_ne!(key1, key2); // Should be random assert_ne!(key1, key2); // Should be random
} }
#[test]
fn test_tls_extension_builder() {
let key = [0x42u8; 32];
let mut builder = TlsExtensionBuilder::new();
builder.add_key_share(&key);
builder.add_supported_versions(0x0304);
let result = builder.build();
// Check length prefix
let len = u16::from_be_bytes([result[0], result[1]]) as usize;
assert_eq!(len, result.len() - 2);
// Check key_share extension is present
assert!(result.len() > 40); // At least key share
}
#[test]
fn test_server_hello_builder() {
let session_id = vec![0x01, 0x02, 0x03, 0x04];
let key = [0x55u8; 32];
let builder = ServerHelloBuilder::new(session_id.clone())
.with_x25519_key(&key)
.with_tls13_version();
let record = builder.build_record();
// Validate structure
validate_server_hello_structure(&record).expect("Invalid ServerHello structure");
// Check record type
assert_eq!(record[0], TLS_RECORD_HANDSHAKE);
// Check version
assert_eq!(&record[1..3], &TLS_VERSION);
// Check message type (ServerHello = 0x02)
assert_eq!(record[5], 0x02);
}
#[test]
fn test_build_server_hello_structure() {
let secret = b"test secret";
let client_digest = [0x42u8; 32];
let session_id = vec![0xAA; 32];
let rng = SecureRandom::new();
let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng);
// Should have at least 3 records
assert!(response.len() > 100);
// First record should be ServerHello
assert_eq!(response[0], TLS_RECORD_HANDSHAKE);
// Validate ServerHello structure
validate_server_hello_structure(&response).expect("Invalid ServerHello");
// Find Change Cipher Spec
let server_hello_len = 5 + u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_start = server_hello_len;
assert!(response.len() > ccs_start + 6);
assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER);
// Find Application Data
let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize;
let app_start = ccs_start + ccs_len;
assert!(response.len() > app_start + 5);
assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
}
#[test]
fn test_build_server_hello_digest() {
let secret = b"test secret key here";
let client_digest = [0x42u8; 32];
let session_id = vec![0xAA; 32];
let rng = SecureRandom::new();
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
// Digest position should have non-zero data
let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
assert!(!digest1.iter().all(|&b| b == 0));
// Different calls should have different digests (due to random cert)
let digest2 = &response2[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
assert_ne!(digest1, digest2);
}
#[test]
fn test_server_hello_extensions_length() {
let session_id = vec![0x01; 32];
let key = [0x55u8; 32];
let builder = ServerHelloBuilder::new(session_id)
.with_x25519_key(&key)
.with_tls13_version();
let record = builder.build_record();
// Parse to find extensions
let msg_start = 5; // After record header
let msg_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize;
// Skip to session ID
let session_id_pos = msg_start + 4 + 2 + 32; // header(4) + version(2) + random(32)
let session_id_len = record[session_id_pos] as usize;
// Skip to extensions
let ext_len_pos = session_id_pos + 1 + session_id_len + 2 + 1; // session_id + cipher(2) + compression(1)
let ext_len = u16::from_be_bytes([record[ext_len_pos], record[ext_len_pos + 1]]) as usize;
// Verify extensions length matches actual data
let extensions_data = &record[ext_len_pos + 2..msg_start + 4 + msg_len];
assert_eq!(ext_len, extensions_data.len(),
"Extension length mismatch: declared {}, actual {}", ext_len, extensions_data.len());
}
#[test]
fn test_validate_tls_handshake_format() {
// Build a minimal ClientHello-like structure
let mut handshake = vec![0u8; 100];
// Put a valid-looking digest at position 11
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.copy_from_slice(&[0x42; 32]);
// Session ID length
handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32;
// This won't validate (wrong HMAC) but shouldn't panic
let secrets = vec![("test".to_string(), b"secret".to_vec())];
let result = validate_tls_handshake(&handshake, &secrets, true);
// Should return None (no match) but not panic
assert!(result.is_none());
}
} }

View File

@@ -13,103 +13,104 @@ use crate::error::{ProxyError, Result, HandshakeResult};
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::protocol::tls; use crate::protocol::tls;
use crate::stats::{Stats, ReplayChecker}; use crate::stats::{Stats, ReplayChecker};
use crate::transport::{ConnectionPool, configure_client_socket}; use crate::transport::{configure_client_socket, UpstreamManager};
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter}; use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool};
use crate::crypto::AesCtr; use crate::crypto::{AesCtr, SecureRandom};
use super::handshake::{ use crate::proxy::handshake::{
handle_tls_handshake, handle_mtproto_handshake, handle_tls_handshake, handle_mtproto_handshake,
HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce, HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce,
}; };
use super::relay::relay_bidirectional; use crate::proxy::relay::relay_bidirectional;
use super::masking::handle_bad_client; use crate::proxy::masking::handle_bad_client;
/// Client connection handler pub struct ClientHandler;
pub struct ClientHandler {
pub struct RunningClientHandler {
stream: TcpStream,
peer: SocketAddr,
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
stats: Arc<Stats>, stats: Arc<Stats>,
replay_checker: Arc<ReplayChecker>, replay_checker: Arc<ReplayChecker>,
pool: Arc<ConnectionPool>, upstream_manager: Arc<UpstreamManager>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
} }
impl ClientHandler { impl ClientHandler {
/// Create new client handler
pub fn new( pub fn new(
stream: TcpStream,
peer: SocketAddr,
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
stats: Arc<Stats>, stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>, replay_checker: Arc<ReplayChecker>,
pool: Arc<ConnectionPool>, buffer_pool: Arc<BufferPool>,
) -> Self { rng: Arc<SecureRandom>,
Self { ) -> RunningClientHandler {
config, RunningClientHandler {
stats, stream, peer, config, stats, replay_checker,
replay_checker, upstream_manager, buffer_pool, rng,
pool,
} }
} }
}
/// Handle a client connection
pub async fn handle(&self, stream: TcpStream, peer: SocketAddr) { impl RunningClientHandler {
pub async fn run(mut self) -> Result<()> {
self.stats.increment_connects_all(); self.stats.increment_connects_all();
let peer = self.peer;
debug!(peer = %peer, "New connection"); debug!(peer = %peer, "New connection");
// Configure socket
if let Err(e) = configure_client_socket( if let Err(e) = configure_client_socket(
&stream, &self.stream,
self.config.client_keepalive, self.config.timeouts.client_keepalive,
self.config.client_ack_timeout, self.config.timeouts.client_ack,
) { ) {
debug!(peer = %peer, error = %e, "Failed to configure client socket"); debug!(peer = %peer, error = %e, "Failed to configure client socket");
} }
// Perform handshake with timeout let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
let handshake_timeout = Duration::from_secs(self.config.client_handshake_timeout); let stats = self.stats.clone();
let result = timeout( let result = timeout(handshake_timeout, self.do_handshake()).await;
handshake_timeout,
self.do_handshake(stream, peer)
).await;
match result { match result {
Ok(Ok(())) => { Ok(Ok(())) => {
debug!(peer = %peer, "Connection handled successfully"); debug!(peer = %peer, "Connection handled successfully");
Ok(())
} }
Ok(Err(e)) => { Ok(Err(e)) => {
debug!(peer = %peer, error = %e, "Handshake failed"); debug!(peer = %peer, error = %e, "Handshake failed");
Err(e)
} }
Err(_) => { Err(_) => {
self.stats.increment_handshake_timeouts(); stats.increment_handshake_timeouts();
debug!(peer = %peer, "Handshake timeout"); debug!(peer = %peer, "Handshake timeout");
Err(ProxyError::TgHandshakeTimeout)
} }
} }
} }
/// Perform handshake and relay async fn do_handshake(mut self) -> Result<()> {
async fn do_handshake(&self, mut stream: TcpStream, peer: SocketAddr) -> Result<()> {
// Read first bytes to determine handshake type
let mut first_bytes = [0u8; 5]; let mut first_bytes = [0u8; 5];
stream.read_exact(&mut first_bytes).await?; self.stream.read_exact(&mut first_bytes).await?;
let is_tls = tls::is_tls_handshake(&first_bytes[..3]); let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
let peer = self.peer;
debug!(peer = %peer, is_tls = is_tls, first_bytes = %hex::encode(&first_bytes), "Handshake type detected"); debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
if is_tls { if is_tls {
self.handle_tls_client(stream, peer, first_bytes).await self.handle_tls_client(first_bytes).await
} else { } else {
self.handle_direct_client(stream, peer, first_bytes).await self.handle_direct_client(first_bytes).await
} }
} }
/// Handle TLS-wrapped client async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
async fn handle_tls_client( let peer = self.peer;
&self,
mut stream: TcpStream,
peer: SocketAddr,
first_bytes: [u8; 5],
) -> Result<()> {
// Read TLS handshake length
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake"); debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake");
@@ -117,113 +118,111 @@ impl ClientHandler {
if tls_len < 512 { if tls_len < 512 {
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
handle_bad_client(stream, &first_bytes, &self.config).await; let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(()); return Ok(());
} }
// Read full TLS handshake
let mut handshake = vec![0u8; 5 + tls_len]; let mut handshake = vec![0u8; 5 + tls_len];
handshake[..5].copy_from_slice(&first_bytes); handshake[..5].copy_from_slice(&first_bytes);
stream.read_exact(&mut handshake[5..]).await?; self.stream.read_exact(&mut handshake[5..]).await?;
// Split stream for reading/writing let config = self.config.clone();
let (read_half, write_half) = stream.into_split(); let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
let (read_half, write_half) = self.stream.into_split();
// Handle TLS handshake
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
&handshake, &handshake, read_half, write_half, peer,
read_half, &config, &replay_checker, &self.rng,
write_half,
peer,
&self.config,
&self.replay_checker,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => { HandshakeResult::BadClient { reader, writer } => {
self.stats.increment_connects_bad(); stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(()); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
// Read MTProto handshake through TLS
debug!(peer = %peer, "Reading MTProto handshake through TLS"); debug!(peer = %peer, "Reading MTProto handshake through TLS");
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into() let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?; .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
// Handle MTProto handshake
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&mtproto_handshake, &mtproto_handshake, tls_reader, tls_writer, peer,
tls_reader, &config, &replay_checker, true,
tls_writer,
peer,
&self.config,
&self.replay_checker,
true,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => { HandshakeResult::BadClient { reader: _, writer: _ } => {
self.stats.increment_connects_bad(); stats.increment_connects_bad();
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
return Ok(()); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
// Handle authenticated client Self::handle_authenticated_static(
self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await crypto_reader, crypto_writer, success,
self.upstream_manager, self.stats, self.config,
buffer_pool, self.rng,
).await
} }
/// Handle direct (non-TLS) client async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
async fn handle_direct_client( let peer = self.peer;
&self,
mut stream: TcpStream, if !self.config.general.modes.classic && !self.config.general.modes.secure {
peer: SocketAddr,
first_bytes: [u8; 5],
) -> Result<()> {
// Check if non-TLS modes are enabled
if !self.config.modes.classic && !self.config.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled"); debug!(peer = %peer, "Non-TLS modes disabled");
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
handle_bad_client(stream, &first_bytes, &self.config).await; let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(()); return Ok(());
} }
// Read rest of handshake
let mut handshake = [0u8; HANDSHAKE_LEN]; let mut handshake = [0u8; HANDSHAKE_LEN];
handshake[..5].copy_from_slice(&first_bytes); handshake[..5].copy_from_slice(&first_bytes);
stream.read_exact(&mut handshake[5..]).await?; self.stream.read_exact(&mut handshake[5..]).await?;
// Split stream let config = self.config.clone();
let (read_half, write_half) = stream.into_split(); let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
let (read_half, write_half) = self.stream.into_split();
// Handle MTProto handshake
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&handshake, &handshake, read_half, write_half, peer,
read_half, &config, &replay_checker, false,
write_half,
peer,
&self.config,
&self.replay_checker,
false,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => { HandshakeResult::BadClient { reader, writer } => {
self.stats.increment_connects_bad(); stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(()); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await Self::handle_authenticated_static(
crypto_reader, crypto_writer, success,
self.upstream_manager, self.stats, self.config,
buffer_pool, self.rng,
).await
} }
/// Handle authenticated client - connect to Telegram and relay async fn handle_authenticated_static<R, W>(
async fn handle_authenticated_inner<R, W>(
&self,
client_reader: CryptoReader<R>, client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>, client_writer: CryptoWriter<W>,
success: HandshakeSuccess, success: HandshakeSuccess,
upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
) -> Result<()> ) -> Result<()>
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
@@ -231,14 +230,12 @@ impl ClientHandler {
{ {
let user = &success.user; let user = &success.user;
// Check user limits if let Err(e) = Self::check_user_limits_static(user, &config, &stats) {
if let Err(e) = self.check_user_limits(user) {
warn!(user = %user, error = %e, "User limit exceeded"); warn!(user = %user, error = %e, "User limit exceeded");
return Err(e); return Err(e);
} }
// Get datacenter address let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?;
let dc_addr = self.get_dc_addr(success.dc_idx)?;
info!( info!(
user = %user, user = %user,
@@ -246,69 +243,54 @@ impl ClientHandler {
dc = success.dc_idx, dc = success.dc_idx,
dc_addr = %dc_addr, dc_addr = %dc_addr,
proto = ?success.proto_tag, proto = ?success.proto_tag,
fast_mode = self.config.fast_mode,
"Connecting to Telegram" "Connecting to Telegram"
); );
// Connect to Telegram // Pass dc_idx for latency-based upstream selection
let tg_stream = self.pool.get(dc_addr).await?; let tg_stream = upstream_manager.connect(dc_addr, Some(success.dc_idx)).await?;
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected to Telegram, performing handshake"); debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
// Perform Telegram handshake and get crypto streams let (tg_reader, tg_writer) = Self::do_tg_handshake_static(
let (tg_reader, tg_writer) = self.do_tg_handshake( tg_stream, &success, &config, rng.as_ref(),
tg_stream,
&success,
).await?; ).await?;
debug!(peer = %success.peer, "Telegram handshake complete, starting relay"); debug!(peer = %success.peer, "TG handshake complete, starting relay");
// Update stats stats.increment_user_connects(user);
self.stats.increment_user_connects(user); stats.increment_user_curr_connects(user);
self.stats.increment_user_curr_connects(user);
// Relay traffic - передаём Arc::clone(&self.stats)
let relay_result = relay_bidirectional( let relay_result = relay_bidirectional(
client_reader, client_reader, client_writer,
client_writer, tg_reader, tg_writer,
tg_reader, user, Arc::clone(&stats), buffer_pool,
tg_writer,
user,
Arc::clone(&self.stats),
).await; ).await;
// Update stats stats.decrement_user_curr_connects(user);
self.stats.decrement_user_curr_connects(user);
match &relay_result { match &relay_result {
Ok(()) => debug!(user = %user, peer = %success.peer, "Relay completed normally"), Ok(()) => debug!(user = %user, "Relay completed"),
Err(e) => debug!(user = %user, peer = %success.peer, error = %e, "Relay ended with error"), Err(e) => debug!(user = %user, error = %e, "Relay ended with error"),
} }
relay_result relay_result
} }
/// Check user limits (expiration, connection count, data quota) fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
fn check_user_limits(&self, user: &str) -> Result<()> { if let Some(expiration) = config.access.user_expirations.get(user) {
// Check expiration
if let Some(expiration) = self.config.user_expirations.get(user) {
if chrono::Utc::now() > *expiration { if chrono::Utc::now() > *expiration {
return Err(ProxyError::UserExpired { user: user.to_string() }); return Err(ProxyError::UserExpired { user: user.to_string() });
} }
} }
// Check connection limit if let Some(limit) = config.access.user_max_tcp_conns.get(user) {
if let Some(limit) = self.config.user_max_tcp_conns.get(user) { if stats.get_user_curr_connects(user) >= *limit as u64 {
let current = self.stats.get_user_curr_connects(user);
if current >= *limit as u64 {
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() }); return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
} }
} }
// Check data quota if let Some(quota) = config.access.user_data_quota.get(user) {
if let Some(quota) = self.config.user_data_quota.get(user) { if stats.get_user_total_octets(user) >= *quota {
let used = self.stats.get_user_total_octets(user);
if used >= *quota {
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() }); return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
} }
} }
@@ -316,63 +298,105 @@ impl ClientHandler {
Ok(()) Ok(())
} }
/// Get datacenter address by index /// Resolve DC index to a target address.
fn get_dc_addr(&self, dc_idx: i16) -> Result<SocketAddr> { ///
let idx = (dc_idx.abs() - 1) as usize; /// Matches the C implementation's behavior exactly:
///
let datacenters = if self.config.prefer_ipv6 { /// 1. Look up DC in known clusters (standard DCs ±1..±5)
/// 2. If not found and `force=1` → fall back to `default_cluster`
///
/// In the C code:
/// - `proxy-multi.conf` is downloaded from Telegram, contains only DC ±1..±5
/// - `default 2;` directive sets the default cluster
/// - `mf_cluster_lookup(CurConf, target_dc, 1)` returns default_cluster
/// for any unknown DC (like CDN DC 203)
///
/// So DC 203, DC 101, DC -300, etc. all route to the default DC (2).
/// There is NO modular arithmetic in the C implementation.
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let datacenters = if config.general.prefer_ipv6 {
&*TG_DATACENTERS_V6 &*TG_DATACENTERS_V6
} else { } else {
&*TG_DATACENTERS_V4 &*TG_DATACENTERS_V4
}; };
datacenters.get(idx) let num_dcs = datacenters.len(); // 5
.map(|ip| SocketAddr::new(*ip, TG_DATACENTER_PORT))
.ok_or_else(|| ProxyError::InvalidHandshake( // === Step 1: Check dc_overrides (like C's `proxy_for <dc> <ip>:<port>`) ===
format!("Invalid DC index: {}", dc_idx) let dc_key = dc_idx.to_string();
)) if let Some(addr_str) = config.dc_overrides.get(&dc_key) {
match addr_str.parse::<SocketAddr>() {
Ok(addr) => {
debug!(dc_idx = dc_idx, addr = %addr, "Using DC override from config");
return Ok(addr);
}
Err(_) => {
warn!(dc_idx = dc_idx, addr_str = %addr_str,
"Invalid DC override address in config, ignoring");
}
}
}
// === Step 2: Standard DCs ±1..±5 — direct lookup ===
let abs_dc = dc_idx.unsigned_abs() as usize;
if abs_dc >= 1 && abs_dc <= num_dcs {
return Ok(SocketAddr::new(datacenters[abs_dc - 1], TG_DATACENTER_PORT));
}
// === Step 3: Unknown DC — fall back to default_cluster ===
// Exactly like C's `mf_cluster_lookup(CurConf, target_dc, force=1)`
// which returns `MC->default_cluster` when the DC is not found.
// Telegram's proxy-multi.conf uses `default 2;`
let default_dc = config.default_dc.unwrap_or(2) as usize;
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
default_dc - 1
} else {
1 // DC 2 (index 1) — matches Telegram's `default 2;`
};
info!(
original_dc = dc_idx,
fallback_dc = (fallback_idx + 1) as u16,
fallback_addr = %datacenters[fallback_idx],
"Special DC ---> default_cluster"
);
Ok(SocketAddr::new(datacenters[fallback_idx], TG_DATACENTER_PORT))
} }
/// Perform handshake with Telegram server async fn do_tg_handshake_static(
/// Returns crypto reader and writer for TG connection
async fn do_tg_handshake(
&self,
mut stream: TcpStream, mut stream: TcpStream,
success: &HandshakeSuccess, success: &HandshakeSuccess,
config: &ProxyConfig,
rng: &SecureRandom,
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> { ) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
// Generate nonce with keys for TG
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce( let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
success.proto_tag, success.proto_tag,
&success.dec_key, // Client's dec key &success.dec_key,
success.dec_iv, success.dec_iv,
self.config.fast_mode, rng,
config.general.fast_mode,
); );
// Encrypt nonce
let encrypted_nonce = encrypt_tg_nonce(&nonce); let encrypted_nonce = encrypt_tg_nonce(&nonce);
debug!( debug!(
peer = %success.peer, peer = %success.peer,
nonce_head = %hex::encode(&nonce[..16]), nonce_head = %hex::encode(&nonce[..16]),
encrypted_head = %hex::encode(&encrypted_nonce[..16]),
"Sending nonce to Telegram" "Sending nonce to Telegram"
); );
// Send to Telegram
stream.write_all(&encrypted_nonce).await?; stream.write_all(&encrypted_nonce).await?;
stream.flush().await?; stream.flush().await?;
debug!(peer = %success.peer, "Nonce sent to Telegram");
// Split stream and wrap with crypto
let (read_half, write_half) = stream.into_split(); let (read_half, write_half) = stream.into_split();
let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv); let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv);
let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv); let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv);
let tg_reader = CryptoReader::new(read_half, decryptor); Ok((
let tg_writer = CryptoWriter::new(write_half, encryptor); CryptoReader::new(read_half, decryptor),
CryptoWriter::new(write_half, encryptor),
Ok((tg_reader, tg_writer)) ))
} }
} }

View File

@@ -1,11 +1,11 @@
//! MTProto Handshake Magics //! MTProto Handshake
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, warn, trace, info}; use tracing::{debug, warn, trace, info};
use zeroize::Zeroize;
use crate::crypto::{sha256, AesCtr}; use crate::crypto::{sha256, AesCtr, SecureRandom};
use crate::crypto::random::SECURE_RANDOM;
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::protocol::tls; use crate::protocol::tls;
use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter}; use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
@@ -14,6 +14,9 @@ use crate::stats::ReplayChecker;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
/// Result of successful handshake /// Result of successful handshake
///
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
/// zeroized on drop.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct HandshakeSuccess { pub struct HandshakeSuccess {
/// Authenticated user name /// Authenticated user name
@@ -34,6 +37,15 @@ pub struct HandshakeSuccess {
pub is_tls: bool, pub is_tls: bool,
} }
impl Drop for HandshakeSuccess {
fn drop(&mut self) {
self.dec_key.zeroize();
self.dec_iv.zeroize();
self.enc_key.zeroize();
self.enc_iv.zeroize();
}
}
/// Handle fake TLS handshake /// Handle fake TLS handshake
pub async fn handle_tls_handshake<R, W>( pub async fn handle_tls_handshake<R, W>(
handshake: &[u8], handshake: &[u8],
@@ -42,76 +54,74 @@ pub async fn handle_tls_handshake<R, W>(
peer: SocketAddr, peer: SocketAddr,
config: &ProxyConfig, config: &ProxyConfig,
replay_checker: &ReplayChecker, replay_checker: &ReplayChecker,
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String)> rng: &SecureRandom,
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
where where
R: AsyncRead + Unpin, R: AsyncRead + Unpin,
W: AsyncWrite + Unpin, W: AsyncWrite + Unpin,
{ {
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
// Check minimum length
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
debug!(peer = %peer, "TLS handshake too short"); debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient; return HandshakeResult::BadClient { reader, writer };
} }
// Extract digest for replay check
let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]; let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN];
let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN]; let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN];
// Check for replay
if replay_checker.check_tls_digest(digest_half) { if replay_checker.check_tls_digest(digest_half) {
warn!(peer = %peer, "TLS replay attack detected"); warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient; return HandshakeResult::BadClient { reader, writer };
} }
// Build secrets list let secrets: Vec<(String, Vec<u8>)> = config.access.users.iter()
let secrets: Vec<(String, Vec<u8>)> = config.users.iter()
.filter_map(|(name, hex)| { .filter_map(|(name, hex)| {
hex::decode(hex).ok().map(|bytes| (name.clone(), bytes)) hex::decode(hex).ok().map(|bytes| (name.clone(), bytes))
}) })
.collect(); .collect();
debug!(peer = %peer, num_users = secrets.len(), "Validating TLS handshake against users");
// Validate handshake
let validation = match tls::validate_tls_handshake( let validation = match tls::validate_tls_handshake(
handshake, handshake,
&secrets, &secrets,
config.ignore_time_skew, config.access.ignore_time_skew,
) { ) {
Some(v) => v, Some(v) => v,
None => { None => {
debug!(peer = %peer, "TLS handshake validation failed - no matching user"); debug!(
return HandshakeResult::BadClient; peer = %peer,
ignore_time_skew = config.access.ignore_time_skew,
"TLS handshake validation failed - no matching user or time skew"
);
return HandshakeResult::BadClient { reader, writer };
} }
}; };
// Get secret for response
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s, Some((_, s)) => s,
None => return HandshakeResult::BadClient, None => return HandshakeResult::BadClient { reader, writer },
}; };
// Build and send response
let response = tls::build_server_hello( let response = tls::build_server_hello(
secret, secret,
&validation.digest, &validation.digest,
&validation.session_id, &validation.session_id,
config.fake_cert_len, config.censorship.fake_cert_len,
rng,
); );
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
if let Err(e) = writer.write_all(&response).await { if let Err(e) = writer.write_all(&response).await {
warn!(peer = %peer, error = %e, "Failed to write TLS ServerHello");
return HandshakeResult::Error(ProxyError::Io(e)); return HandshakeResult::Error(ProxyError::Io(e));
} }
if let Err(e) = writer.flush().await { if let Err(e) = writer.flush().await {
warn!(peer = %peer, error = %e, "Failed to flush TLS ServerHello");
return HandshakeResult::Error(ProxyError::Io(e)); return HandshakeResult::Error(ProxyError::Io(e));
} }
// Record for replay protection
replay_checker.add_tls_digest(digest_half); replay_checker.add_tls_digest(digest_half);
info!( info!(
@@ -136,39 +146,28 @@ pub async fn handle_mtproto_handshake<R, W>(
config: &ProxyConfig, config: &ProxyConfig,
replay_checker: &ReplayChecker, replay_checker: &ReplayChecker,
is_tls: bool, is_tls: bool,
) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess)> ) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess), R, W>
where where
R: AsyncRead + Unpin + Send, R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send, W: AsyncWrite + Unpin + Send,
{ {
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
// Extract prekey and IV
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
debug!(
peer = %peer,
dec_prekey_iv = %hex::encode(dec_prekey_iv),
"Extracted prekey+IV from handshake"
);
// Check for replay
if replay_checker.check_handshake(dec_prekey_iv) { if replay_checker.check_handshake(dec_prekey_iv) {
warn!(peer = %peer, "MTProto replay attack detected"); warn!(peer = %peer, "MTProto replay attack detected");
return HandshakeResult::BadClient; return HandshakeResult::BadClient { reader, writer };
} }
// Reversed for encryption direction
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
// Try each user's secret for (user, secret_hex) in &config.access.users {
for (user, secret_hex) in &config.users {
let secret = match hex::decode(secret_hex) { let secret = match hex::decode(secret_hex) {
Ok(s) => s, Ok(s) => s,
Err(_) => continue, Err(_) => continue,
}; };
// Derive decryption key
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
@@ -179,38 +178,23 @@ where
let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
// Decrypt handshake to check protocol tag
let mut decryptor = AesCtr::new(&dec_key, dec_iv); let mut decryptor = AesCtr::new(&dec_key, dec_iv);
let decrypted = decryptor.decrypt(handshake); let decrypted = decryptor.decrypt(handshake);
trace!(
peer = %peer,
user = %user,
decrypted_tail = %hex::encode(&decrypted[PROTO_TAG_POS..]),
"Decrypted handshake tail"
);
// Check protocol tag
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
.try_into() .try_into()
.unwrap(); .unwrap();
let proto_tag = match ProtoTag::from_bytes(tag_bytes) { let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
Some(tag) => tag, Some(tag) => tag,
None => { None => continue,
trace!(peer = %peer, user = %user, tag = %hex::encode(tag_bytes), "Invalid proto tag");
continue;
}
}; };
debug!(peer = %peer, user = %user, proto = ?proto_tag, "Found valid proto tag");
// Check if mode is enabled
let mode_ok = match proto_tag { let mode_ok = match proto_tag {
ProtoTag::Secure => { ProtoTag::Secure => {
if is_tls { config.modes.tls } else { config.modes.secure } if is_tls { config.general.modes.tls } else { config.general.modes.secure }
} }
ProtoTag::Intermediate | ProtoTag::Abridged => config.modes.classic, ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic,
}; };
if !mode_ok { if !mode_ok {
@@ -218,12 +202,10 @@ where
continue; continue;
} }
// Extract DC index
let dc_idx = i16::from_le_bytes( let dc_idx = i16::from_le_bytes(
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
); );
// Derive encryption key
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
@@ -234,10 +216,8 @@ where
let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
// Record for replay protection
replay_checker.add_handshake(dec_prekey_iv); replay_checker.add_handshake(dec_prekey_iv);
// Create new cipher instances
let decryptor = AesCtr::new(&dec_key, dec_iv); let decryptor = AesCtr::new(&dec_key, dec_iv);
let encryptor = AesCtr::new(&enc_key, enc_iv); let encryptor = AesCtr::new(&enc_key, enc_iv);
@@ -270,56 +250,37 @@ where
} }
debug!(peer = %peer, "MTProto handshake: no matching user found"); debug!(peer = %peer, "MTProto handshake: no matching user found");
HandshakeResult::BadClient HandshakeResult::BadClient { reader, writer }
} }
/// Generate nonce for Telegram connection /// Generate nonce for Telegram connection
///
/// In FAST MODE: we use the same keys for TG as for client, but reversed.
/// This means: client's enc_key becomes TG's dec_key and vice versa.
pub fn generate_tg_nonce( pub fn generate_tg_nonce(
proto_tag: ProtoTag, proto_tag: ProtoTag,
client_dec_key: &[u8; 32], client_dec_key: &[u8; 32],
client_dec_iv: u128, client_dec_iv: u128,
rng: &SecureRandom,
fast_mode: bool, fast_mode: bool,
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) { ) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
loop { loop {
let bytes = SECURE_RANDOM.bytes(HANDSHAKE_LEN); let bytes = rng.bytes(HANDSHAKE_LEN);
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap(); let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
// Check reserved patterns if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
continue;
}
let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; }
continue;
}
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; }
continue;
}
// Set protocol tag
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
// Fast mode: copy client's dec_key+iv (this becomes TG's enc direction)
// In fast mode, we make TG use the same keys as client but swapped:
// - What we decrypt FROM TG = what we encrypt TO client (so no re-encryption needed)
// - What we encrypt TO TG = what we decrypt FROM client
if fast_mode { if fast_mode {
// Put client's dec_key + dec_iv into nonce[8:56]
// This will be used by TG for encryption TO us
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key); nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key);
nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN] nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN]
.copy_from_slice(&client_dec_iv.to_be_bytes()); .copy_from_slice(&client_dec_iv.to_be_bytes());
} }
// Now compute what keys WE will use for TG connection
// enc_key_iv = nonce[8:56] (for encrypting TO TG)
// dec_key_iv = nonce[8:56] reversed (for decrypting FROM TG)
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect(); let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
@@ -329,44 +290,22 @@ pub fn generate_tg_nonce(
let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap(); let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap()); let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
debug!(
fast_mode = fast_mode,
tg_enc_key = %hex::encode(&tg_enc_key[..8]),
tg_dec_key = %hex::encode(&tg_dec_key[..8]),
"Generated TG nonce"
);
return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv); return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv);
} }
} }
/// Encrypt nonce for sending to Telegram /// Encrypt nonce for sending to Telegram
///
/// Only the part from PROTO_TAG_POS onwards is encrypted.
/// The encryption key is derived from enc_key_iv in the nonce itself.
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> { pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
// enc_key_iv is at nonce[8:56]
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
// Key for encrypting is just the first 32 bytes of enc_key_iv
let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let mut encryptor = AesCtr::new(&key, iv); let mut encryptor = AesCtr::new(&key, iv);
// Encrypt the entire nonce first, then take only the encrypted tail
let encrypted_full = encryptor.encrypt(nonce); let encrypted_full = encryptor.encrypt(nonce);
// Result: unencrypted head + encrypted tail
let mut result = nonce[..PROTO_TAG_POS].to_vec(); let mut result = nonce[..PROTO_TAG_POS].to_vec();
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]); result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
trace!(
original = %hex::encode(&nonce[PROTO_TAG_POS..]),
encrypted = %hex::encode(&result[PROTO_TAG_POS..]),
"Encrypted nonce tail"
);
result result
} }
@@ -379,13 +318,12 @@ mod tests {
let client_dec_key = [0x42u8; 32]; let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128; let client_dec_iv = 12345u128;
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = let rng = SecureRandom::new();
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false); let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) =
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false);
// Check length
assert_eq!(nonce.len(), HANDSHAKE_LEN); assert_eq!(nonce.len(), HANDSHAKE_LEN);
// Check proto tag is set
let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap();
assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure));
} }
@@ -395,17 +333,35 @@ mod tests {
let client_dec_key = [0x42u8; 32]; let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128; let client_dec_iv = 12345u128;
let rng = SecureRandom::new();
let (nonce, _, _, _, _) = let (nonce, _, _, _, _) =
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false); generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false);
let encrypted = encrypt_tg_nonce(&nonce); let encrypted = encrypt_tg_nonce(&nonce);
assert_eq!(encrypted.len(), HANDSHAKE_LEN); assert_eq!(encrypted.len(), HANDSHAKE_LEN);
// First PROTO_TAG_POS bytes should be unchanged
assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]);
// Rest should be different (encrypted)
assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]);
} }
#[test]
fn test_handshake_success_zeroize_on_drop() {
let success = HandshakeSuccess {
user: "test".to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Secure,
dec_key: [0xAA; 32],
dec_iv: 0xBBBBBBBB,
enc_key: [0xCC; 32],
enc_iv: 0xDDDDDDDD,
peer: "127.0.0.1:1234".parse().unwrap(),
is_tls: true,
};
assert_eq!(success.dec_key, [0xAA; 32]);
assert_eq!(success.enc_key, [0xCC; 32]);
drop(success);
// Drop impl zeroizes key material without panic
}
} }

View File

@@ -1,35 +1,76 @@
//! Masking - forward unrecognized traffic to mask host //! Masking - forward unrecognized traffic to mask host
use std::time::Duration; use std::time::Duration;
use std::str;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout; use tokio::time::timeout;
use tracing::debug; use tracing::debug;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::transport::set_linger_zero;
const MASK_TIMEOUT: Duration = Duration::from_secs(5); const MASK_TIMEOUT: Duration = Duration::from_secs(5);
/// Maximum duration for the entire masking relay.
/// Limits resource consumption from slow-loris attacks and port scanners.
const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60);
const MASK_BUFFER_SIZE: usize = 8192; const MASK_BUFFER_SIZE: usize = 8192;
/// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request
if data.len() > 4 {
if data.starts_with(b"GET ") || data.starts_with(b"POST") ||
data.starts_with(b"HEAD") || data.starts_with(b"PUT ") ||
data.starts_with(b"DELETE") || data.starts_with(b"OPTIONS") {
return "HTTP";
}
}
// Check for TLS ClientHello (0x16 = handshake, 0x03 0x01-0x03 = TLS version)
if data.len() > 3 && data[0] == 0x16 && data[1] == 0x03 {
return "TLS-scanner";
}
// Check for SSH
if data.starts_with(b"SSH-") {
return "SSH";
}
// Port scanner (very short data)
if data.len() < 10 {
return "port-scanner";
}
"unknown"
}
/// Handle a bad client by forwarding to mask host /// Handle a bad client by forwarding to mask host
pub async fn handle_bad_client( pub async fn handle_bad_client<R, W>(
mut client: TcpStream, mut reader: R,
mut writer: W,
initial_data: &[u8], initial_data: &[u8],
config: &ProxyConfig, config: &ProxyConfig,
) { )
if !config.mask { where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
if !config.censorship.mask {
// Masking disabled, just consume data // Masking disabled, just consume data
consume_client_data(client).await; consume_client_data(reader).await;
return; return;
} }
let mask_host = config.mask_host.as_deref() let client_type = detect_client_type(initial_data);
.unwrap_or(&config.tls_domain);
let mask_port = config.mask_port; let mask_host = config.censorship.mask_host.as_deref()
.unwrap_or(&config.censorship.tls_domain);
let mask_port = config.censorship.mask_port;
debug!( debug!(
client_type = client_type,
host = %mask_host, host = %mask_host,
port = mask_port, port = mask_port,
data_len = initial_data.len(),
"Forwarding bad client to mask host" "Forwarding bad client to mask host"
); );
@@ -40,33 +81,32 @@ pub async fn handle_bad_client(
TcpStream::connect(&mask_addr) TcpStream::connect(&mask_addr)
).await; ).await;
let mut mask_stream = match connect_result { let mask_stream = match connect_result {
Ok(Ok(s)) => s, Ok(Ok(s)) => s,
Ok(Err(e)) => { Ok(Err(e)) => {
debug!(error = %e, "Failed to connect to mask host"); debug!(error = %e, "Failed to connect to mask host");
consume_client_data(client).await; consume_client_data(reader).await;
return; return;
} }
Err(_) => { Err(_) => {
debug!("Timeout connecting to mask host"); debug!("Timeout connecting to mask host");
consume_client_data(client).await; consume_client_data(reader).await;
return; return;
} }
}; };
let (mut mask_read, mut mask_write) = mask_stream.into_split();
// Send initial data to mask host // Send initial data to mask host
if mask_stream.write_all(initial_data).await.is_err() { if mask_write.write_all(initial_data).await.is_err() {
return; return;
} }
// Relay traffic // Relay traffic
let (mut client_read, mut client_write) = client.into_split();
let (mut mask_read, mut mask_write) = mask_stream.into_split();
let c2m = tokio::spawn(async move { let c2m = tokio::spawn(async move {
let mut buf = vec![0u8; MASK_BUFFER_SIZE]; let mut buf = vec![0u8; MASK_BUFFER_SIZE];
loop { loop {
match client_read.read(&mut buf).await { match reader.read(&mut buf).await {
Ok(0) | Err(_) => { Ok(0) | Err(_) => {
let _ = mask_write.shutdown().await; let _ = mask_write.shutdown().await;
break; break;
@@ -85,11 +125,11 @@ pub async fn handle_bad_client(
loop { loop {
match mask_read.read(&mut buf).await { match mask_read.read(&mut buf).await {
Ok(0) | Err(_) => { Ok(0) | Err(_) => {
let _ = client_write.shutdown().await; let _ = writer.shutdown().await;
break; break;
} }
Ok(n) => { Ok(n) => {
if client_write.write_all(&buf[..n]).await.is_err() { if writer.write_all(&buf[..n]).await.is_err() {
break; break;
} }
} }
@@ -105,9 +145,9 @@ pub async fn handle_bad_client(
} }
/// Just consume all data from client without responding /// Just consume all data from client without responding
async fn consume_client_data(mut client: TcpStream) { async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
let mut buf = vec![0u8; MASK_BUFFER_SIZE]; let mut buf = vec![0u8; MASK_BUFFER_SIZE];
while let Ok(n) = client.read(&mut buf).await { while let Ok(n) = reader.read(&mut buf).await {
if n == 0 { if n == 0 {
break; break;
} }

View File

@@ -1,22 +1,320 @@
//! Bidirectional Relay //! Bidirectional Relay — poll-based, no head-of-line blocking
//!
//! ## What changed and why
//!
//! Previous implementation used a single-task `select! { biased; ... }` loop
//! where each branch called `write_all()`. This caused head-of-line blocking:
//! while `write_all()` waited for a slow writer (e.g. client on 3G downloading
//! media), the entire loop was blocked — the other direction couldn't make progress.
//!
//! Symptoms observed in production:
//! - Media loading at ~8 KB/s despite fast server connection
//! - Stop-and-go pattern with 50500ms gaps between chunks
//! - `biased` select starving S→C direction
//! - Some users unable to load media at all
//!
//! ## New architecture
//!
//! Uses `tokio::io::copy_bidirectional` which polls both directions concurrently
//! in a single task via non-blocking `poll_read` / `poll_write` calls:
//!
//! Old (select! + write_all — BLOCKING):
//!
//! loop {
//! select! {
//! biased;
//! data = client.read() => { server.write_all(data).await; } ← BLOCKS here
//! data = server.read() => { client.write_all(data).await; } ← can't run
//! }
//! }
//!
//! New (copy_bidirectional — CONCURRENT):
//!
//! poll(cx) {
//! // Both directions polled in the same poll cycle
//! C→S: poll_read(client) → poll_write(server) // non-blocking
//! S→C: poll_read(server) → poll_write(client) // non-blocking
//! // If one writer is Pending, the other direction still progresses
//! }
//!
//! Benefits:
//! - No head-of-line blocking: slow client download doesn't block uploads
//! - No biased starvation: fair polling of both directions
//! - Proper flush: `copy_bidirectional` calls `poll_flush` when reader stalls,
//! so CryptoWriter's pending ciphertext is always drained (fixes "stuck at 95%")
//! - No deadlock risk: old write_all could deadlock when both TCP buffers filled;
//! poll-based approach lets TCP flow control work correctly
//!
//! Stats tracking:
//! - `StatsIo` wraps client side, intercepts `poll_read` / `poll_write`
//! - `poll_read` on client = C→S (client sending) → `octets_from`, `msgs_from`
//! - `poll_write` on client = S→C (to client) → `octets_to`, `msgs_to`
//! - `SharedCounters` (atomics) let the watchdog read stats without locking
use std::io;
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use std::sync::atomic::{AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional};
use tokio::time::Instant;
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
use crate::error::Result; use crate::error::Result;
use crate::stats::Stats; use crate::stats::Stats;
use std::sync::atomic::{AtomicU64, Ordering}; use crate::stream::BufferPool;
const BUFFER_SIZE: usize = 65536; // ============= Constants =============
/// Relay data bidirectionally between client and server /// Activity timeout for iOS compatibility.
///
/// iOS keeps Telegram connections alive in background for up to 30 minutes.
/// Closing earlier causes unnecessary reconnects and handshake overhead.
const ACTIVITY_TIMEOUT: Duration = Duration::from_secs(1800);
/// Watchdog check interval — also used for periodic rate logging.
///
/// 10 seconds gives responsive timeout detection (±10s accuracy)
/// without measurable overhead from atomic reads.
const WATCHDOG_INTERVAL: Duration = Duration::from_secs(10);
// ============= CombinedStream =============
/// Combines separate read and write halves into a single bidirectional stream.
///
/// `copy_bidirectional` requires `AsyncRead + AsyncWrite` on each side,
/// but the handshake layer produces split reader/writer pairs
/// (e.g. `CryptoReader<FakeTlsReader<OwnedReadHalf>>` + `CryptoWriter<...>`).
///
/// This wrapper reunifies them with zero overhead — each trait method
/// delegates directly to the corresponding half. No buffering, no copies.
///
/// Safety: `poll_read` only touches `reader`, `poll_write` only touches `writer`,
/// so there's no aliasing even though both are called on the same `&mut self`.
struct CombinedStream<R, W> {
reader: R,
writer: W,
}
impl<R, W> CombinedStream<R, W> {
fn new(reader: R, writer: W) -> Self {
Self { reader, writer }
}
}
impl<R: AsyncRead + Unpin, W: Unpin> AsyncRead for CombinedStream<R, W> {
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().reader).poll_read(cx, buf)
}
}
impl<R: Unpin, W: AsyncWrite + Unpin> AsyncWrite for CombinedStream<R, W> {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.get_mut().writer).poll_write(cx, buf)
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().writer).poll_flush(cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().writer).poll_shutdown(cx)
}
}
// ============= SharedCounters =============
/// Atomic counters shared between the relay (via StatsIo) and the watchdog task.
///
/// Using `Relaxed` ordering is sufficient because:
/// - Counters are monotonically increasing (no ABA problem)
/// - Slight staleness in watchdog reads is harmless (±10s check interval anyway)
/// - No ordering dependencies between different counters
struct SharedCounters {
/// Bytes read from client (C→S direction)
c2s_bytes: AtomicU64,
/// Bytes written to client (S→C direction)
s2c_bytes: AtomicU64,
/// Number of poll_read completions (≈ C→S chunks)
c2s_ops: AtomicU64,
/// Number of poll_write completions (≈ S→C chunks)
s2c_ops: AtomicU64,
/// Milliseconds since relay epoch of last I/O activity
last_activity_ms: AtomicU64,
}
impl SharedCounters {
fn new() -> Self {
Self {
c2s_bytes: AtomicU64::new(0),
s2c_bytes: AtomicU64::new(0),
c2s_ops: AtomicU64::new(0),
s2c_ops: AtomicU64::new(0),
last_activity_ms: AtomicU64::new(0),
}
}
/// Record activity at this instant.
#[inline]
fn touch(&self, now: Instant, epoch: Instant) {
let ms = now.duration_since(epoch).as_millis() as u64;
self.last_activity_ms.store(ms, Ordering::Relaxed);
}
/// How long since last recorded activity.
fn idle_duration(&self, now: Instant, epoch: Instant) -> Duration {
let last_ms = self.last_activity_ms.load(Ordering::Relaxed);
let now_ms = now.duration_since(epoch).as_millis() as u64;
Duration::from_millis(now_ms.saturating_sub(last_ms))
}
}
// ============= StatsIo =============
/// Transparent I/O wrapper that tracks per-user statistics and activity.
///
/// Wraps the **client** side of the relay. Direction mapping:
///
/// | poll method | direction | stats updated |
/// |-------------|-----------|--------------------------------------|
/// | `poll_read` | C→S | `octets_from`, `msgs_from`, counters |
/// | `poll_write` | S→C | `octets_to`, `msgs_to`, counters |
///
/// Both update the shared activity timestamp for the watchdog.
///
/// Note on message counts: the original code counted one `read()`/`write_all()`
/// as one "message". Here we count `poll_read`/`poll_write` completions instead.
/// Byte counts are identical; op counts may differ slightly due to different
/// internal buffering in `copy_bidirectional`. This is fine for monitoring.
struct StatsIo<S> {
inner: S,
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
epoch: Instant,
}
impl<S> StatsIo<S> {
fn new(
inner: S,
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
epoch: Instant,
) -> Self {
// Mark initial activity so the watchdog doesn't fire before data flows
counters.touch(Instant::now(), epoch);
Self { inner, counters, stats, user, epoch }
}
}
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
let before = buf.filled().len();
match Pin::new(&mut this.inner).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let n = buf.filled().len() - before;
if n > 0 {
// C→S: client sent data
this.counters.c2s_bytes.fetch_add(n as u64, Ordering::Relaxed);
this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch);
this.stats.add_user_octets_from(&this.user, n as u64);
this.stats.increment_user_msgs_from(&this.user);
trace!(user = %this.user, bytes = n, "C->S");
}
Poll::Ready(Ok(()))
}
other => other,
}
}
}
impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
match Pin::new(&mut this.inner).poll_write(cx, buf) {
Poll::Ready(Ok(n)) => {
if n > 0 {
// S→C: data written to client
this.counters.s2c_bytes.fetch_add(n as u64, Ordering::Relaxed);
this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch);
this.stats.add_user_octets_to(&this.user, n as u64);
this.stats.increment_user_msgs_to(&this.user);
trace!(user = %this.user, bytes = n, "S->C");
}
Poll::Ready(Ok(n))
}
other => other,
}
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
}
}
// ============= Relay =============
/// Relay data bidirectionally between client and server.
///
/// Uses `tokio::io::copy_bidirectional` for concurrent, non-blocking data transfer.
///
/// ## API compatibility
///
/// Signature is identical to the previous implementation. The `_buffer_pool`
/// parameter is retained for call-site compatibility — `copy_bidirectional`
/// manages its own internal buffers (8 KB per direction).
///
/// ## Guarantees preserved
///
/// - Activity timeout: 30 minutes of inactivity → clean shutdown
/// - Per-user stats: bytes and ops counted per direction
/// - Periodic rate logging: every 10 seconds when active
/// - Clean shutdown: both write sides are shut down on exit
/// - Error propagation: I/O errors are returned as `ProxyError::Io`
pub async fn relay_bidirectional<CR, CW, SR, SW>( pub async fn relay_bidirectional<CR, CW, SR, SW>(
mut client_reader: CR, client_reader: CR,
mut client_writer: CW, client_writer: CW,
mut server_reader: SR, server_reader: SR,
mut server_writer: SW, server_writer: SW,
user: &str, user: &str,
stats: Arc<Stats>, stats: Arc<Stats>,
_buffer_pool: Arc<BufferPool>,
) -> Result<()> ) -> Result<()>
where where
CR: AsyncRead + Unpin + Send + 'static, CR: AsyncRead + Unpin + Send + 'static,
@@ -24,139 +322,145 @@ where
SR: AsyncRead + Unpin + Send + 'static, SR: AsyncRead + Unpin + Send + 'static,
SW: AsyncWrite + Unpin + Send + 'static, SW: AsyncWrite + Unpin + Send + 'static,
{ {
let user_c2s = user.to_string(); let epoch = Instant::now();
let user_s2c = user.to_string(); let counters = Arc::new(SharedCounters::new());
let user_owned = user.to_string();
// Используем Arc::clone вместо stats.clone()
let stats_c2s = Arc::clone(&stats); // ── Combine split halves into bidirectional streams ──────────────
let stats_s2c = Arc::clone(&stats); let client_combined = CombinedStream::new(client_reader, client_writer);
let mut server = CombinedStream::new(server_reader, server_writer);
let c2s_bytes = Arc::new(AtomicU64::new(0));
let s2c_bytes = Arc::new(AtomicU64::new(0)); // Wrap client with stats/activity tracking
let c2s_bytes_clone = Arc::clone(&c2s_bytes); let mut client = StatsIo::new(
let s2c_bytes_clone = Arc::clone(&s2c_bytes); client_combined,
Arc::clone(&counters),
// Client -> Server task Arc::clone(&stats),
let c2s = tokio::spawn(async move { user_owned.clone(),
let mut buf = vec![0u8; BUFFER_SIZE]; epoch,
let mut total_bytes = 0u64; );
let mut msg_count = 0u64;
// ── Watchdog: activity timeout + periodic rate logging ──────────
let wd_counters = Arc::clone(&counters);
let wd_user = user_owned.clone();
let watchdog = async {
let mut prev_c2s: u64 = 0;
let mut prev_s2c: u64 = 0;
loop { loop {
match client_reader.read(&mut buf).await { tokio::time::sleep(WATCHDOG_INTERVAL).await;
Ok(0) => {
debug!( let now = Instant::now();
user = %user_c2s, let idle = wd_counters.idle_duration(now, epoch);
total_bytes = total_bytes,
msgs = msg_count, // ── Activity timeout ────────────────────────────────────
"Client closed connection (C->S)" if idle >= ACTIVITY_TIMEOUT {
); let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
let _ = server_writer.shutdown().await; let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed);
break; warn!(
} user = %wd_user,
Ok(n) => { c2s_bytes = c2s,
total_bytes += n as u64; s2c_bytes = s2c,
msg_count += 1; idle_secs = idle.as_secs(),
c2s_bytes_clone.store(total_bytes, Ordering::Relaxed); "Activity timeout"
);
stats_c2s.add_user_octets_from(&user_c2s, n as u64); return; // Causes select! to cancel copy_bidirectional
stats_c2s.increment_user_msgs_from(&user_c2s);
trace!(
user = %user_c2s,
bytes = n,
total = total_bytes,
data_preview = %hex::encode(&buf[..n.min(32)]),
"C->S data"
);
if let Err(e) = server_writer.write_all(&buf[..n]).await {
debug!(user = %user_c2s, error = %e, "Failed to write to server");
break;
}
if let Err(e) = server_writer.flush().await {
debug!(user = %user_c2s, error = %e, "Failed to flush to server");
break;
}
}
Err(e) => {
debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error");
break;
}
} }
// ── Periodic rate logging ───────────────────────────────
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed);
let c2s_delta = c2s - prev_c2s;
let s2c_delta = s2c - prev_s2c;
if c2s_delta > 0 || s2c_delta > 0 {
let secs = WATCHDOG_INTERVAL.as_secs_f64();
debug!(
user = %wd_user,
c2s_kbps = (c2s_delta as f64 / secs / 1024.0) as u64,
s2c_kbps = (s2c_delta as f64 / secs / 1024.0) as u64,
c2s_total = c2s,
s2c_total = s2c,
"Relay active"
);
}
prev_c2s = c2s;
prev_s2c = s2c;
} }
}); };
// Server -> Client task // ── Run bidirectional copy + watchdog concurrently ───────────────
let s2c = tokio::spawn(async move { //
let mut buf = vec![0u8; BUFFER_SIZE]; // copy_bidirectional polls both directions in the same poll() call:
let mut total_bytes = 0u64; // C→S: poll_read(client/StatsIo) → poll_write(server)
let mut msg_count = 0u64; // S→C: poll_read(server) → poll_write(client/StatsIo)
//
loop { // When one direction's writer returns Pending, the other direction
match server_reader.read(&mut buf).await { // continues — no head-of-line blocking.
Ok(0) => { //
debug!( // When the watchdog fires, select! drops the copy future,
user = %user_s2c, // releasing the &mut borrows on client and server.
total_bytes = total_bytes, let copy_result = tokio::select! {
msgs = msg_count, result = copy_bidirectional(&mut client, &mut server) => Some(result),
"Server closed connection (S->C)" _ = watchdog => None, // Activity timeout — cancel relay
); };
let _ = client_writer.shutdown().await;
break; // ── Clean shutdown ──────────────────────────────────────────────
} // After select!, the losing future is dropped, borrows released.
Ok(n) => { // Shut down both write sides for clean TCP FIN.
total_bytes += n as u64; let _ = client.shutdown().await;
msg_count += 1; let _ = server.shutdown().await;
s2c_bytes_clone.store(total_bytes, Ordering::Relaxed);
// ── Final logging ───────────────────────────────────────────────
stats_s2c.add_user_octets_to(&user_s2c, n as u64); let c2s_ops = counters.c2s_ops.load(Ordering::Relaxed);
stats_s2c.increment_user_msgs_to(&user_s2c); let s2c_ops = counters.s2c_ops.load(Ordering::Relaxed);
let duration = epoch.elapsed();
trace!(
user = %user_s2c, match copy_result {
bytes = n, Some(Ok((c2s, s2c))) => {
total = total_bytes, // Normal completion — one side closed the connection
data_preview = %hex::encode(&buf[..n.min(32)]), debug!(
"S->C data" user = %user_owned,
); c2s_bytes = c2s,
s2c_bytes = s2c,
if let Err(e) = client_writer.write_all(&buf[..n]).await { c2s_msgs = c2s_ops,
debug!(user = %user_s2c, error = %e, "Failed to write to client"); s2c_msgs = s2c_ops,
break; duration_secs = duration.as_secs(),
} "Relay finished"
if let Err(e) = client_writer.flush().await { );
debug!(user = %user_s2c, error = %e, "Failed to flush to client"); Ok(())
break;
}
}
Err(e) => {
debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error");
break;
}
}
} }
}); Some(Err(e)) => {
// I/O error in one of the directions
// Wait for either direction to complete let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
tokio::select! { let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
result = c2s => { debug!(
if let Err(e) = result { user = %user_owned,
warn!(error = %e, "C->S task panicked"); c2s_bytes = c2s,
} s2c_bytes = s2c,
c2s_msgs = c2s_ops,
s2c_msgs = s2c_ops,
duration_secs = duration.as_secs(),
error = %e,
"Relay error"
);
Err(e.into())
} }
result = s2c => { None => {
if let Err(e) = result { // Activity timeout (watchdog fired)
warn!(error = %e, "S->C task panicked"); let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
} let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
debug!(
user = %user_owned,
c2s_bytes = c2s,
s2c_bytes = s2c,
c2s_msgs = c2s_ops,
s2c_msgs = s2c_ops,
duration_secs = duration.as_secs(),
"Relay finished (activity timeout)"
);
Ok(())
} }
} }
debug!(
c2s_bytes = c2s_bytes.load(Ordering::Relaxed),
s2c_bytes = s2c_bytes.load(Ordering::Relaxed),
"Relay finished"
);
Ok(())
} }

View File

@@ -1,29 +1,28 @@
//! Statistics //! Statistics and replay protection
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::{Instant, Duration};
use dashmap::DashMap; use dashmap::DashMap;
use parking_lot::RwLock; use parking_lot::Mutex;
use lru::LruCache; use lru::LruCache;
use std::num::NonZeroUsize; use std::num::NonZeroUsize;
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
use std::collections::VecDeque;
use tracing::debug;
// ============= Stats =============
/// Thread-safe statistics
#[derive(Default)] #[derive(Default)]
pub struct Stats { pub struct Stats {
// Global counters
connects_all: AtomicU64, connects_all: AtomicU64,
connects_bad: AtomicU64, connects_bad: AtomicU64,
handshake_timeouts: AtomicU64, handshake_timeouts: AtomicU64,
// Per-user stats
user_stats: DashMap<String, UserStats>, user_stats: DashMap<String, UserStats>,
start_time: parking_lot::RwLock<Option<Instant>>,
// Start time
start_time: RwLock<Option<Instant>>,
} }
/// Per-user statistics
#[derive(Default)] #[derive(Default)]
pub struct UserStats { pub struct UserStats {
pub connects: AtomicU64, pub connects: AtomicU64,
@@ -41,42 +40,20 @@ impl Stats {
stats stats
} }
// Global stats pub fn increment_connects_all(&self) { self.connects_all.fetch_add(1, Ordering::Relaxed); }
pub fn increment_connects_all(&self) { pub fn increment_connects_bad(&self) { self.connects_bad.fetch_add(1, Ordering::Relaxed); }
self.connects_all.fetch_add(1, Ordering::Relaxed); pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); }
} pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) }
pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) }
pub fn increment_connects_bad(&self) {
self.connects_bad.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_handshake_timeouts(&self) {
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
}
pub fn get_connects_all(&self) -> u64 {
self.connects_all.load(Ordering::Relaxed)
}
pub fn get_connects_bad(&self) -> u64 {
self.connects_bad.load(Ordering::Relaxed)
}
// User stats
pub fn increment_user_connects(&self, user: &str) { pub fn increment_user_connects(&self, user: &str) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .connects.fetch_add(1, Ordering::Relaxed);
.or_default()
.connects
.fetch_add(1, Ordering::Relaxed);
} }
pub fn increment_user_curr_connects(&self, user: &str) { pub fn increment_user_curr_connects(&self, user: &str) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .curr_connects.fetch_add(1, Ordering::Relaxed);
.or_default()
.curr_connects
.fetch_add(1, Ordering::Relaxed);
} }
pub fn decrement_user_curr_connects(&self, user: &str) { pub fn decrement_user_curr_connects(&self, user: &str) {
@@ -86,47 +63,33 @@ impl Stats {
} }
pub fn get_user_curr_connects(&self, user: &str) -> u64 { pub fn get_user_curr_connects(&self, user: &str) -> u64 {
self.user_stats self.user_stats.get(user)
.get(user)
.map(|s| s.curr_connects.load(Ordering::Relaxed)) .map(|s| s.curr_connects.load(Ordering::Relaxed))
.unwrap_or(0) .unwrap_or(0)
} }
pub fn add_user_octets_from(&self, user: &str, bytes: u64) { pub fn add_user_octets_from(&self, user: &str, bytes: u64) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .octets_from_client.fetch_add(bytes, Ordering::Relaxed);
.or_default()
.octets_from_client
.fetch_add(bytes, Ordering::Relaxed);
} }
pub fn add_user_octets_to(&self, user: &str, bytes: u64) { pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .octets_to_client.fetch_add(bytes, Ordering::Relaxed);
.or_default()
.octets_to_client
.fetch_add(bytes, Ordering::Relaxed);
} }
pub fn increment_user_msgs_from(&self, user: &str) { pub fn increment_user_msgs_from(&self, user: &str) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .msgs_from_client.fetch_add(1, Ordering::Relaxed);
.or_default()
.msgs_from_client
.fetch_add(1, Ordering::Relaxed);
} }
pub fn increment_user_msgs_to(&self, user: &str) { pub fn increment_user_msgs_to(&self, user: &str) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .msgs_to_client.fetch_add(1, Ordering::Relaxed);
.or_default()
.msgs_to_client
.fetch_add(1, Ordering::Relaxed);
} }
pub fn get_user_total_octets(&self, user: &str) -> u64 { pub fn get_user_total_octets(&self, user: &str) -> u64 {
self.user_stats self.user_stats.get(user)
.get(user)
.map(|s| { .map(|s| {
s.octets_from_client.load(Ordering::Relaxed) + s.octets_from_client.load(Ordering::Relaxed) +
s.octets_to_client.load(Ordering::Relaxed) s.octets_to_client.load(Ordering::Relaxed)
@@ -141,37 +104,209 @@ impl Stats {
} }
} }
// Arc<Stats> Hightech Stats :D // ============= Replay Checker =============
/// Replay attack checker using LRU cache
pub struct ReplayChecker { pub struct ReplayChecker {
handshakes: RwLock<LruCache<Vec<u8>, ()>>, shards: Vec<Mutex<ReplayShard>>,
tls_digests: RwLock<LruCache<Vec<u8>, ()>>, shard_mask: usize,
window: Duration,
checks: AtomicU64,
hits: AtomicU64,
additions: AtomicU64,
cleanups: AtomicU64,
} }
impl ReplayChecker { struct ReplayEntry {
pub fn new(capacity: usize) -> Self { seen_at: Instant,
let cap = NonZeroUsize::new(capacity.max(1)).unwrap(); seq: u64,
}
struct ReplayShard {
cache: LruCache<Box<[u8]>, ReplayEntry>,
queue: VecDeque<(Instant, Box<[u8]>, u64)>,
seq_counter: u64,
}
impl ReplayShard {
fn new(cap: NonZeroUsize) -> Self {
Self { Self {
handshakes: RwLock::new(LruCache::new(cap)), cache: LruCache::new(cap),
tls_digests: RwLock::new(LruCache::new(cap)), queue: VecDeque::with_capacity(cap.get()),
seq_counter: 0,
} }
} }
pub fn check_handshake(&self, data: &[u8]) -> bool { fn next_seq(&mut self) -> u64 {
self.handshakes.read().contains(&data.to_vec()) self.seq_counter += 1;
self.seq_counter
}
fn cleanup(&mut self, now: Instant, window: Duration) {
if window.is_zero() {
return;
}
let cutoff = now.checked_sub(window).unwrap_or(now);
while let Some((ts, _, _)) = self.queue.front() {
if *ts >= cutoff {
break;
}
let (_, key, queue_seq) = self.queue.pop_front().unwrap();
// Use key.as_ref() to get &[u8] — avoids Borrow<Q> ambiguity
// between Borrow<[u8]> and Borrow<Box<[u8]>>
if let Some(entry) = self.cache.peek(key.as_ref()) {
if entry.seq == queue_seq {
self.cache.pop(key.as_ref());
}
}
}
} }
pub fn add_handshake(&self, data: &[u8]) { fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool {
self.handshakes.write().put(data.to_vec(), ()); self.cleanup(now, window);
// key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]>
self.cache.get(key).is_some()
} }
pub fn check_tls_digest(&self, data: &[u8]) -> bool { fn add(&mut self, key: &[u8], now: Instant, window: Duration) {
self.tls_digests.read().contains(&data.to_vec()) self.cleanup(now, window);
let seq = self.next_seq();
let boxed_key: Box<[u8]> = key.into();
self.cache.put(boxed_key.clone(), ReplayEntry { seen_at: now, seq });
self.queue.push_back((now, boxed_key, seq));
} }
pub fn add_tls_digest(&self, data: &[u8]) { fn len(&self) -> usize {
self.tls_digests.write().put(data.to_vec(), ()); self.cache.len()
}
}
impl ReplayChecker {
pub fn new(total_capacity: usize, window: Duration) -> Self {
let num_shards = 64;
let shard_capacity = (total_capacity / num_shards).max(1);
let cap = NonZeroUsize::new(shard_capacity).unwrap();
let mut shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
shards.push(Mutex::new(ReplayShard::new(cap)));
}
Self {
shards,
shard_mask: num_shards - 1,
window,
checks: AtomicU64::new(0),
hits: AtomicU64::new(0),
additions: AtomicU64::new(0),
cleanups: AtomicU64::new(0),
}
}
fn get_shard_idx(&self, key: &[u8]) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
(hasher.finish() as usize) & self.shard_mask
}
fn check(&self, data: &[u8]) -> bool {
self.checks.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
let found = shard.check(data, Instant::now(), self.window);
if found {
self.hits.fetch_add(1, Ordering::Relaxed);
}
found
}
fn add(&self, data: &[u8]) {
self.additions.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
shard.add(data, Instant::now(), self.window);
}
pub fn check_handshake(&self, data: &[u8]) -> bool { self.check(data) }
pub fn add_handshake(&self, data: &[u8]) { self.add(data) }
pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check(data) }
pub fn add_tls_digest(&self, data: &[u8]) { self.add(data) }
pub fn stats(&self) -> ReplayStats {
let mut total_entries = 0;
let mut total_queue_len = 0;
for shard in &self.shards {
let s = shard.lock();
total_entries += s.cache.len();
total_queue_len += s.queue.len();
}
ReplayStats {
total_entries,
total_queue_len,
total_checks: self.checks.load(Ordering::Relaxed),
total_hits: self.hits.load(Ordering::Relaxed),
total_additions: self.additions.load(Ordering::Relaxed),
total_cleanups: self.cleanups.load(Ordering::Relaxed),
num_shards: self.shards.len(),
window_secs: self.window.as_secs(),
}
}
pub async fn run_periodic_cleanup(&self) {
let interval = if self.window.as_secs() > 60 {
Duration::from_secs(30)
} else {
Duration::from_secs(self.window.as_secs().max(1) / 2)
};
loop {
tokio::time::sleep(interval).await;
let now = Instant::now();
let mut cleaned = 0usize;
for shard_mutex in &self.shards {
let mut shard = shard_mutex.lock();
let before = shard.len();
shard.cleanup(now, self.window);
let after = shard.len();
cleaned += before.saturating_sub(after);
}
self.cleanups.fetch_add(1, Ordering::Relaxed);
if cleaned > 0 {
debug!(cleaned = cleaned, "Replay checker: periodic cleanup");
}
}
}
}
#[derive(Debug, Clone)]
pub struct ReplayStats {
pub total_entries: usize,
pub total_queue_len: usize,
pub total_checks: u64,
pub total_hits: u64,
pub total_additions: u64,
pub total_cleanups: u64,
pub num_shards: usize,
pub window_secs: u64,
}
impl ReplayStats {
pub fn hit_rate(&self) -> f64 {
if self.total_checks == 0 { 0.0 }
else { (self.total_hits as f64 / self.total_checks as f64) * 100.0 }
}
pub fn ghost_ratio(&self) -> f64 {
if self.total_entries == 0 { 0.0 }
else { self.total_queue_len as f64 / self.total_entries as f64 }
} }
} }
@@ -182,42 +317,60 @@ mod tests {
#[test] #[test]
fn test_stats_shared_counters() { fn test_stats_shared_counters() {
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
stats.increment_connects_all();
// Симулируем использование из разных "задач" stats.increment_connects_all();
let stats1 = Arc::clone(&stats); stats.increment_connects_all();
let stats2 = Arc::clone(&stats);
stats1.increment_connects_all();
stats2.increment_connects_all();
stats1.increment_connects_all();
// Все инкременты должны быть видны
assert_eq!(stats.get_connects_all(), 3); assert_eq!(stats.get_connects_all(), 3);
} }
#[test] #[test]
fn test_user_stats_shared() { fn test_replay_checker_basic() {
let stats = Arc::new(Stats::new()); let checker = ReplayChecker::new(100, Duration::from_secs(60));
assert!(!checker.check_handshake(b"test1"));
let stats1 = Arc::clone(&stats); checker.add_handshake(b"test1");
let stats2 = Arc::clone(&stats); assert!(checker.check_handshake(b"test1"));
assert!(!checker.check_handshake(b"test2"));
stats1.add_user_octets_from("user1", 100);
stats2.add_user_octets_from("user1", 200);
stats1.add_user_octets_to("user1", 50);
assert_eq!(stats.get_user_total_octets("user1"), 350);
} }
#[test] #[test]
fn test_concurrent_user_connects() { fn test_replay_checker_duplicate_add() {
let stats = Arc::new(Stats::new()); let checker = ReplayChecker::new(100, Duration::from_secs(60));
checker.add_handshake(b"dup");
stats.increment_user_curr_connects("user1"); checker.add_handshake(b"dup");
stats.increment_user_curr_connects("user1"); assert!(checker.check_handshake(b"dup"));
assert_eq!(stats.get_user_curr_connects("user1"), 2); }
stats.decrement_user_curr_connects("user1"); #[test]
assert_eq!(stats.get_user_curr_connects("user1"), 1); fn test_replay_checker_expiration() {
let checker = ReplayChecker::new(100, Duration::from_millis(50));
checker.add_handshake(b"expire");
assert!(checker.check_handshake(b"expire"));
std::thread::sleep(Duration::from_millis(100));
assert!(!checker.check_handshake(b"expire"));
}
#[test]
fn test_replay_checker_stats() {
let checker = ReplayChecker::new(100, Duration::from_secs(60));
checker.add_handshake(b"k1");
checker.add_handshake(b"k2");
checker.check_handshake(b"k1");
checker.check_handshake(b"k3");
let stats = checker.stats();
assert_eq!(stats.total_additions, 2);
assert_eq!(stats.total_checks, 2);
assert_eq!(stats.total_hits, 1);
}
#[test]
fn test_replay_checker_many_keys() {
let checker = ReplayChecker::new(1000, Duration::from_secs(60));
for i in 0..500u32 {
checker.add(&i.to_le_bytes());
}
for i in 0..500u32 {
assert!(checker.check(&i.to_le_bytes()));
}
assert_eq!(checker.stats().total_entries, 500);
} }
} }

451
src/stream/buffer_pool.rs Normal file
View File

@@ -0,0 +1,451 @@
//! Reusable buffer pool to avoid allocations in hot paths
//!
//! This module provides a thread-safe pool of BytesMut buffers
//! that can be reused across connections to reduce allocation pressure.
use bytes::BytesMut;
use crossbeam_queue::ArrayQueue;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
// ============= Configuration =============
/// Default buffer size
/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and prevent bufferbloat.
pub const DEFAULT_BUFFER_SIZE: usize = 16 * 1024;
/// Default maximum number of pooled buffers
pub const DEFAULT_MAX_BUFFERS: usize = 1024;
// ============= Buffer Pool =============
/// Thread-safe pool of reusable buffers
pub struct BufferPool {
/// Queue of available buffers
buffers: ArrayQueue<BytesMut>,
/// Size of each buffer
buffer_size: usize,
/// Maximum number of buffers to pool
max_buffers: usize,
/// Total allocated buffers (including in-use)
allocated: AtomicUsize,
/// Number of times we had to create a new buffer
misses: AtomicUsize,
/// Number of successful reuses
hits: AtomicUsize,
}
impl BufferPool {
/// Create a new buffer pool with default settings
pub fn new() -> Self {
Self::with_config(DEFAULT_BUFFER_SIZE, DEFAULT_MAX_BUFFERS)
}
/// Create a buffer pool with custom configuration
pub fn with_config(buffer_size: usize, max_buffers: usize) -> Self {
Self {
buffers: ArrayQueue::new(max_buffers),
buffer_size,
max_buffers,
allocated: AtomicUsize::new(0),
misses: AtomicUsize::new(0),
hits: AtomicUsize::new(0),
}
}
/// Get a buffer from the pool, or create a new one if empty
pub fn get(self: &Arc<Self>) -> PooledBuffer {
match self.buffers.pop() {
Some(mut buffer) => {
self.hits.fetch_add(1, Ordering::Relaxed);
buffer.clear();
PooledBuffer {
buffer: Some(buffer),
pool: Arc::clone(self),
}
}
None => {
self.misses.fetch_add(1, Ordering::Relaxed);
self.allocated.fetch_add(1, Ordering::Relaxed);
PooledBuffer {
buffer: Some(BytesMut::with_capacity(self.buffer_size)),
pool: Arc::clone(self),
}
}
}
}
/// Try to get a buffer, returns None if pool is empty
pub fn try_get(self: &Arc<Self>) -> Option<PooledBuffer> {
self.buffers.pop().map(|mut buffer| {
self.hits.fetch_add(1, Ordering::Relaxed);
buffer.clear();
PooledBuffer {
buffer: Some(buffer),
pool: Arc::clone(self),
}
})
}
/// Return a buffer to the pool
fn return_buffer(&self, mut buffer: BytesMut) {
// Clear the buffer but keep capacity
buffer.clear();
// Only return if we haven't exceeded max and buffer is right size
if buffer.capacity() >= self.buffer_size {
// Try to push to pool, if full just drop
let _ = self.buffers.push(buffer);
}
// If buffer was dropped (pool full), decrement allocated
// Actually we don't decrement here because the buffer might have been
// grown beyond our size - we just let it go
}
/// Get pool statistics
pub fn stats(&self) -> PoolStats {
PoolStats {
pooled: self.buffers.len(),
allocated: self.allocated.load(Ordering::Relaxed),
max_buffers: self.max_buffers,
buffer_size: self.buffer_size,
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
}
}
/// Get buffer size
pub fn buffer_size(&self) -> usize {
self.buffer_size
}
/// Preallocate buffers to fill the pool
pub fn preallocate(&self, count: usize) {
let to_alloc = count.min(self.max_buffers);
for _ in 0..to_alloc {
if self.buffers.push(BytesMut::with_capacity(self.buffer_size)).is_err() {
break;
}
self.allocated.fetch_add(1, Ordering::Relaxed);
}
}
}
impl Default for BufferPool {
fn default() -> Self {
Self::new()
}
}
// ============= Pool Statistics =============
/// Statistics about buffer pool usage
#[derive(Debug, Clone)]
pub struct PoolStats {
/// Current number of buffers in pool
pub pooled: usize,
/// Total buffers allocated (in-use + pooled)
pub allocated: usize,
/// Maximum buffers allowed
pub max_buffers: usize,
/// Size of each buffer
pub buffer_size: usize,
/// Number of cache hits (reused buffer)
pub hits: usize,
/// Number of cache misses (new allocation)
pub misses: usize,
}
impl PoolStats {
/// Get hit rate as percentage
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
(self.hits as f64 / total as f64) * 100.0
}
}
}
// ============= Pooled Buffer =============
/// A buffer that automatically returns to the pool when dropped
pub struct PooledBuffer {
buffer: Option<BytesMut>,
pool: Arc<BufferPool>,
}
impl PooledBuffer {
/// Take the inner buffer, preventing return to pool
pub fn take(mut self) -> BytesMut {
self.buffer.take().unwrap()
}
/// Get the capacity of the buffer
pub fn capacity(&self) -> usize {
self.buffer.as_ref().map(|b| b.capacity()).unwrap_or(0)
}
/// Check if buffer is empty
pub fn is_empty(&self) -> bool {
self.buffer.as_ref().map(|b| b.is_empty()).unwrap_or(true)
}
/// Get the length of data in buffer
pub fn len(&self) -> usize {
self.buffer.as_ref().map(|b| b.len()).unwrap_or(0)
}
/// Clear the buffer
pub fn clear(&mut self) {
if let Some(ref mut b) = self.buffer {
b.clear();
}
}
}
impl Deref for PooledBuffer {
type Target = BytesMut;
fn deref(&self) -> &Self::Target {
self.buffer.as_ref().expect("buffer taken")
}
}
impl DerefMut for PooledBuffer {
fn deref_mut(&mut self) -> &mut Self::Target {
self.buffer.as_mut().expect("buffer taken")
}
}
impl Drop for PooledBuffer {
fn drop(&mut self) {
if let Some(buffer) = self.buffer.take() {
self.pool.return_buffer(buffer);
}
}
}
impl AsRef<[u8]> for PooledBuffer {
fn as_ref(&self) -> &[u8] {
self.buffer.as_ref().map(|b| b.as_ref()).unwrap_or(&[])
}
}
impl AsMut<[u8]> for PooledBuffer {
fn as_mut(&mut self) -> &mut [u8] {
self.buffer.as_mut().map(|b| b.as_mut()).unwrap_or(&mut [])
}
}
// ============= Scoped Buffer =============
/// A buffer that can be used for a scoped operation
/// Useful for ensuring buffer is returned even on early return
pub struct ScopedBuffer<'a> {
buffer: &'a mut PooledBuffer,
}
impl<'a> ScopedBuffer<'a> {
/// Create a new scoped buffer
pub fn new(buffer: &'a mut PooledBuffer) -> Self {
buffer.clear();
Self { buffer }
}
}
impl<'a> Deref for ScopedBuffer<'a> {
type Target = BytesMut;
fn deref(&self) -> &Self::Target {
self.buffer.deref()
}
}
impl<'a> DerefMut for ScopedBuffer<'a> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.buffer.deref_mut()
}
}
impl<'a> Drop for ScopedBuffer<'a> {
fn drop(&mut self) {
self.buffer.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_basic() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
// Get a buffer
let mut buf1 = pool.get();
buf1.extend_from_slice(b"hello");
assert_eq!(&buf1[..], b"hello");
// Drop returns to pool
drop(buf1);
let stats = pool.stats();
assert_eq!(stats.pooled, 1);
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 1);
// Get again - should reuse
let buf2 = pool.get();
assert!(buf2.is_empty()); // Buffer was cleared
let stats = pool.stats();
assert_eq!(stats.pooled, 0);
assert_eq!(stats.hits, 1);
}
#[test]
fn test_pool_multiple_buffers() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
// Get multiple buffers
let buf1 = pool.get();
let buf2 = pool.get();
let buf3 = pool.get();
let stats = pool.stats();
assert_eq!(stats.allocated, 3);
assert_eq!(stats.pooled, 0);
// Return all
drop(buf1);
drop(buf2);
drop(buf3);
let stats = pool.stats();
assert_eq!(stats.pooled, 3);
}
#[test]
fn test_pool_overflow() {
let pool = Arc::new(BufferPool::with_config(1024, 2));
// Get 3 buffers (more than max)
let buf1 = pool.get();
let buf2 = pool.get();
let buf3 = pool.get();
// Return all - only 2 should be pooled
drop(buf1);
drop(buf2);
drop(buf3);
let stats = pool.stats();
assert_eq!(stats.pooled, 2);
}
#[test]
fn test_pool_take() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
let mut buf = pool.get();
buf.extend_from_slice(b"data");
// Take ownership, buffer should not return to pool
let taken = buf.take();
assert_eq!(&taken[..], b"data");
let stats = pool.stats();
assert_eq!(stats.pooled, 0);
}
#[test]
fn test_pool_preallocate() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
pool.preallocate(5);
let stats = pool.stats();
assert_eq!(stats.pooled, 5);
assert_eq!(stats.allocated, 5);
}
#[test]
fn test_pool_try_get() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
// Pool is empty, try_get returns None
assert!(pool.try_get().is_none());
// Add a buffer to pool
pool.preallocate(1);
// Now try_get should succeed
assert!(pool.try_get().is_some());
assert!(pool.try_get().is_none());
}
#[test]
fn test_hit_rate() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
// First get is a miss
let buf1 = pool.get();
drop(buf1);
// Second get is a hit
let buf2 = pool.get();
drop(buf2);
// Third get is a hit
let _buf3 = pool.get();
let stats = pool.stats();
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate() - 66.67).abs() < 1.0);
}
#[test]
fn test_scoped_buffer() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
let mut buf = pool.get();
{
let mut scoped = ScopedBuffer::new(&mut buf);
scoped.extend_from_slice(b"scoped data");
assert_eq!(&scoped[..], b"scoped data");
}
// After scoped is dropped, buffer is cleared
assert!(buf.is_empty());
}
#[test]
fn test_concurrent_access() {
use std::thread;
let pool = Arc::new(BufferPool::with_config(1024, 100));
let mut handles = vec![];
for _ in 0..10 {
let pool_clone = Arc::clone(&pool);
handles.push(thread::spawn(move || {
for _ in 0..100 {
let mut buf = pool_clone.get();
buf.extend_from_slice(b"test");
// buf auto-returned on drop
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let stats = pool.stats();
// All buffers should be returned
assert!(stats.pooled > 0);
}
}

File diff suppressed because it is too large Load Diff

189
src/stream/frame.rs Normal file
View File

@@ -0,0 +1,189 @@
//! MTProto frame types and traits
//!
//! This module defines the common types and traits used by all
//! frame encoding/decoding implementations.
use bytes::{Bytes, BytesMut};
use std::io::Result;
use std::sync::Arc;
use crate::protocol::constants::ProtoTag;
use crate::crypto::SecureRandom;
// ============= Frame Types =============
/// A decoded MTProto frame
#[derive(Debug, Clone)]
pub struct Frame {
/// Frame payload data
pub data: Bytes,
/// Frame metadata
pub meta: FrameMeta,
}
impl Frame {
/// Create a new frame with data and default metadata
pub fn new(data: Bytes) -> Self {
Self {
data,
meta: FrameMeta::default(),
}
}
/// Create a new frame with data and metadata
pub fn with_meta(data: Bytes, meta: FrameMeta) -> Self {
Self { data, meta }
}
/// Create an empty frame
pub fn empty() -> Self {
Self::new(Bytes::new())
}
/// Check if frame is empty
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
/// Get frame length
pub fn len(&self) -> usize {
self.data.len()
}
/// Create a QuickAck request frame
pub fn quickack(data: Bytes) -> Self {
Self {
data,
meta: FrameMeta {
quickack: true,
..Default::default()
},
}
}
/// Create a simple ACK frame
pub fn simple_ack(data: Bytes) -> Self {
Self {
data,
meta: FrameMeta {
simple_ack: true,
..Default::default()
},
}
}
}
/// Frame metadata
#[derive(Debug, Clone, Default)]
pub struct FrameMeta {
/// Quick ACK requested - client wants immediate acknowledgment
pub quickack: bool,
/// This is a simple ACK message (reversed data)
pub simple_ack: bool,
/// Original padding length (for secure mode)
pub padding_len: u8,
}
impl FrameMeta {
/// Create new empty metadata
pub fn new() -> Self {
Self::default()
}
/// Create with quickack flag
pub fn with_quickack(mut self) -> Self {
self.quickack = true;
self
}
/// Create with simple_ack flag
pub fn with_simple_ack(mut self) -> Self {
self.simple_ack = true;
self
}
/// Create with padding length
pub fn with_padding(mut self, len: u8) -> Self {
self.padding_len = len;
self
}
/// Check if any special flags are set
pub fn has_flags(&self) -> bool {
self.quickack || self.simple_ack
}
}
// ============= Codec Trait =============
/// Trait for frame codecs that can encode and decode frames
pub trait FrameCodec: Send + Sync {
/// Get the protocol tag for this codec
fn proto_tag(&self) -> ProtoTag;
/// Encode a frame into the destination buffer
///
/// Returns the number of bytes written.
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> Result<usize>;
/// Try to decode a frame from the source buffer
///
/// Returns:
/// - `Ok(Some(frame))` if a complete frame was decoded
/// - `Ok(None)` if more data is needed
/// - `Err(e)` if an error occurred
///
/// On success, the consumed bytes are removed from `src`.
fn decode(&self, src: &mut BytesMut) -> Result<Option<Frame>>;
/// Get the minimum bytes needed to determine frame length
fn min_header_size(&self) -> usize;
/// Get the maximum allowed frame size
fn max_frame_size(&self) -> usize {
// Default: 16MB
16 * 1024 * 1024
}
}
// ============= Codec Factory =============
/// Create a frame codec for the given protocol tag
pub fn create_codec(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Box<dyn FrameCodec> {
match proto_tag {
ProtoTag::Abridged => Box::new(crate::stream::frame_codec::AbridgedCodec::new()),
ProtoTag::Intermediate => Box::new(crate::stream::frame_codec::IntermediateCodec::new()),
ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new(rng)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_frame_creation() {
let frame = Frame::new(Bytes::from_static(b"test"));
assert_eq!(frame.len(), 4);
assert!(!frame.is_empty());
assert!(!frame.meta.quickack);
let frame = Frame::empty();
assert!(frame.is_empty());
let frame = Frame::quickack(Bytes::from_static(b"ack"));
assert!(frame.meta.quickack);
}
#[test]
fn test_frame_meta() {
let meta = FrameMeta::new()
.with_quickack()
.with_padding(3);
assert!(meta.quickack);
assert!(!meta.simple_ack);
assert_eq!(meta.padding_len, 3);
assert!(meta.has_flags());
}
}

628
src/stream/frame_codec.rs Normal file
View File

@@ -0,0 +1,628 @@
//! tokio-util codec integration for MTProto frames
//!
//! This module provides Encoder/Decoder implementations compatible
//! with tokio-util's Framed wrapper for easy async frame I/O.
use bytes::{Bytes, BytesMut, BufMut};
use std::io::{self, Error, ErrorKind};
use std::sync::Arc;
use tokio_util::codec::{Decoder, Encoder};
use crate::protocol::constants::ProtoTag;
use crate::crypto::SecureRandom;
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
// ============= Unified Codec =============
/// Unified frame codec that wraps all protocol variants
///
/// This codec implements tokio-util's Encoder and Decoder traits,
/// allowing it to be used with `Framed` for async frame I/O.
pub struct FrameCodec {
/// Protocol variant
proto_tag: ProtoTag,
/// Maximum allowed frame size
max_frame_size: usize,
/// RNG for secure padding
rng: Arc<SecureRandom>,
}
impl FrameCodec {
/// Create a new codec for the given protocol
pub fn new(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
Self {
proto_tag,
max_frame_size: 16 * 1024 * 1024, // 16MB default
rng,
}
}
/// Set maximum frame size
pub fn with_max_frame_size(mut self, size: usize) -> Self {
self.max_frame_size = size;
self
}
/// Get protocol tag
pub fn proto_tag(&self) -> ProtoTag {
self.proto_tag
}
}
impl Decoder for FrameCodec {
type Item = Frame;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match self.proto_tag {
ProtoTag::Abridged => decode_abridged(src, self.max_frame_size),
ProtoTag::Intermediate => decode_intermediate(src, self.max_frame_size),
ProtoTag::Secure => decode_secure(src, self.max_frame_size),
}
}
}
impl Encoder<Frame> for FrameCodec {
type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
match self.proto_tag {
ProtoTag::Abridged => encode_abridged(&frame, dst),
ProtoTag::Intermediate => encode_intermediate(&frame, dst),
ProtoTag::Secure => encode_secure(&frame, dst, &self.rng),
}
}
}
// ============= Abridged Protocol =============
fn decode_abridged(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame>> {
if src.is_empty() {
return Ok(None);
}
let mut meta = FrameMeta::new();
let first_byte = src[0];
// Extract length and quickack flag
let mut len_words = (first_byte & 0x7f) as usize;
if first_byte >= 0x80 {
meta.quickack = true;
}
let header_len;
if len_words == 0x7f {
// Extended length (3 more bytes needed)
if src.len() < 4 {
return Ok(None);
}
len_words = u32::from_le_bytes([src[1], src[2], src[3], 0]) as usize;
header_len = 4;
} else {
header_len = 1;
}
// Length is in 4-byte words
let byte_len = len_words.checked_mul(4).ok_or_else(|| {
Error::new(ErrorKind::InvalidData, "frame length overflow")
})?;
// Validate size
if byte_len > max_size {
return Err(Error::new(
ErrorKind::InvalidData,
format!("frame too large: {} bytes (max {})", byte_len, max_size)
));
}
let total_len = header_len + byte_len;
if src.len() < total_len {
// Reserve space for the rest of the frame
src.reserve(total_len - src.len());
return Ok(None);
}
// Extract data
let _ = src.split_to(header_len);
let data = src.split_to(byte_len).freeze();
Ok(Some(Frame::with_meta(data, meta)))
}
fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
let data = &frame.data;
// Validate alignment
if data.len() % 4 != 0 {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("abridged frame must be 4-byte aligned, got {} bytes", data.len())
));
}
// Simple ACK: send reversed data without header
if frame.meta.simple_ack {
dst.reserve(data.len());
for byte in data.iter().rev() {
dst.put_u8(*byte);
}
return Ok(());
}
let len_words = data.len() / 4;
if len_words < 0x7f {
// Short header
dst.reserve(1 + data.len());
let mut len_byte = len_words as u8;
if frame.meta.quickack {
len_byte |= 0x80;
}
dst.put_u8(len_byte);
} else if len_words < (1 << 24) {
// Extended header
dst.reserve(4 + data.len());
let mut first = 0x7fu8;
if frame.meta.quickack {
first |= 0x80;
}
dst.put_u8(first);
let len_bytes = (len_words as u32).to_le_bytes();
dst.extend_from_slice(&len_bytes[..3]);
} else {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("frame too large: {} bytes", data.len())
));
}
dst.extend_from_slice(data);
Ok(())
}
// ============= Intermediate Protocol =============
fn decode_intermediate(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame>> {
if src.len() < 4 {
return Ok(None);
}
let mut meta = FrameMeta::new();
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
// Check QuickACK flag
if len >= 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
// Validate size
if len > max_size {
return Err(Error::new(
ErrorKind::InvalidData,
format!("frame too large: {} bytes (max {})", len, max_size)
));
}
let total_len = 4 + len;
if src.len() < total_len {
src.reserve(total_len - src.len());
return Ok(None);
}
// Extract data
let _ = src.split_to(4);
let data = src.split_to(len).freeze();
Ok(Some(Frame::with_meta(data, meta)))
}
fn encode_intermediate(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
let data = &frame.data;
// Simple ACK: just send data
if frame.meta.simple_ack {
dst.reserve(data.len());
dst.extend_from_slice(data);
return Ok(());
}
dst.reserve(4 + data.len());
let mut len = data.len() as u32;
if frame.meta.quickack {
len |= 0x80000000;
}
dst.extend_from_slice(&len.to_le_bytes());
dst.extend_from_slice(data);
Ok(())
}
// ============= Secure Intermediate Protocol =============
fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame>> {
if src.len() < 4 {
return Ok(None);
}
let mut meta = FrameMeta::new();
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
// Check QuickACK flag
if len >= 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
// Validate size
if len > max_size {
return Err(Error::new(
ErrorKind::InvalidData,
format!("frame too large: {} bytes (max {})", len, max_size)
));
}
let total_len = 4 + len;
if src.len() < total_len {
src.reserve(total_len - src.len());
return Ok(None);
}
// Calculate padding (indicated by length not divisible by 4)
let padding_len = len % 4;
let data_len = if padding_len != 0 {
len - padding_len
} else {
len
};
meta.padding_len = padding_len as u8;
// Extract data (excluding padding)
let _ = src.split_to(4);
let all_data = src.split_to(len);
// Copy only the data portion, excluding padding
let data = Bytes::copy_from_slice(&all_data[..data_len]);
Ok(Some(Frame::with_meta(data, meta)))
}
fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::Result<()> {
let data = &frame.data;
// Simple ACK: just send data
if frame.meta.simple_ack {
dst.reserve(data.len());
dst.extend_from_slice(data);
return Ok(());
}
// Generate padding to make length not divisible by 4
let padding_len = if data.len() % 4 == 0 {
// Add 1-3 bytes to make it non-aligned
(rng.range(3) + 1) as usize
} else {
// Already non-aligned, can add 0-3
rng.range(4) as usize
};
let total_len = data.len() + padding_len;
dst.reserve(4 + total_len);
let mut len = total_len as u32;
if frame.meta.quickack {
len |= 0x80000000;
}
dst.extend_from_slice(&len.to_le_bytes());
dst.extend_from_slice(data);
if padding_len > 0 {
let padding = rng.bytes(padding_len);
dst.extend_from_slice(&padding);
}
Ok(())
}
// ============= Typed Codecs =============
/// Abridged protocol codec
pub struct AbridgedCodec {
max_frame_size: usize,
}
impl AbridgedCodec {
pub fn new() -> Self {
Self {
max_frame_size: 16 * 1024 * 1024,
}
}
}
impl Default for AbridgedCodec {
fn default() -> Self {
Self::new()
}
}
impl Decoder for AbridgedCodec {
type Item = Frame;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
decode_abridged(src, self.max_frame_size)
}
}
impl Encoder<Frame> for AbridgedCodec {
type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
encode_abridged(&frame, dst)
}
}
impl FrameCodecTrait for AbridgedCodec {
fn proto_tag(&self) -> ProtoTag {
ProtoTag::Abridged
}
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
let before = dst.len();
encode_abridged(frame, dst)?;
Ok(dst.len() - before)
}
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
decode_abridged(src, self.max_frame_size)
}
fn min_header_size(&self) -> usize {
1
}
}
/// Intermediate protocol codec
pub struct IntermediateCodec {
max_frame_size: usize,
}
impl IntermediateCodec {
pub fn new() -> Self {
Self {
max_frame_size: 16 * 1024 * 1024,
}
}
}
impl Default for IntermediateCodec {
fn default() -> Self {
Self::new()
}
}
impl Decoder for IntermediateCodec {
type Item = Frame;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
decode_intermediate(src, self.max_frame_size)
}
}
impl Encoder<Frame> for IntermediateCodec {
type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
encode_intermediate(&frame, dst)
}
}
impl FrameCodecTrait for IntermediateCodec {
fn proto_tag(&self) -> ProtoTag {
ProtoTag::Intermediate
}
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
let before = dst.len();
encode_intermediate(frame, dst)?;
Ok(dst.len() - before)
}
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
decode_intermediate(src, self.max_frame_size)
}
fn min_header_size(&self) -> usize {
4
}
}
/// Secure Intermediate protocol codec
pub struct SecureCodec {
max_frame_size: usize,
rng: Arc<SecureRandom>,
}
impl SecureCodec {
pub fn new(rng: Arc<SecureRandom>) -> Self {
Self {
max_frame_size: 16 * 1024 * 1024,
rng,
}
}
}
impl Default for SecureCodec {
fn default() -> Self {
Self::new(Arc::new(SecureRandom::new()))
}
}
impl Decoder for SecureCodec {
type Item = Frame;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
decode_secure(src, self.max_frame_size)
}
}
impl Encoder<Frame> for SecureCodec {
type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
encode_secure(&frame, dst, &self.rng)
}
}
impl FrameCodecTrait for SecureCodec {
fn proto_tag(&self) -> ProtoTag {
ProtoTag::Secure
}
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
let before = dst.len();
encode_secure(frame, dst, &self.rng)?;
Ok(dst.len() - before)
}
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
decode_secure(src, self.max_frame_size)
}
fn min_header_size(&self) -> usize {
4
}
}
// ============= Tests =============
#[cfg(test)]
mod tests {
use super::*;
use tokio_util::codec::{FramedRead, FramedWrite};
use tokio::io::duplex;
use futures::{SinkExt, StreamExt};
use crate::crypto::SecureRandom;
use std::sync::Arc;
#[tokio::test]
async fn test_framed_abridged() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, AbridgedCodec::new());
let mut reader = FramedRead::new(server, AbridgedCodec::new());
// Write a frame
let frame = Frame::new(Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]));
writer.send(frame).await.unwrap();
// Read it back
let received = reader.next().await.unwrap().unwrap();
assert_eq!(&received.data[..], &[1, 2, 3, 4, 5, 6, 7, 8]);
}
#[tokio::test]
async fn test_framed_intermediate() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
let mut reader = FramedRead::new(server, IntermediateCodec::new());
let frame = Frame::new(Bytes::from_static(b"hello world"));
writer.send(frame).await.unwrap();
let received = reader.next().await.unwrap().unwrap();
assert_eq!(&received.data[..], b"hello world");
}
#[tokio::test]
async fn test_framed_secure() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, SecureCodec::new(Arc::new(SecureRandom::new())));
let mut reader = FramedRead::new(server, SecureCodec::new(Arc::new(SecureRandom::new())));
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
let frame = Frame::new(original.clone());
writer.send(frame).await.unwrap();
let received = reader.next().await.unwrap().unwrap();
assert_eq!(&received.data[..], &original[..]);
}
#[tokio::test]
async fn test_unified_codec() {
for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
// Use 4-byte aligned data for abridged compatibility
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
let frame = Frame::new(original.clone());
writer.send(frame).await.unwrap();
let received = reader.next().await.unwrap().unwrap();
assert_eq!(received.data.len(), 8);
}
}
#[tokio::test]
async fn test_multiple_frames() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
let mut reader = FramedRead::new(server, IntermediateCodec::new());
// Send multiple frames
for i in 0..10 {
let data: Vec<u8> = (0..((i + 1) * 10)).map(|j| (j % 256) as u8).collect();
let frame = Frame::new(Bytes::from(data));
writer.send(frame).await.unwrap();
}
// Receive them
for i in 0..10 {
let received = reader.next().await.unwrap().unwrap();
assert_eq!(received.data.len(), (i + 1) * 10);
}
}
#[tokio::test]
async fn test_quickack_flag() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
let mut reader = FramedRead::new(server, IntermediateCodec::new());
let frame = Frame::quickack(Bytes::from_static(b"urgent"));
writer.send(frame).await.unwrap();
let received = reader.next().await.unwrap().unwrap();
assert!(received.meta.quickack);
}
#[test]
fn test_frame_too_large() {
let mut codec = FrameCodec::new(ProtoTag::Intermediate, Arc::new(SecureRandom::new()))
.with_max_frame_size(100);
// Create a "frame" that claims to be very large
let mut buf = BytesMut::new();
buf.extend_from_slice(&1000u32.to_le_bytes()); // length = 1000
buf.extend_from_slice(&[0u8; 10]); // partial data
let result = codec.decode(&mut buf);
assert!(result.is_err());
}
}

View File

@@ -4,8 +4,8 @@ use bytes::{Bytes, BytesMut};
use std::io::{Error, ErrorKind, Result}; use std::io::{Error, ErrorKind, Result};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::crypto::crc32; use crate::crypto::{crc32, SecureRandom};
use crate::crypto::random::SECURE_RANDOM; use std::sync::Arc;
use super::traits::{FrameMeta, LayeredStream}; use super::traits::{FrameMeta, LayeredStream};
// ============= Abridged (Compact) Frame ============= // ============= Abridged (Compact) Frame =============
@@ -251,11 +251,12 @@ impl<R> LayeredStream<R> for SecureIntermediateFrameReader<R> {
/// Writer for secure intermediate MTProto framing /// Writer for secure intermediate MTProto framing
pub struct SecureIntermediateFrameWriter<W> { pub struct SecureIntermediateFrameWriter<W> {
upstream: W, upstream: W,
rng: Arc<SecureRandom>,
} }
impl<W> SecureIntermediateFrameWriter<W> { impl<W> SecureIntermediateFrameWriter<W> {
pub fn new(upstream: W) -> Self { pub fn new(upstream: W, rng: Arc<SecureRandom>) -> Self {
Self { upstream } Self { upstream, rng }
} }
} }
@@ -267,8 +268,8 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
} }
// Add random padding (0-3 bytes) // Add random padding (0-3 bytes)
let padding_len = SECURE_RANDOM.range(4); let padding_len = self.rng.range(4);
let padding = SECURE_RANDOM.bytes(padding_len); let padding = self.rng.bytes(padding_len);
let total_len = data.len() + padding_len; let total_len = data.len() + padding_len;
let len_bytes = (total_len as u32).to_le_bytes(); let len_bytes = (total_len as u32).to_le_bytes();
@@ -454,11 +455,11 @@ pub enum FrameWriterKind<W> {
} }
impl<W: AsyncWrite + Unpin> FrameWriterKind<W> { impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
pub fn new(upstream: W, proto_tag: ProtoTag) -> Self { pub fn new(upstream: W, proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
match proto_tag { match proto_tag {
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)), ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)), ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream)), ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream, rng)),
} }
} }
@@ -483,6 +484,8 @@ impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
mod tests { mod tests {
use super::*; use super::*;
use tokio::io::duplex; use tokio::io::duplex;
use std::sync::Arc;
use crate::crypto::SecureRandom;
#[tokio::test] #[tokio::test]
async fn test_abridged_roundtrip() { async fn test_abridged_roundtrip() {
@@ -539,7 +542,7 @@ mod tests {
async fn test_secure_intermediate_padding() { async fn test_secure_intermediate_padding() {
let (client, server) = duplex(1024); let (client, server) = duplex(1024);
let mut writer = SecureIntermediateFrameWriter::new(client); let mut writer = SecureIntermediateFrameWriter::new(client, Arc::new(SecureRandom::new()));
let mut reader = SecureIntermediateFrameReader::new(server); let mut reader = SecureIntermediateFrameReader::new(server);
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8]; let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
@@ -572,7 +575,7 @@ mod tests {
async fn test_frame_reader_kind() { async fn test_frame_reader_kind() {
let (client, server) = duplex(1024); let (client, server) = duplex(1024);
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate); let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate, Arc::new(SecureRandom::new()));
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate); let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
let data = vec![1u8, 2, 3, 4]; let data = vec![1u8, 2, 3, 4];

View File

@@ -1,10 +1,43 @@
//! Stream wrappers for MTProto protocol layers //! Stream wrappers for MTProto protocol layers
pub mod state;
pub mod buffer_pool;
pub mod traits; pub mod traits;
pub mod crypto_stream; pub mod crypto_stream;
pub mod tls_stream; pub mod tls_stream;
pub mod frame;
pub mod frame_codec;
// Legacy compatibility - will be removed later
pub mod frame_stream; pub mod frame_stream;
// Re-export state machine types
pub use state::{
StreamState, Transition, PollResult,
ReadBuffer, WriteBuffer, HeaderBuffer, YieldBuffer,
};
// Re-export buffer pool
pub use buffer_pool::{BufferPool, PooledBuffer, PoolStats};
// Re-export stream implementations
pub use crypto_stream::{CryptoReader, CryptoWriter, PassthroughStream}; pub use crypto_stream::{CryptoReader, CryptoWriter, PassthroughStream};
pub use tls_stream::{FakeTlsReader, FakeTlsWriter}; pub use tls_stream::{FakeTlsReader, FakeTlsWriter};
pub use frame_stream::*;
// Re-export frame types
pub use frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait, create_codec};
// Re-export tokio-util compatible codecs
pub use frame_codec::{
FrameCodec,
AbridgedCodec, IntermediateCodec, SecureCodec,
};
// Legacy re-exports for compatibility
pub use frame_stream::{
AbridgedFrameReader, AbridgedFrameWriter,
IntermediateFrameReader, IntermediateFrameWriter,
SecureIntermediateFrameReader, SecureIntermediateFrameWriter,
MtprotoFrameReader, MtprotoFrameWriter,
FrameReaderKind, FrameWriterKind,
};

571
src/stream/state.rs Normal file
View File

@@ -0,0 +1,571 @@
//! State machine foundation types for async streams
//!
//! This module provides core types and traits for implementing
//! stateful async streams with proper partial read/write handling.
use bytes::{Bytes, BytesMut};
use std::io;
// ============= Core Traits =============
/// Trait for stream states
pub trait StreamState: Sized {
/// Check if this is a terminal state (no more transitions possible)
fn is_terminal(&self) -> bool;
/// Check if stream is in poisoned/error state
fn is_poisoned(&self) -> bool;
/// Get human-readable state name for debugging
fn state_name(&self) -> &'static str;
}
// ============= Transition Types =============
/// Result of a state transition
#[derive(Debug)]
pub enum Transition<S, O> {
/// Stay in the same state, no output
Same,
/// Transition to a new state, no output
Next(S),
/// Complete with output, typically transitions to Idle
Complete(O),
/// Yield output and transition to new state
Yield(O, S),
/// Error occurred, transition to error state
Error(io::Error),
}
impl<S, O> Transition<S, O> {
/// Check if transition produces output
pub fn has_output(&self) -> bool {
matches!(self, Transition::Complete(_) | Transition::Yield(_, _))
}
/// Map the output value
pub fn map_output<U, F: FnOnce(O) -> U>(self, f: F) -> Transition<S, U> {
match self {
Transition::Same => Transition::Same,
Transition::Next(s) => Transition::Next(s),
Transition::Complete(o) => Transition::Complete(f(o)),
Transition::Yield(o, s) => Transition::Yield(f(o), s),
Transition::Error(e) => Transition::Error(e),
}
}
/// Map the state value
pub fn map_state<T, F: FnOnce(S) -> T>(self, f: F) -> Transition<T, O> {
match self {
Transition::Same => Transition::Same,
Transition::Next(s) => Transition::Next(f(s)),
Transition::Complete(o) => Transition::Complete(o),
Transition::Yield(o, s) => Transition::Yield(o, f(s)),
Transition::Error(e) => Transition::Error(e),
}
}
}
// ============= Poll Result Types =============
/// Result of polling for more data
#[derive(Debug)]
pub enum PollResult<T> {
/// Data is ready
Ready(T),
/// Operation would block, need to poll again
Pending,
/// Need more input data (minimum bytes required)
NeedInput(usize),
/// End of stream reached
Eof,
/// Error occurred
Error(io::Error),
}
impl<T> PollResult<T> {
/// Check if result is ready
pub fn is_ready(&self) -> bool {
matches!(self, PollResult::Ready(_))
}
/// Check if result indicates EOF
pub fn is_eof(&self) -> bool {
matches!(self, PollResult::Eof)
}
/// Convert to Option, discarding non-ready states
pub fn ok(self) -> Option<T> {
match self {
PollResult::Ready(t) => Some(t),
_ => None,
}
}
/// Map the value
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> PollResult<U> {
match self {
PollResult::Ready(t) => PollResult::Ready(f(t)),
PollResult::Pending => PollResult::Pending,
PollResult::NeedInput(n) => PollResult::NeedInput(n),
PollResult::Eof => PollResult::Eof,
PollResult::Error(e) => PollResult::Error(e),
}
}
}
impl<T> From<io::Result<T>> for PollResult<T> {
fn from(result: io::Result<T>) -> Self {
match result {
Ok(t) => PollResult::Ready(t),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => PollResult::Pending,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => PollResult::Eof,
Err(e) => PollResult::Error(e),
}
}
}
// ============= Buffer State =============
/// State for buffered reading operations
#[derive(Debug)]
pub struct ReadBuffer {
/// The buffer holding data
buffer: BytesMut,
/// Target number of bytes to read (if known)
target: Option<usize>,
}
impl ReadBuffer {
/// Create new empty read buffer
pub fn new() -> Self {
Self {
buffer: BytesMut::with_capacity(8192),
target: None,
}
}
/// Create with specific capacity
pub fn with_capacity(capacity: usize) -> Self {
Self {
buffer: BytesMut::with_capacity(capacity),
target: None,
}
}
/// Create with target size
pub fn with_target(target: usize) -> Self {
Self {
buffer: BytesMut::with_capacity(target),
target: Some(target),
}
}
/// Get current buffer length
pub fn len(&self) -> usize {
self.buffer.len()
}
/// Check if buffer is empty
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
/// Check if target is reached
pub fn is_complete(&self) -> bool {
match self.target {
Some(t) => self.buffer.len() >= t,
None => false,
}
}
/// Get remaining bytes needed
pub fn remaining(&self) -> usize {
match self.target {
Some(t) => t.saturating_sub(self.buffer.len()),
None => 0,
}
}
/// Append data to buffer
pub fn extend(&mut self, data: &[u8]) {
self.buffer.extend_from_slice(data);
}
/// Take all data from buffer
pub fn take(&mut self) -> Bytes {
self.target = None;
self.buffer.split().freeze()
}
/// Take exactly n bytes
pub fn take_exact(&mut self, n: usize) -> Option<Bytes> {
if self.buffer.len() >= n {
Some(self.buffer.split_to(n).freeze())
} else {
None
}
}
/// Get a slice of the buffer
pub fn as_slice(&self) -> &[u8] {
&self.buffer
}
/// Get mutable access to underlying BytesMut
pub fn as_bytes_mut(&mut self) -> &mut BytesMut {
&mut self.buffer
}
/// Clear the buffer
pub fn clear(&mut self) {
self.buffer.clear();
self.target = None;
}
/// Set new target
pub fn set_target(&mut self, target: usize) {
self.target = Some(target);
}
}
impl Default for ReadBuffer {
fn default() -> Self {
Self::new()
}
}
/// State for buffered writing operations
#[derive(Debug)]
pub struct WriteBuffer {
/// The buffer holding data to write
buffer: BytesMut,
/// Position of next byte to write
position: usize,
/// Maximum buffer size
max_size: usize,
}
impl WriteBuffer {
/// Create new write buffer with default max size (256KB)
pub fn new() -> Self {
Self::with_max_size(256 * 1024)
}
/// Create with specific max size
pub fn with_max_size(max_size: usize) -> Self {
Self {
buffer: BytesMut::with_capacity(8192),
position: 0,
max_size,
}
}
/// Get pending bytes count
pub fn len(&self) -> usize {
self.buffer.len() - self.position
}
/// Check if buffer is empty (all written)
pub fn is_empty(&self) -> bool {
self.position >= self.buffer.len()
}
/// Check if buffer is full
pub fn is_full(&self) -> bool {
self.buffer.len() >= self.max_size
}
/// Get remaining capacity
pub fn remaining_capacity(&self) -> usize {
self.max_size.saturating_sub(self.buffer.len())
}
/// Append data to buffer
pub fn extend(&mut self, data: &[u8]) -> Result<(), ()> {
if self.buffer.len() + data.len() > self.max_size {
return Err(());
}
self.buffer.extend_from_slice(data);
Ok(())
}
/// Get slice of data to write
pub fn pending(&self) -> &[u8] {
&self.buffer[self.position..]
}
/// Advance position by n bytes (after successful write)
pub fn advance(&mut self, n: usize) {
self.position += n;
// If all data written, reset buffer
if self.position >= self.buffer.len() {
self.buffer.clear();
self.position = 0;
}
}
/// Clear the buffer
pub fn clear(&mut self) {
self.buffer.clear();
self.position = 0;
}
}
impl Default for WriteBuffer {
fn default() -> Self {
Self::new()
}
}
// ============= Fixed-Size Buffer States =============
/// State for reading a fixed-size header
#[derive(Debug, Clone)]
pub struct HeaderBuffer<const N: usize> {
/// The buffer
data: [u8; N],
/// Bytes filled so far
filled: usize,
}
impl<const N: usize> HeaderBuffer<N> {
/// Create new empty header buffer
pub fn new() -> Self {
Self {
data: [0u8; N],
filled: 0,
}
}
/// Get slice for reading into
pub fn unfilled_mut(&mut self) -> &mut [u8] {
&mut self.data[self.filled..]
}
/// Advance filled count
pub fn advance(&mut self, n: usize) {
self.filled = (self.filled + n).min(N);
}
/// Check if completely filled
pub fn is_complete(&self) -> bool {
self.filled >= N
}
/// Get remaining bytes needed
pub fn remaining(&self) -> usize {
N - self.filled
}
/// Get filled bytes as slice
pub fn as_slice(&self) -> &[u8] {
&self.data[..self.filled]
}
/// Get complete buffer (panics if not complete)
pub fn as_array(&self) -> &[u8; N] {
assert!(self.is_complete());
&self.data
}
/// Take the buffer, resetting state
pub fn take(&mut self) -> [u8; N] {
let data = self.data;
self.data = [0u8; N];
self.filled = 0;
data
}
/// Reset to empty state
pub fn reset(&mut self) {
self.filled = 0;
}
}
impl<const N: usize> Default for HeaderBuffer<N> {
fn default() -> Self {
Self::new()
}
}
// ============= Yield Buffer =============
/// Buffer for yielding data to caller in chunks
#[derive(Debug)]
pub struct YieldBuffer {
data: Bytes,
position: usize,
}
impl YieldBuffer {
/// Create new yield buffer
pub fn new(data: Bytes) -> Self {
Self { data, position: 0 }
}
/// Check if all data has been yielded
pub fn is_empty(&self) -> bool {
self.position >= self.data.len()
}
/// Get remaining bytes
pub fn remaining(&self) -> usize {
self.data.len() - self.position
}
/// Copy data to output slice, return bytes copied
pub fn copy_to(&mut self, dst: &mut [u8]) -> usize {
let available = &self.data[self.position..];
let to_copy = available.len().min(dst.len());
dst[..to_copy].copy_from_slice(&available[..to_copy]);
self.position += to_copy;
to_copy
}
/// Get remaining data as slice
pub fn as_slice(&self) -> &[u8] {
&self.data[self.position..]
}
}
// ============= Macros =============
/// Macro to simplify state transitions in poll methods
#[macro_export]
macro_rules! transition {
(same) => {
$crate::stream::state::Transition::Same
};
(next $state:expr) => {
$crate::stream::state::Transition::Next($state)
};
(complete $output:expr) => {
$crate::stream::state::Transition::Complete($output)
};
(yield $output:expr, $state:expr) => {
$crate::stream::state::Transition::Yield($output, $state)
};
(error $err:expr) => {
$crate::stream::state::Transition::Error($err)
};
}
/// Macro to match poll ready or return pending
#[macro_export]
macro_rules! ready_or_pending {
($poll:expr) => {
match $poll {
std::task::Poll::Ready(t) => t,
std::task::Poll::Pending => return std::task::Poll::Pending,
}
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read_buffer_basic() {
let mut buf = ReadBuffer::with_target(10);
assert_eq!(buf.remaining(), 10);
assert!(!buf.is_complete());
buf.extend(b"hello");
assert_eq!(buf.len(), 5);
assert_eq!(buf.remaining(), 5);
assert!(!buf.is_complete());
buf.extend(b"world");
assert_eq!(buf.len(), 10);
assert!(buf.is_complete());
}
#[test]
fn test_read_buffer_take() {
let mut buf = ReadBuffer::new();
buf.extend(b"test data");
let data = buf.take();
assert_eq!(&data[..], b"test data");
assert!(buf.is_empty());
}
#[test]
fn test_write_buffer_basic() {
let mut buf = WriteBuffer::with_max_size(100);
assert!(buf.is_empty());
buf.extend(b"hello").unwrap();
assert_eq!(buf.len(), 5);
assert!(!buf.is_empty());
buf.advance(3);
assert_eq!(buf.len(), 2);
assert_eq!(buf.pending(), b"lo");
}
#[test]
fn test_write_buffer_overflow() {
let mut buf = WriteBuffer::with_max_size(10);
assert!(buf.extend(b"short").is_ok());
assert!(buf.extend(b"toolong").is_err());
}
#[test]
fn test_header_buffer() {
let mut buf = HeaderBuffer::<5>::new();
assert!(!buf.is_complete());
assert_eq!(buf.remaining(), 5);
buf.unfilled_mut()[..3].copy_from_slice(b"hel");
buf.advance(3);
assert_eq!(buf.remaining(), 2);
buf.unfilled_mut()[..2].copy_from_slice(b"lo");
buf.advance(2);
assert!(buf.is_complete());
assert_eq!(buf.as_array(), b"hello");
}
#[test]
fn test_yield_buffer() {
let mut buf = YieldBuffer::new(Bytes::from_static(b"hello world"));
let mut dst = [0u8; 5];
assert_eq!(buf.copy_to(&mut dst), 5);
assert_eq!(&dst, b"hello");
assert_eq!(buf.remaining(), 6);
let mut dst = [0u8; 10];
assert_eq!(buf.copy_to(&mut dst), 6);
assert_eq!(&dst[..6], b" world");
assert!(buf.is_empty());
}
#[test]
fn test_transition_map() {
let t: Transition<i32, String> = Transition::Complete("hello".to_string());
let t = t.map_output(|s| s.len());
match t {
Transition::Complete(5) => {}
_ => panic!("Expected Complete(5)"),
}
}
#[test]
fn test_poll_result() {
let r: PollResult<i32> = PollResult::Ready(42);
assert!(r.is_ready());
assert_eq!(r.ok(), Some(42));
let r: PollResult<i32> = PollResult::Eof;
assert!(r.is_eof());
assert_eq!(r.ok(), None);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,11 @@
pub mod pool; pub mod pool;
pub mod proxy_protocol; pub mod proxy_protocol;
pub mod socket; pub mod socket;
pub mod socks;
pub mod upstream;
pub use pool::ConnectionPool; pub use pool::ConnectionPool;
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol}; pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
pub use socket::*; pub use socket::*;
pub use socks::*;
pub use upstream::{UpstreamManager, StartupPingResult, DcPingResult};

View File

@@ -1,7 +1,7 @@
//! TCP Socket Configuration //! TCP Socket Configuration
use std::io::Result; use std::io::Result;
use std::net::SocketAddr; use std::net::{SocketAddr, IpAddr};
use std::time::Duration; use std::time::Duration;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use socket2::{Socket, TcpKeepalive, Domain, Type, Protocol}; use socket2::{Socket, TcpKeepalive, Domain, Type, Protocol};
@@ -30,20 +30,13 @@ pub fn configure_tcp_socket(
socket.set_tcp_keepalive(&keepalive)?; socket.set_tcp_keepalive(&keepalive)?;
} }
// Set buffer sizes // CHANGED: Removed manual buffer size setting (was 256KB).
set_buffer_sizes(&socket, 65536, 65536)?; // Allowing the OS kernel to handle TCP window scaling (Autotuning) is critical
// for mobile clients to avoid bufferbloat and stalled connections during uploads.
Ok(()) Ok(())
} }
/// Set socket buffer sizes
fn set_buffer_sizes(socket: &socket2::SockRef, recv: usize, send: usize) -> Result<()> {
// These may fail on some systems, so we ignore errors
let _ = socket.set_recv_buffer_size(recv);
let _ = socket.set_send_buffer_size(send);
Ok(())
}
/// Configure socket for accepting client connections /// Configure socket for accepting client connections
pub fn configure_client_socket( pub fn configure_client_socket(
stream: &TcpStream, stream: &TcpStream,
@@ -65,6 +58,8 @@ pub fn configure_client_socket(
socket.set_tcp_keepalive(&keepalive)?; socket.set_tcp_keepalive(&keepalive)?;
// Set TCP user timeout (Linux only) // Set TCP user timeout (Linux only)
// NOTE: iOS does not support TCP_USER_TIMEOUT - application-level timeout
// is implemented in relay_bidirectional instead
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
{ {
use std::os::unix::io::AsRawFd; use std::os::unix::io::AsRawFd;
@@ -93,6 +88,11 @@ pub fn set_linger_zero(stream: &TcpStream) -> Result<()> {
/// Create a new TCP socket for outgoing connections /// Create a new TCP socket for outgoing connections
pub fn create_outgoing_socket(addr: SocketAddr) -> Result<Socket> { pub fn create_outgoing_socket(addr: SocketAddr) -> Result<Socket> {
create_outgoing_socket_bound(addr, None)
}
/// Create a new TCP socket for outgoing connections, optionally bound to a specific interface
pub fn create_outgoing_socket_bound(addr: SocketAddr, bind_addr: Option<IpAddr>) -> Result<Socket> {
let domain = if addr.is_ipv4() { let domain = if addr.is_ipv4() {
Domain::IPV4 Domain::IPV4
} else { } else {
@@ -106,10 +106,17 @@ pub fn create_outgoing_socket(addr: SocketAddr) -> Result<Socket> {
// Disable Nagle // Disable Nagle
socket.set_nodelay(true)?; socket.set_nodelay(true)?;
if let Some(bind_ip) = bind_addr {
let bind_sock_addr = SocketAddr::new(bind_ip, 0);
socket.bind(&bind_sock_addr.into())?;
debug!("Bound outgoing socket to {}", bind_ip);
}
Ok(socket) Ok(socket)
} }
/// Get local address of a socket /// Get local address of a socket
pub fn get_local_addr(stream: &TcpStream) -> Option<SocketAddr> { pub fn get_local_addr(stream: &TcpStream) -> Option<SocketAddr> {
stream.local_addr().ok() stream.local_addr().ok()

145
src/transport/socks.rs Normal file
View File

@@ -0,0 +1,145 @@
//! SOCKS4/5 Client Implementation
use std::net::{IpAddr, SocketAddr};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use crate::error::{ProxyError, Result};
pub async fn connect_socks4(
stream: &mut TcpStream,
target: SocketAddr,
user_id: Option<&str>,
) -> Result<()> {
let ip = match target.ip() {
IpAddr::V4(ip) => ip,
IpAddr::V6(_) => return Err(ProxyError::Proxy("SOCKS4 does not support IPv6".to_string())),
};
let port = target.port();
let user = user_id.unwrap_or("").as_bytes();
// VN (4) | CD (1) | DSTPORT (2) | DSTIP (4) | USERID (variable) | NULL (1)
let mut buf = Vec::with_capacity(9 + user.len());
buf.push(4); // VN
buf.push(1); // CD (CONNECT)
buf.extend_from_slice(&port.to_be_bytes());
buf.extend_from_slice(&ip.octets());
buf.extend_from_slice(user);
buf.push(0); // NULL
stream.write_all(&buf).await.map_err(|e| ProxyError::Io(e))?;
// Response: VN (1) | CD (1) | DSTPORT (2) | DSTIP (4)
let mut resp = [0u8; 8];
stream.read_exact(&mut resp).await.map_err(|e| ProxyError::Io(e))?;
if resp[1] != 90 {
return Err(ProxyError::Proxy(format!("SOCKS4 request rejected: code {}", resp[1])));
}
Ok(())
}
pub async fn connect_socks5(
stream: &mut TcpStream,
target: SocketAddr,
username: Option<&str>,
password: Option<&str>,
) -> Result<()> {
// 1. Auth negotiation
// VER (1) | NMETHODS (1) | METHODS (variable)
let mut methods = vec![0u8]; // No auth
if username.is_some() {
methods.push(2u8); // Username/Password
}
let mut buf = vec![5u8, methods.len() as u8];
buf.extend_from_slice(&methods);
stream.write_all(&buf).await.map_err(|e| ProxyError::Io(e))?;
let mut resp = [0u8; 2];
stream.read_exact(&mut resp).await.map_err(|e| ProxyError::Io(e))?;
if resp[0] != 5 {
return Err(ProxyError::Proxy("Invalid SOCKS5 version".to_string()));
}
match resp[1] {
0 => {}, // No auth
2 => {
// Username/Password auth
if let (Some(u), Some(p)) = (username, password) {
let u_bytes = u.as_bytes();
let p_bytes = p.as_bytes();
let mut auth_buf = Vec::with_capacity(3 + u_bytes.len() + p_bytes.len());
auth_buf.push(1); // VER
auth_buf.push(u_bytes.len() as u8);
auth_buf.extend_from_slice(u_bytes);
auth_buf.push(p_bytes.len() as u8);
auth_buf.extend_from_slice(p_bytes);
stream.write_all(&auth_buf).await.map_err(|e| ProxyError::Io(e))?;
let mut auth_resp = [0u8; 2];
stream.read_exact(&mut auth_resp).await.map_err(|e| ProxyError::Io(e))?;
if auth_resp[1] != 0 {
return Err(ProxyError::Proxy("SOCKS5 authentication failed".to_string()));
}
} else {
return Err(ProxyError::Proxy("SOCKS5 server requires authentication".to_string()));
}
},
_ => return Err(ProxyError::Proxy("Unsupported SOCKS5 auth method".to_string())),
}
// 2. Connection request
// VER (1) | CMD (1) | RSV (1) | ATYP (1) | DST.ADDR (variable) | DST.PORT (2)
let mut req = vec![5u8, 1u8, 0u8]; // CONNECT
match target {
SocketAddr::V4(v4) => {
req.push(1u8); // IPv4
req.extend_from_slice(&v4.ip().octets());
},
SocketAddr::V6(v6) => {
req.push(4u8); // IPv6
req.extend_from_slice(&v6.ip().octets());
},
}
req.extend_from_slice(&target.port().to_be_bytes());
stream.write_all(&req).await.map_err(|e| ProxyError::Io(e))?;
// Response
let mut head = [0u8; 4];
stream.read_exact(&mut head).await.map_err(|e| ProxyError::Io(e))?;
if head[1] != 0 {
return Err(ProxyError::Proxy(format!("SOCKS5 request failed: code {}", head[1])));
}
// Skip address part of response
match head[3] {
1 => { // IPv4
let mut addr = [0u8; 4 + 2];
stream.read_exact(&mut addr).await.map_err(|e| ProxyError::Io(e))?;
},
3 => { // Domain
let mut len = [0u8; 1];
stream.read_exact(&mut len).await.map_err(|e| ProxyError::Io(e))?;
let mut addr = vec![0u8; len[0] as usize + 2];
stream.read_exact(&mut addr).await.map_err(|e| ProxyError::Io(e))?;
},
4 => { // IPv6
let mut addr = [0u8; 16 + 2];
stream.read_exact(&mut addr).await.map_err(|e| ProxyError::Io(e))?;
},
_ => return Err(ProxyError::Proxy("Invalid address type in SOCKS5 response".to_string())),
}
Ok(())
}

488
src/transport/upstream.rs Normal file
View File

@@ -0,0 +1,488 @@
//! Upstream Management with per-DC latency-weighted selection
use std::net::{SocketAddr, IpAddr};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::RwLock;
use tokio::time::Instant;
use rand::Rng;
use tracing::{debug, warn, info, trace};
use crate::config::{UpstreamConfig, UpstreamType};
use crate::error::{Result, ProxyError};
use crate::protocol::constants::{TG_DATACENTERS_V4, TG_DATACENTERS_V6, TG_DATACENTER_PORT};
use crate::transport::socket::create_outgoing_socket_bound;
use crate::transport::socks::{connect_socks4, connect_socks5};
/// Number of Telegram datacenters
const NUM_DCS: usize = 5;
// ============= RTT Tracking =============
#[derive(Debug, Clone, Copy)]
struct LatencyEma {
value_ms: Option<f64>,
alpha: f64,
}
impl LatencyEma {
const fn new(alpha: f64) -> Self {
Self { value_ms: None, alpha }
}
fn update(&mut self, sample_ms: f64) {
self.value_ms = Some(match self.value_ms {
None => sample_ms,
Some(prev) => prev * (1.0 - self.alpha) + sample_ms * self.alpha,
});
}
fn get(&self) -> Option<f64> {
self.value_ms
}
}
// ============= Upstream State =============
#[derive(Debug)]
struct UpstreamState {
config: UpstreamConfig,
healthy: bool,
fails: u32,
last_check: std::time::Instant,
/// Per-DC latency EMA (index 0 = DC1, index 4 = DC5)
dc_latency: [LatencyEma; NUM_DCS],
}
impl UpstreamState {
fn new(config: UpstreamConfig) -> Self {
Self {
config,
healthy: true,
fails: 0,
last_check: std::time::Instant::now(),
dc_latency: [LatencyEma::new(0.3); NUM_DCS],
}
}
/// Map DC index to latency array slot (0..NUM_DCS).
///
/// Matches the C implementation's `mf_cluster_lookup` behavior:
/// - Standard DCs ±1..±5 → direct mapping to array index 0..4
/// - Unknown DCs (CDN, media, etc.) → default DC slot (index 1 = DC 2)
/// This matches Telegram's `default 2;` in proxy-multi.conf.
/// - There is NO modular arithmetic in the C implementation.
fn dc_array_idx(dc_idx: i16) -> Option<usize> {
let abs_dc = dc_idx.unsigned_abs() as usize;
if abs_dc == 0 {
return None;
}
if abs_dc >= 1 && abs_dc <= NUM_DCS {
Some(abs_dc - 1)
} else {
// Unknown DC → default cluster (DC 2, index 1)
// Same as C: mf_cluster_lookup returns default_cluster
Some(1)
}
}
/// Get latency for a specific DC, falling back to average across all known DCs
fn effective_latency(&self, dc_idx: Option<i16>) -> Option<f64> {
// Try DC-specific latency first
if let Some(di) = dc_idx.and_then(Self::dc_array_idx) {
if let Some(ms) = self.dc_latency[di].get() {
return Some(ms);
}
}
// Fallback: average of all known DC latencies
let (sum, count) = self.dc_latency.iter()
.filter_map(|l| l.get())
.fold((0.0, 0u32), |(s, c), v| (s + v, c + 1));
if count > 0 { Some(sum / count as f64) } else { None }
}
}
/// Result of a single DC ping
#[derive(Debug, Clone)]
pub struct DcPingResult {
pub dc_idx: usize,
pub dc_addr: SocketAddr,
pub rtt_ms: Option<f64>,
pub error: Option<String>,
}
/// Result of startup ping for one upstream
#[derive(Debug, Clone)]
pub struct StartupPingResult {
pub results: Vec<DcPingResult>,
pub upstream_name: String,
}
// ============= Upstream Manager =============
#[derive(Clone)]
pub struct UpstreamManager {
upstreams: Arc<RwLock<Vec<UpstreamState>>>,
}
impl UpstreamManager {
pub fn new(configs: Vec<UpstreamConfig>) -> Self {
let states = configs.into_iter()
.filter(|c| c.enabled)
.map(UpstreamState::new)
.collect();
Self {
upstreams: Arc::new(RwLock::new(states)),
}
}
/// Select upstream using latency-weighted random selection.
///
/// `effective_weight = config_weight × latency_factor`
///
/// where `latency_factor = 1000 / latency_ms` if latency is known,
/// or `1.0` if no latency data is available.
///
/// This means a 50ms upstream gets factor 20, a 200ms upstream gets
/// factor 5 — the faster route is 4× more likely to be chosen
/// (all else being equal).
async fn select_upstream(&self, dc_idx: Option<i16>) -> Option<usize> {
let upstreams = self.upstreams.read().await;
if upstreams.is_empty() {
return None;
}
let healthy: Vec<usize> = upstreams.iter()
.enumerate()
.filter(|(_, u)| u.healthy)
.map(|(i, _)| i)
.collect();
if healthy.is_empty() {
// All unhealthy — pick any
return Some(rand::rng().gen_range(0..upstreams.len()));
}
if healthy.len() == 1 {
return Some(healthy[0]);
}
// Calculate latency-weighted scores
let weights: Vec<(usize, f64)> = healthy.iter().map(|&i| {
let base = upstreams[i].config.weight as f64;
let latency_factor = upstreams[i].effective_latency(dc_idx)
.map(|ms| if ms > 1.0 { 1000.0 / ms } else { 1000.0 })
.unwrap_or(1.0);
(i, base * latency_factor)
}).collect();
let total: f64 = weights.iter().map(|(_, w)| w).sum();
if total <= 0.0 {
return Some(healthy[rand::rng().gen_range(0..healthy.len())]);
}
let mut choice: f64 = rand::rng().gen_range(0.0..total);
for &(idx, weight) in &weights {
if choice < weight {
trace!(
upstream = idx,
dc = ?dc_idx,
weight = format!("{:.2}", weight),
total = format!("{:.2}", total),
"Upstream selected"
);
return Some(idx);
}
choice -= weight;
}
Some(healthy[0])
}
/// Connect to target through a selected upstream.
///
/// `dc_idx` is used for latency-based upstream selection and RTT tracking.
/// Pass `None` if DC index is unknown.
pub async fn connect(&self, target: SocketAddr, dc_idx: Option<i16>) -> Result<TcpStream> {
let idx = self.select_upstream(dc_idx).await
.ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?;
let upstream = {
let guard = self.upstreams.read().await;
guard[idx].config.clone()
};
let start = Instant::now();
match self.connect_via_upstream(&upstream, target).await {
Ok(stream) => {
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(idx) {
if !u.healthy {
debug!(rtt_ms = format!("{:.1}", rtt_ms), "Upstream recovered");
}
u.healthy = true;
u.fails = 0;
// Store per-DC latency
if let Some(di) = dc_idx.and_then(UpstreamState::dc_array_idx) {
u.dc_latency[di].update(rtt_ms);
}
}
Ok(stream)
},
Err(e) => {
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(idx) {
u.fails += 1;
warn!(fails = u.fails, "Upstream failed: {}", e);
if u.fails > 3 {
u.healthy = false;
warn!("Upstream marked unhealthy");
}
}
Err(e)
}
}
}
async fn connect_via_upstream(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<TcpStream> {
match &config.upstream_type {
UpstreamType::Direct { interface } => {
let bind_ip = interface.as_ref()
.and_then(|s| s.parse::<IpAddr>().ok());
let socket = create_outgoing_socket_bound(target, bind_ip)?;
socket.set_nonblocking(true)?;
match socket.connect(&target.into()) {
Ok(()) => {},
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(ProxyError::Io(err)),
}
let std_stream: std::net::TcpStream = socket.into();
let stream = TcpStream::from_std(std_stream)?;
stream.writable().await?;
if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e));
}
Ok(stream)
},
UpstreamType::Socks4 { address, interface, user_id } => {
let proxy_addr: SocketAddr = address.parse()
.map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?;
let bind_ip = interface.as_ref()
.and_then(|s| s.parse::<IpAddr>().ok());
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
socket.set_nonblocking(true)?;
match socket.connect(&proxy_addr.into()) {
Ok(()) => {},
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(ProxyError::Io(err)),
}
let std_stream: std::net::TcpStream = socket.into();
let mut stream = TcpStream::from_std(std_stream)?;
stream.writable().await?;
if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e));
}
connect_socks4(&mut stream, target, user_id.as_deref()).await?;
Ok(stream)
},
UpstreamType::Socks5 { address, interface, username, password } => {
let proxy_addr: SocketAddr = address.parse()
.map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?;
let bind_ip = interface.as_ref()
.and_then(|s| s.parse::<IpAddr>().ok());
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
socket.set_nonblocking(true)?;
match socket.connect(&proxy_addr.into()) {
Ok(()) => {},
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(ProxyError::Io(err)),
}
let std_stream: std::net::TcpStream = socket.into();
let mut stream = TcpStream::from_std(std_stream)?;
stream.writable().await?;
if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e));
}
connect_socks5(&mut stream, target, username.as_deref(), password.as_deref()).await?;
Ok(stream)
},
}
}
// ============= Startup Ping =============
/// Ping all Telegram DCs through all upstreams.
pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec<StartupPingResult> {
let upstreams: Vec<(usize, UpstreamConfig)> = {
let guard = self.upstreams.read().await;
guard.iter().enumerate()
.map(|(i, u)| (i, u.config.clone()))
.collect()
};
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
let mut all_results = Vec::new();
for (upstream_idx, upstream_config) in &upstreams {
let upstream_name = match &upstream_config.upstream_type {
UpstreamType::Direct { interface } => {
format!("direct{}", interface.as_ref().map(|i| format!(" ({})", i)).unwrap_or_default())
}
UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address),
UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address),
};
let mut dc_results = Vec::new();
for (dc_zero_idx, dc_ip) in datacenters.iter().enumerate() {
let dc_addr = SocketAddr::new(*dc_ip, TG_DATACENTER_PORT);
let ping_result = tokio::time::timeout(
Duration::from_secs(5),
self.ping_single_dc(upstream_config, dc_addr)
).await;
let result = match ping_result {
Ok(Ok(rtt_ms)) => {
// Store per-DC latency
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) {
u.dc_latency[dc_zero_idx].update(rtt_ms);
}
DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr,
rtt_ms: Some(rtt_ms),
error: None,
}
}
Ok(Err(e)) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr,
rtt_ms: None,
error: Some(e.to_string()),
},
Err(_) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr,
rtt_ms: None,
error: Some("timeout (5s)".to_string()),
},
};
dc_results.push(result);
}
all_results.push(StartupPingResult {
results: dc_results,
upstream_name,
});
}
all_results
}
async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<f64> {
let start = Instant::now();
let _stream = self.connect_via_upstream(config, target).await?;
Ok(start.elapsed().as_secs_f64() * 1000.0)
}
// ============= Health Checks =============
/// Background health check: rotates through DCs, 30s interval.
pub async fn run_health_checks(&self, prefer_ipv6: bool) {
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
let mut dc_rotation = 0usize;
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
let dc_zero_idx = dc_rotation % datacenters.len();
dc_rotation += 1;
let check_target = SocketAddr::new(datacenters[dc_zero_idx], TG_DATACENTER_PORT);
let count = self.upstreams.read().await.len();
for i in 0..count {
let config = {
let guard = self.upstreams.read().await;
guard[i].config.clone()
};
let start = Instant::now();
let result = tokio::time::timeout(
Duration::from_secs(10),
self.connect_via_upstream(&config, check_target)
).await;
let mut guard = self.upstreams.write().await;
let u = &mut guard[i];
match result {
Ok(Ok(_stream)) => {
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
u.dc_latency[dc_zero_idx].update(rtt_ms);
if !u.healthy {
info!(
rtt = format!("{:.0}ms", rtt_ms),
dc = dc_zero_idx + 1,
"Upstream recovered"
);
}
u.healthy = true;
u.fails = 0;
}
Ok(Err(e)) => {
u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check failed: {}", e);
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (fails)");
}
}
Err(_) => {
u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check timeout");
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (timeout)");
}
}
}
u.last_check = std::time::Instant::now();
}
}
}
}

View File

@@ -1,6 +1,6 @@
//! IP Addr Detect //! IP Addr Detect
use std::net::IpAddr; use std::net::{IpAddr, SocketAddr, UdpSocket};
use std::time::Duration; use std::time::Duration;
use tracing::{debug, warn}; use tracing::{debug, warn};
@@ -40,28 +40,74 @@ const IPV6_URLS: &[&str] = &[
"http://api6.ipify.org/", "http://api6.ipify.org/",
]; ];
/// Detect local IP address by connecting to a public DNS
/// This does not actually send any packets
fn get_local_ip(target: &str) -> Option<IpAddr> {
let socket = UdpSocket::bind("0.0.0.0:0").ok()?;
socket.connect(target).ok()?;
socket.local_addr().ok().map(|addr| addr.ip())
}
fn get_local_ipv6(target: &str) -> Option<IpAddr> {
let socket = UdpSocket::bind("[::]:0").ok()?;
socket.connect(target).ok()?;
socket.local_addr().ok().map(|addr| addr.ip())
}
/// Detect public IP addresses /// Detect public IP addresses
pub async fn detect_ip() -> IpInfo { pub async fn detect_ip() -> IpInfo {
let mut info = IpInfo::default(); let mut info = IpInfo::default();
// Try to get local interface IP first (default gateway interface)
// We connect to Google DNS to find out which interface is used for routing
if let Some(ip) = get_local_ip("8.8.8.8:80") {
if ip.is_ipv4() && !ip.is_loopback() {
info.ipv4 = Some(ip);
debug!(ip = %ip, "Detected local IPv4 address via routing");
}
}
if let Some(ip) = get_local_ipv6("[2001:4860:4860::8888]:80") {
if ip.is_ipv6() && !ip.is_loopback() {
info.ipv6 = Some(ip);
debug!(ip = %ip, "Detected local IPv6 address via routing");
}
}
// Detect IPv4 // If local detection failed or returned private IP (and we want public),
for url in IPV4_URLS { // or just as a fallback/verification, we might want to check external services.
if let Some(ip) = fetch_ip(url).await { // However, the requirement is: "if IP for listening is not set... it should be IP from interface...
if ip.is_ipv4() { // if impossible - request external resources".
info.ipv4 = Some(ip);
debug!(ip = %ip, "Detected IPv4 address"); // So if we found a local IP, we might be good. But often servers are behind NAT.
break; // If the local IP is private, we probably want the public IP for the tg:// link.
// Let's check if the detected IPs are private.
let need_external_v4 = info.ipv4.map_or(true, |ip| is_private_ip(ip));
let need_external_v6 = info.ipv6.map_or(true, |ip| is_private_ip(ip));
if need_external_v4 {
debug!("Local IPv4 is private or missing, checking external services...");
for url in IPV4_URLS {
if let Some(ip) = fetch_ip(url).await {
if ip.is_ipv4() {
info.ipv4 = Some(ip);
debug!(ip = %ip, "Detected public IPv4 address");
break;
}
} }
} }
} }
// Detect IPv6 if need_external_v6 {
for url in IPV6_URLS { debug!("Local IPv6 is private or missing, checking external services...");
if let Some(ip) = fetch_ip(url).await { for url in IPV6_URLS {
if ip.is_ipv6() { if let Some(ip) = fetch_ip(url).await {
info.ipv6 = Some(ip); if ip.is_ipv6() {
debug!(ip = %ip, "Detected IPv6 address"); info.ipv6 = Some(ip);
break; debug!(ip = %ip, "Detected public IPv6 address");
break;
}
} }
} }
} }
@@ -73,6 +119,17 @@ pub async fn detect_ip() -> IpInfo {
info info
} }
fn is_private_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(ipv4) => {
ipv4.is_private() || ipv4.is_loopback() || ipv4.is_link_local()
}
IpAddr::V6(ipv6) => {
ipv6.is_loopback() || (ipv6.segments()[0] & 0xfe00) == 0xfc00 // Unique Local
}
}
}
/// Fetch IP from URL /// Fetch IP from URL
async fn fetch_ip(url: &str) -> Option<IpAddr> { async fn fetch_ip(url: &str) -> Option<IpAddr> {
let client = reqwest::Client::builder() let client = reqwest::Client::builder()

12
telemt.service Normal file
View File

@@ -0,0 +1,12 @@
[Unit]
Description=Telemt
After=network.target
[Service]
Type=simple
WorkingDirectory=/bin
ExecStart=/bin/telemt /etc/telemt.toml
Restart=on-failure
[Install]
WantedBy=multi-user.target