Compare commits
99 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a688bfe22f | ||
|
|
9bd12f6acb | ||
|
|
61581203c4 | ||
|
|
84668e671e | ||
|
|
5bde202866 | ||
|
|
9304d5256a | ||
|
|
364bc6e278 | ||
|
|
e83db704b7 | ||
|
|
acf90043eb | ||
|
|
0011e20653 | ||
|
|
41fb307858 | ||
|
|
6a78c44d2e | ||
|
|
be9c9858ac | ||
|
|
2fa8d85b4c | ||
|
|
310666fd44 | ||
|
|
6cafee153a | ||
|
|
32f60f34db | ||
|
|
158eae8d2a | ||
|
|
92cedabc81 | ||
|
|
b9428d9780 | ||
|
|
5876f0c4d5 | ||
|
|
94750a2749 | ||
|
|
cf4b240913 | ||
|
|
1424fbb1d5 | ||
|
|
97f4c0d3b7 | ||
|
|
806536fab6 | ||
|
|
df8cfe462b | ||
|
|
a5f1521d71 | ||
|
|
8de7b7adc0 | ||
|
|
cde1b15ef0 | ||
|
|
46e4c06ba6 | ||
|
|
b7673daf0f | ||
|
|
397ed8f193 | ||
|
|
d90b2fd300 | ||
|
|
d62136d9fa | ||
|
|
0f8933b908 | ||
|
|
0ec87974d1 | ||
|
|
c8446c32d1 | ||
|
|
f79a2eb097 | ||
|
|
dea1a3b5de | ||
|
|
97ce235ae4 | ||
|
|
d04757eb9c | ||
|
|
2d7901a978 | ||
|
|
3881ba9bed | ||
|
|
5ac9089ccb | ||
|
|
eb8b991818 | ||
|
|
2ce8fbb2cc | ||
|
|
038f0cd5d1 | ||
|
|
efea3f981d | ||
|
|
42ce9dd671 | ||
|
|
4fa6867056 | ||
|
|
54ea6efdd0 | ||
|
|
27ac32a901 | ||
|
|
829f53c123 | ||
|
|
43eae6127d | ||
|
|
a03212c8cc | ||
|
|
2613969a7c | ||
|
|
be1b2db867 | ||
|
|
8fbee8701b | ||
|
|
952d160870 | ||
|
|
91ae6becde | ||
|
|
e1f576e4fe | ||
|
|
a7556cabdc | ||
|
|
b2e8d16bb1 | ||
|
|
d95e762812 | ||
|
|
384f927fc3 | ||
|
|
1b7c09ae18 | ||
|
|
85cb4092d5 | ||
|
|
5016160ac3 | ||
|
|
4f007f3128 | ||
|
|
7746a1177c | ||
|
|
2bb2a2983f | ||
|
|
5778be4f6e | ||
|
|
f443d3dfc7 | ||
|
|
450cf180ad | ||
|
|
84fa7face0 | ||
|
|
f8a2ea1972 | ||
|
|
96d0a6bdfa | ||
|
|
eeee55e8ea | ||
|
|
7be179b3c0 | ||
|
|
b2e034f8f1 | ||
|
|
ffe5a6cfb7 | ||
|
|
0e096ca8fb | ||
|
|
50658525cf | ||
|
|
4fd5ff4e83 | ||
|
|
df4f312fec | ||
|
|
7d9a8b99b4 | ||
|
|
06f34e55cd | ||
|
|
153cb7f3a3 | ||
|
|
7f8904a989 | ||
|
|
0ee71a59a0 | ||
|
|
45c7347e22 | ||
|
|
3805237d74 | ||
|
|
5b281bf7fd | ||
|
|
d64cccd52c | ||
|
|
016fdada68 | ||
|
|
2c2ceeaf54 | ||
|
|
dd6badd786 | ||
|
|
50e72368c8 |
46
.github/workflows/rust.yml
vendored
Normal file
46
.github/workflows/rust.yml
vendored
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
name: Rust
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
env:
|
||||||
|
CARGO_TERM_COLOR: always
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: Build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
actions: write
|
||||||
|
checks: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install latest stable Rust toolchain
|
||||||
|
uses: dtolnay/rust-toolchain@stable
|
||||||
|
with:
|
||||||
|
components: rustfmt, clippy
|
||||||
|
|
||||||
|
- name: Cache cargo registry & build artifacts
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo/registry
|
||||||
|
~/.cargo/git
|
||||||
|
target
|
||||||
|
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-cargo-
|
||||||
|
|
||||||
|
- name: Build Release
|
||||||
|
run: cargo build --release --verbose
|
||||||
|
|
||||||
|
- name: Check for unused dependencies
|
||||||
|
run: cargo udeps || true
|
||||||
2742
Cargo.lock
generated
Normal file
2742
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
28
Cargo.toml
28
Cargo.toml
@@ -1,15 +1,14 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "telemt"
|
name = "telemt"
|
||||||
version = "1.0.0"
|
version = "1.2.0"
|
||||||
edition = "2021"
|
edition = "2024"
|
||||||
rust-version = "1.75"
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# C
|
# C
|
||||||
libc = "0.2"
|
libc = "0.2"
|
||||||
|
|
||||||
# Async runtime
|
# Async runtime
|
||||||
tokio = { version = "1.35", features = ["full", "tracing"] }
|
tokio = { version = "1.42", features = ["full", "tracing"] }
|
||||||
tokio-util = { version = "0.7", features = ["codec"] }
|
tokio-util = { version = "0.7", features = ["codec"] }
|
||||||
|
|
||||||
# Crypto
|
# Crypto
|
||||||
@@ -20,40 +19,41 @@ sha2 = "0.10"
|
|||||||
sha1 = "0.10"
|
sha1 = "0.10"
|
||||||
md-5 = "0.10"
|
md-5 = "0.10"
|
||||||
hmac = "0.12"
|
hmac = "0.12"
|
||||||
crc32fast = "1.3"
|
crc32fast = "1.4"
|
||||||
|
zeroize = { version = "1.8", features = ["derive"] }
|
||||||
|
|
||||||
# Network
|
# Network
|
||||||
socket2 = { version = "0.5", features = ["all"] }
|
socket2 = { version = "0.5", features = ["all"] }
|
||||||
rustls = "0.22"
|
|
||||||
|
|
||||||
# Serial
|
# Serialization
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
toml = "0.8"
|
toml = "0.8"
|
||||||
|
|
||||||
# Utils
|
# Utils
|
||||||
bytes = "1.5"
|
bytes = "1.9"
|
||||||
thiserror = "1.0"
|
thiserror = "2.0"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
parking_lot = "0.12"
|
parking_lot = "0.12"
|
||||||
dashmap = "5.5"
|
dashmap = "5.5"
|
||||||
lru = "0.12"
|
lru = "0.12"
|
||||||
rand = "0.8"
|
rand = "0.9"
|
||||||
chrono = { version = "0.4", features = ["serde"] }
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
hex = "0.4"
|
hex = "0.4"
|
||||||
base64 = "0.21"
|
base64 = "0.22"
|
||||||
url = "2.5"
|
url = "2.5"
|
||||||
regex = "1.10"
|
regex = "1.11"
|
||||||
once_cell = "1.19"
|
crossbeam-queue = "0.3"
|
||||||
|
|
||||||
# HTTP
|
# HTTP
|
||||||
reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false }
|
reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio-test = "0.4"
|
tokio-test = "0.4"
|
||||||
criterion = "0.5"
|
criterion = "0.5"
|
||||||
proptest = "1.4"
|
proptest = "1.4"
|
||||||
|
futures = "0.3"
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
name = "crypto_bench"
|
name = "crypto_bench"
|
||||||
|
|||||||
425
README.md
425
README.md
@@ -1,2 +1,423 @@
|
|||||||
# telemt
|
# Telemt - MTProxy on Rust + Tokio
|
||||||
MTProxy for Telegram on Rust + Tokio
|
|
||||||
|
**Telemt** is a fast, secure, and feature-rich server written in Rust: it fully implements the official Telegram proxy algo and adds many production-ready improvements such as connection pooling, replay protection, detailed statistics, masking from "prying" eyes
|
||||||
|
|
||||||
|
## Emergency
|
||||||
|
**Важное сообщение для пользователей из России**
|
||||||
|
|
||||||
|
Мы работаем над проектом с Нового года и сейчас готовим новый релиз - 1.2
|
||||||
|
|
||||||
|
В нём имплементируется поддержка Middle Proxy Protocol - основного терминатора для Ad Tag:
|
||||||
|
работа над ним идёт с 6 ферваля, а уже 10 февраля произошли "громкие события"...
|
||||||
|
|
||||||
|
Если у вас есть компетенции в асинхронных сетевых приложениях - мы открыты к предложениям и pull requests
|
||||||
|
|
||||||
|
**Important message for users from Russia**
|
||||||
|
|
||||||
|
We've been working on the project since December 30 and are currently preparing a new release – 1.2
|
||||||
|
|
||||||
|
It implements support for the Middle Proxy Protocol – the primary point for the Ad Tag:
|
||||||
|
development on it started on February 6th, and by February 10th, "big activity" in Russia had already "taken place"...
|
||||||
|
|
||||||
|
If you have expertise in asynchronous network applications – we are open to ideas and pull requests!
|
||||||
|
|
||||||
|
# Features
|
||||||
|
💥 The configuration structure has changed since version 1.1.0.0, change it in your environment!
|
||||||
|
|
||||||
|
⚓ Our implementation of **TLS-fronting** is one of the most deeply debugged, focused, advanced and *almost* **"behaviorally consistent to real"**: we are confident we have it right - [see evidence on our validation and traces](#recognizability-for-dpi-and-crawler)
|
||||||
|
|
||||||
|
# GOTO
|
||||||
|
- [Features](#features)
|
||||||
|
- [Quick Start Guide](#quick-start-guide)
|
||||||
|
- [How to use?](#how-to-use)
|
||||||
|
- [Systemd Method](#telemt-via-systemd)
|
||||||
|
- [Configuration](#configuration)
|
||||||
|
- [Minimal Configuration](#minimal-configuration-for-first-start)
|
||||||
|
- [Advanced](#advanced)
|
||||||
|
- [Adtag](#adtag)
|
||||||
|
- [Listening and Announce IPs](#listening-and-announce-ips)
|
||||||
|
- [Upstream Manager](#upstream-manager)
|
||||||
|
- [IP](#bind-on-ip)
|
||||||
|
- [SOCKS](#socks45-as-upstream)
|
||||||
|
- [FAQ](#faq)
|
||||||
|
- [Recognizability for DPI + crawler](#recognizability-for-dpi-and-crawler)
|
||||||
|
- [Telegram Calls](#telegram-calls-via-mtproxy)
|
||||||
|
- [DPI](#how-does-dpi-see-mtproxy-tls)
|
||||||
|
- [Whitelist on Network Level](#whitelist-on-ip)
|
||||||
|
- [Build](#build)
|
||||||
|
- [Why Rust?](#why-rust)
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Full support for all official MTProto proxy modes:
|
||||||
|
- Classic
|
||||||
|
- Secure - with `dd` prefix
|
||||||
|
- Fake TLS - with `ee` prefix + SNI fronting
|
||||||
|
- Replay attack protection
|
||||||
|
- Optional traffic masking: forward unrecognized connections to a real web server, e.g. GitHub 🤪
|
||||||
|
- Configurable keepalives + timeouts + IPv6 and "Fast Mode"
|
||||||
|
- Graceful shutdown on Ctrl+C
|
||||||
|
- Extensive logging via `trace` and `debug` with `RUST_LOG` method
|
||||||
|
|
||||||
|
## Quick Start Guide
|
||||||
|
**This software is designed for Debian-based OS: in addition to Debian, these are Ubuntu, Mint, Kali, MX and many other Linux**
|
||||||
|
1. Download release
|
||||||
|
```bash
|
||||||
|
wget https://github.com/telemt/telemt/releases/latest/download/telemt
|
||||||
|
```
|
||||||
|
2. Move to Bin Folder
|
||||||
|
```bash
|
||||||
|
mv telemt /bin
|
||||||
|
```
|
||||||
|
4. Make Executable
|
||||||
|
```bash
|
||||||
|
chmod +x /bin/telemt
|
||||||
|
```
|
||||||
|
5. Go to [How to use?](#how-to-use) section for for further steps
|
||||||
|
|
||||||
|
## How to use?
|
||||||
|
### Telemt via Systemd
|
||||||
|
**This instruction "assume" that you:**
|
||||||
|
- logged in as root or executed `su -` / `sudo su`
|
||||||
|
- you already have an assembled and executable `telemt` in /bin folder as a result of the [Quick Start Guide](#quick-start-guide) or [Build](#build)
|
||||||
|
|
||||||
|
**0. Check port and generate secrets**
|
||||||
|
|
||||||
|
The port you have selected for use should be MISSING from the list, when:
|
||||||
|
```bash
|
||||||
|
netstat -lnp
|
||||||
|
```
|
||||||
|
|
||||||
|
Generate 16 bytes/32 characters HEX with OpenSSL or another way:
|
||||||
|
```bash
|
||||||
|
openssl rand -hex 16
|
||||||
|
```
|
||||||
|
OR
|
||||||
|
```bash
|
||||||
|
xxd -l 16 -p /dev/urandom
|
||||||
|
```
|
||||||
|
OR
|
||||||
|
```bash
|
||||||
|
python3 -c 'import os; print(os.urandom(16).hex())'
|
||||||
|
```
|
||||||
|
|
||||||
|
**1. Place your config to /etc/telemt.toml**
|
||||||
|
|
||||||
|
Open nano
|
||||||
|
```bash
|
||||||
|
nano /etc/telemt.toml
|
||||||
|
```
|
||||||
|
paste your config from [Configuration](#configuration) section
|
||||||
|
|
||||||
|
then Ctrl+X -> Y -> Enter to save
|
||||||
|
|
||||||
|
**2. Create service on /etc/systemd/system/telemt.service**
|
||||||
|
|
||||||
|
Open nano
|
||||||
|
```bash
|
||||||
|
nano /etc/systemd/system/telemt.service
|
||||||
|
```
|
||||||
|
paste this Systemd Module
|
||||||
|
```bash
|
||||||
|
[Unit]
|
||||||
|
Description=Telemt
|
||||||
|
After=network.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
WorkingDirectory=/bin
|
||||||
|
ExecStart=/bin/telemt /etc/telemt.toml
|
||||||
|
Restart=on-failure
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
|
```
|
||||||
|
then Ctrl+X -> Y -> Enter to save
|
||||||
|
|
||||||
|
**3.** In Shell type `systemctl start telemt` - it must start with zero exit-code
|
||||||
|
|
||||||
|
**4.** In Shell type `systemctl status telemt` - there you can reach info about current MTProxy status
|
||||||
|
|
||||||
|
**5.** In Shell type `systemctl enable telemt` - then telemt will start with system startup, after the network is up
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
### Minimal Configuration for First Start
|
||||||
|
```toml
|
||||||
|
# === UI ===
|
||||||
|
# Users to show in the startup log (tg:// links)
|
||||||
|
show_link = ["hello"]
|
||||||
|
|
||||||
|
# === General Settings ===
|
||||||
|
[general]
|
||||||
|
prefer_ipv6 = false
|
||||||
|
fast_mode = true
|
||||||
|
use_middle_proxy = false
|
||||||
|
# ad_tag = "..."
|
||||||
|
|
||||||
|
[general.modes]
|
||||||
|
classic = false
|
||||||
|
secure = false
|
||||||
|
tls = true
|
||||||
|
|
||||||
|
# === Server Binding ===
|
||||||
|
[server]
|
||||||
|
port = 443
|
||||||
|
listen_addr_ipv4 = "0.0.0.0"
|
||||||
|
listen_addr_ipv6 = "::"
|
||||||
|
# metrics_port = 9090
|
||||||
|
# metrics_whitelist = ["127.0.0.1", "::1"]
|
||||||
|
|
||||||
|
# Listen on multiple interfaces/IPs (overrides listen_addr_*)
|
||||||
|
[[server.listeners]]
|
||||||
|
ip = "0.0.0.0"
|
||||||
|
# announce_ip = "1.2.3.4" # Optional: Public IP for tg:// links
|
||||||
|
|
||||||
|
[[server.listeners]]
|
||||||
|
ip = "::"
|
||||||
|
|
||||||
|
# === Timeouts (in seconds) ===
|
||||||
|
[timeouts]
|
||||||
|
client_handshake = 15
|
||||||
|
tg_connect = 10
|
||||||
|
client_keepalive = 60
|
||||||
|
client_ack = 300
|
||||||
|
|
||||||
|
# === Anti-Censorship & Masking ===
|
||||||
|
[censorship]
|
||||||
|
tls_domain = "petrovich.ru"
|
||||||
|
mask = true
|
||||||
|
mask_port = 443
|
||||||
|
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
|
||||||
|
fake_cert_len = 2048
|
||||||
|
|
||||||
|
# === Access Control & Users ===
|
||||||
|
# username "hello" is used for example
|
||||||
|
[access]
|
||||||
|
replay_check_len = 65536
|
||||||
|
ignore_time_skew = false
|
||||||
|
|
||||||
|
[access.users]
|
||||||
|
# format: "username" = "32_hex_chars_secret"
|
||||||
|
hello = "00000000000000000000000000000000"
|
||||||
|
|
||||||
|
# [access.user_max_tcp_conns]
|
||||||
|
# hello = 50
|
||||||
|
|
||||||
|
# [access.user_data_quota]
|
||||||
|
# hello = 1073741824 # 1 GB
|
||||||
|
|
||||||
|
# === Upstreams & Routing ===
|
||||||
|
# By default, direct connection is used, but you can add SOCKS proxy
|
||||||
|
|
||||||
|
# Direct - Default
|
||||||
|
[[upstreams]]
|
||||||
|
type = "direct"
|
||||||
|
enabled = true
|
||||||
|
weight = 10
|
||||||
|
|
||||||
|
# SOCKS5
|
||||||
|
# [[upstreams]]
|
||||||
|
# type = "socks5"
|
||||||
|
# address = "127.0.0.1:9050"
|
||||||
|
# enabled = false
|
||||||
|
# weight = 1
|
||||||
|
```
|
||||||
|
### Advanced
|
||||||
|
#### Adtag
|
||||||
|
To use channel advertising and usage statistics from Telegram, get Adtag from [@mtproxybot](https://t.me/mtproxybot), add this parameter to section `[General]`
|
||||||
|
```toml
|
||||||
|
ad_tag = "00000000000000000000000000000000" # Replace zeros to your adtag from @mtproxybot
|
||||||
|
```
|
||||||
|
#### Listening and Announce IPs
|
||||||
|
To specify listening address and/or address in links, add to section `[[server.listeners]]` of config.toml:
|
||||||
|
```toml
|
||||||
|
[[server.listeners]]
|
||||||
|
ip = "0.0.0.0" # 0.0.0.0 = all IPs; your IP = specific listening
|
||||||
|
announce_ip = "1.2.3.4" # IP in links; comment with # if not used
|
||||||
|
```
|
||||||
|
#### Upstream Manager
|
||||||
|
To specify upstream, add to section `[[upstreams]]` of config.toml:
|
||||||
|
##### Bind on IP
|
||||||
|
```toml
|
||||||
|
[[upstreams]]
|
||||||
|
type = "direct"
|
||||||
|
weight = 1
|
||||||
|
enabled = true
|
||||||
|
interface = "192.168.1.100" # Change to your outgoing IP
|
||||||
|
```
|
||||||
|
##### SOCKS4/5 as Upstream
|
||||||
|
- Without Auth:
|
||||||
|
```toml
|
||||||
|
[[upstreams]]
|
||||||
|
type = "socks5" # Specify SOCKS4 or SOCKS5
|
||||||
|
address = "1.2.3.4:1234" # SOCKS-server Address
|
||||||
|
weight = 1 # Set Weight for Scenarios
|
||||||
|
enabled = true
|
||||||
|
```
|
||||||
|
|
||||||
|
- With Auth:
|
||||||
|
```toml
|
||||||
|
[[upstreams]]
|
||||||
|
type = "socks5" # Specify SOCKS4 or SOCKS5
|
||||||
|
address = "1.2.3.4:1234" # SOCKS-server Address
|
||||||
|
username = "user" # Username for Auth on SOCKS-server
|
||||||
|
password = "pass" # Password for Auth on SOCKS-server
|
||||||
|
weight = 1 # Set Weight for Scenarios
|
||||||
|
enabled = true
|
||||||
|
```
|
||||||
|
|
||||||
|
## FAQ
|
||||||
|
### Recognizability for DPI and crawler
|
||||||
|
Since version 1.1.0.0, we have debugged masking perfectly: for all clients without "presenting" a key,
|
||||||
|
we transparently direct traffic to the target host!
|
||||||
|
|
||||||
|
- We consider this a breakthrough aspect, which has no stable analogues today
|
||||||
|
- Based on this: if `telemt` configured correctly, **TLS mode is completely identical to real-life handshake + communication** with a specified host
|
||||||
|
- Here is our evidence:
|
||||||
|
- 212.220.88.77 - "dummy" host, running `telemt`
|
||||||
|
- `petrovich.ru` - `tls` + `masking` host, in HEX: `706574726f766963682e7275`
|
||||||
|
- **No MITM + No Fake Certificates/Crypto** = pure transparent *TCP Splice* to "best" upstream: MTProxy or tls/mask-host:
|
||||||
|
- DPI see legitimate HTTPS to `tls_host`, including *valid chain-of-trust* and entropy
|
||||||
|
- Crawlers completely satisfied receiving responses from `mask_host`
|
||||||
|
#### Client WITH secret-key accesses the MTProxy resource:
|
||||||
|
|
||||||
|
<img width="360" height="439" alt="telemt" src="https://github.com/user-attachments/assets/39352afb-4a11-4ecc-9d91-9e8cfb20607d" />
|
||||||
|
|
||||||
|
#### Client WITHOUT secret-key gets transparent access to the specified resource:
|
||||||
|
- with trusted certificate
|
||||||
|
- with original handshake
|
||||||
|
- with full request-response way
|
||||||
|
- with low-latency overhead
|
||||||
|
```bash
|
||||||
|
root@debian:~/telemt# curl -v -I --resolve petrovich.ru:443:212.220.88.77 https://petrovich.ru/
|
||||||
|
* Added petrovich.ru:443:212.220.88.77 to DNS cache
|
||||||
|
* Hostname petrovich.ru was found in DNS cache
|
||||||
|
* Trying 212.220.88.77:443...
|
||||||
|
* Connected to petrovich.ru (212.220.88.77) port 443 (#0)
|
||||||
|
* ALPN: offers h2,http/1.1
|
||||||
|
* TLSv1.3 (OUT), TLS handshake, Client hello (1):
|
||||||
|
* CAfile: /etc/ssl/certs/ca-certificates.crt
|
||||||
|
* CApath: /etc/ssl/certs
|
||||||
|
* TLSv1.3 (IN), TLS handshake, Server hello (2):
|
||||||
|
* TLSv1.3 (IN), TLS handshake, Encrypted Extensions (8):
|
||||||
|
* TLSv1.3 (IN), TLS handshake, Certificate (11):
|
||||||
|
* TLSv1.3 (IN), TLS handshake, CERT verify (15):
|
||||||
|
* TLSv1.3 (IN), TLS handshake, Finished (20):
|
||||||
|
* TLSv1.3 (OUT), TLS change cipher, Change cipher spec (1):
|
||||||
|
* TLSv1.3 (OUT), TLS handshake, Finished (20):
|
||||||
|
* SSL connection using TLSv1.3 / TLS_AES_256_GCM_SHA384
|
||||||
|
* ALPN: server did not agree on a protocol. Uses default.
|
||||||
|
* Server certificate:
|
||||||
|
* subject: C=RU; ST=Saint Petersburg; L=Saint Petersburg; O=STD Petrovich; CN=*.petrovich.ru
|
||||||
|
* start date: Jan 28 11:21:01 2025 GMT
|
||||||
|
* expire date: Mar 1 11:21:00 2026 GMT
|
||||||
|
* subjectAltName: host "petrovich.ru" matched cert's "petrovich.ru"
|
||||||
|
* issuer: C=BE; O=GlobalSign nv-sa; CN=GlobalSign RSA OV SSL CA 2018
|
||||||
|
* SSL certificate verify ok.
|
||||||
|
* using HTTP/1.x
|
||||||
|
> HEAD / HTTP/1.1
|
||||||
|
> Host: petrovich.ru
|
||||||
|
> User-Agent: curl/7.88.1
|
||||||
|
> Accept: */*
|
||||||
|
>
|
||||||
|
* TLSv1.3 (IN), TLS handshake, Newsession Ticket (4):
|
||||||
|
* TLSv1.3 (IN), TLS handshake, Newsession Ticket (4):
|
||||||
|
* old SSL session ID is stale, removing
|
||||||
|
< HTTP/1.1 200 OK
|
||||||
|
HTTP/1.1 200 OK
|
||||||
|
< Server: Variti/0.9.3a
|
||||||
|
Server: Variti/0.9.3a
|
||||||
|
< Date: Thu, 01 Jan 2026 00:0000 GMT
|
||||||
|
Date: Thu, 01 Jan 2026 00:0000 GMT
|
||||||
|
< Access-Control-Allow-Origin: *
|
||||||
|
Access-Control-Allow-Origin: *
|
||||||
|
< Content-Type: text/html
|
||||||
|
Content-Type: text/html
|
||||||
|
< Cache-Control: no-store
|
||||||
|
Cache-Control: no-store
|
||||||
|
< Expires: Thu, 01 Jan 2026 00:0000 GMT
|
||||||
|
Expires: Thu, 01 Jan 2026 00:0000 GMT
|
||||||
|
< Pragma: no-cache
|
||||||
|
Pragma: no-cache
|
||||||
|
< Set-Cookie: ipp_uid=XXXXX/XXXXX/XXXXX==; Expires=Tue, 31 Dec 2040 23:59:59 GMT; Domain=.petrovich.ru; Path=/
|
||||||
|
Set-Cookie: ipp_uid=XXXXX/XXXXX/XXXXX==; Expires=Tue, 31 Dec 2040 23:59:59 GMT; Domain=.petrovich.ru; Path=/
|
||||||
|
< Content-Type: text/html
|
||||||
|
Content-Type: text/html
|
||||||
|
< Content-Length: 31253
|
||||||
|
Content-Length: 31253
|
||||||
|
< Connection: keep-alive
|
||||||
|
Connection: keep-alive
|
||||||
|
< Keep-Alive: timeout=60
|
||||||
|
Keep-Alive: timeout=60
|
||||||
|
|
||||||
|
<
|
||||||
|
* Connection #0 to host petrovich.ru left intact
|
||||||
|
|
||||||
|
```
|
||||||
|
- We challenged ourselves, we kept trying and we didn't only *beat the air*: now, we have something to show you
|
||||||
|
- Do not just take our word for it? - This is great and we respect that: you can build your own `telemt` or download a build and check it right now
|
||||||
|
### Telegram Calls via MTProxy
|
||||||
|
- Telegram architecture **does NOT allow calls via MTProxy**, but only via SOCKS5, which cannot be obfuscated
|
||||||
|
### How does DPI see MTProxy TLS?
|
||||||
|
- DPI sees MTProxy in Fake TLS (ee) mode as TLS 1.3
|
||||||
|
- the SNI you specify sends both the client and the server;
|
||||||
|
- ALPN is similar to HTTP 1.1/2;
|
||||||
|
- high entropy, which is normal for AES-encrypted traffic;
|
||||||
|
### Whitelist on IP
|
||||||
|
- MTProxy cannot work when there is:
|
||||||
|
- no IP connectivity to the target host: Russian Whitelist on Mobile Networks - "Белый список"
|
||||||
|
- OR all TCP traffic is blocked
|
||||||
|
- OR high entropy/encrypted traffic is blocked: content filters at universities and critical infrastructure
|
||||||
|
- OR all TLS traffic is blocked
|
||||||
|
- OR specified port is blocked: use 443 to make it "like real"
|
||||||
|
- OR provided SNI is blocked: use "officially approved"/innocuous name
|
||||||
|
- like most protocols on the Internet;
|
||||||
|
- these situations are observed:
|
||||||
|
- in China behind the Great Firewall
|
||||||
|
- in Russia on mobile networks, less in wired networks
|
||||||
|
- in Iran during "activity"
|
||||||
|
|
||||||
|
|
||||||
|
## Build
|
||||||
|
```bash
|
||||||
|
# Cloning repo
|
||||||
|
git clone https://github.com/telemt/telemt
|
||||||
|
# Changing Directory to telemt
|
||||||
|
cd telemt
|
||||||
|
# Starting Release Build
|
||||||
|
cargo build --release
|
||||||
|
# Move to /bin
|
||||||
|
mv ./target/release/telemt /bin
|
||||||
|
# Make executable
|
||||||
|
chmod +x /bin/telemt
|
||||||
|
# Lets go!
|
||||||
|
telemt config.toml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Why Rust?
|
||||||
|
- Long-running reliability and idempotent behavior
|
||||||
|
- Rust’s deterministic resource management - RAII
|
||||||
|
- No garbage collector
|
||||||
|
- Memory safety and reduced attack surface
|
||||||
|
- Tokio's asynchronous architecture
|
||||||
|
|
||||||
|
## Issues
|
||||||
|
- ✅ [SOCKS5 as Upstream](https://github.com/telemt/telemt/issues/1) -> added Upstream Management
|
||||||
|
- ✅ [iOS - Media Upload Hanging-in-Loop](https://github.com/telemt/telemt/issues/2)
|
||||||
|
|
||||||
|
## Roadmap
|
||||||
|
- Public IP in links
|
||||||
|
- Config Reload-on-fly
|
||||||
|
- Bind to device or IP for outbound/inbound connections
|
||||||
|
- Adtag Support per SNI / Secret
|
||||||
|
- Fail-fast on start + Fail-soft on runtime (only WARN/ERROR)
|
||||||
|
- Zero-copy, minimal allocs on hotpath
|
||||||
|
- DC Healthchecks + global fallback
|
||||||
|
- No global mutable state
|
||||||
|
- Client isolation + Fair Bandwidth
|
||||||
|
- Backpressure-aware IO
|
||||||
|
- "Secret Policy" - SNI / Secret Routing :D
|
||||||
|
- Multi-upstream Balancer and Failover
|
||||||
|
- Strict FSM per handshake
|
||||||
|
- Session-based Antireplay with Sliding window, non-broking reconnects
|
||||||
|
- Web Control: statistic, state of health, latency, client experience...
|
||||||
|
|||||||
84
config.toml
84
config.toml
@@ -1,13 +1,79 @@
|
|||||||
port = 443
|
# === UI ===
|
||||||
|
# Users to show in the startup log (tg:// links)
|
||||||
|
show_link = ["hello"]
|
||||||
|
|
||||||
[users]
|
# === General Settings ===
|
||||||
user1 = "00000000000000000000000000000000"
|
[general]
|
||||||
|
prefer_ipv6 = false
|
||||||
|
fast_mode = true
|
||||||
|
use_middle_proxy = true
|
||||||
|
ad_tag = "00000000000000000000000000000000"
|
||||||
|
|
||||||
[modes]
|
# Log level: debug | verbose | normal | silent
|
||||||
classic = true
|
# Can be overridden with --silent or --log-level CLI flags
|
||||||
secure = true
|
# RUST_LOG env var takes absolute priority over all of these
|
||||||
|
log_level = "normal"
|
||||||
|
|
||||||
|
[general.modes]
|
||||||
|
classic = false
|
||||||
|
secure = false
|
||||||
tls = true
|
tls = true
|
||||||
|
|
||||||
tls_domain = "www.github.com"
|
# === Server Binding ===
|
||||||
fast_mode = true
|
[server]
|
||||||
prefer_ipv6 = false
|
port = 443
|
||||||
|
listen_addr_ipv4 = "0.0.0.0"
|
||||||
|
listen_addr_ipv6 = "::"
|
||||||
|
# metrics_port = 9090
|
||||||
|
# metrics_whitelist = ["127.0.0.1", "::1"]
|
||||||
|
|
||||||
|
# Listen on multiple interfaces/IPs (overrides listen_addr_*)
|
||||||
|
[[server.listeners]]
|
||||||
|
ip = "0.0.0.0"
|
||||||
|
# announce_ip = "1.2.3.4" # Optional: Public IP for tg:// links
|
||||||
|
|
||||||
|
[[server.listeners]]
|
||||||
|
ip = "::"
|
||||||
|
|
||||||
|
# === Timeouts (in seconds) ===
|
||||||
|
[timeouts]
|
||||||
|
client_handshake = 15
|
||||||
|
tg_connect = 10
|
||||||
|
client_keepalive = 60
|
||||||
|
client_ack = 300
|
||||||
|
|
||||||
|
# === Anti-Censorship & Masking ===
|
||||||
|
[censorship]
|
||||||
|
tls_domain = "google.ru"
|
||||||
|
mask = true
|
||||||
|
mask_port = 443
|
||||||
|
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
|
||||||
|
fake_cert_len = 2048
|
||||||
|
|
||||||
|
# === Access Control & Users ===
|
||||||
|
[access]
|
||||||
|
replay_check_len = 65536
|
||||||
|
replay_window_secs = 1800
|
||||||
|
ignore_time_skew = false
|
||||||
|
|
||||||
|
[access.users]
|
||||||
|
# format: "username" = "32_hex_chars_secret"
|
||||||
|
hello = "00000000000000000000000000000000"
|
||||||
|
|
||||||
|
# [access.user_max_tcp_conns]
|
||||||
|
# hello = 50
|
||||||
|
|
||||||
|
# [access.user_data_quota]
|
||||||
|
# hello = 1073741824 # 1 GB
|
||||||
|
|
||||||
|
# === Upstreams & Routing ===
|
||||||
|
[[upstreams]]
|
||||||
|
type = "direct"
|
||||||
|
enabled = true
|
||||||
|
weight = 10
|
||||||
|
|
||||||
|
# [[upstreams]]
|
||||||
|
# type = "socks5"
|
||||||
|
# address = "127.0.0.1:9050"
|
||||||
|
# enabled = false
|
||||||
|
# weight = 1
|
||||||
300
src/cli.rs
Normal file
300
src/cli.rs
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
//! CLI commands: --init (fire-and-forget setup)
|
||||||
|
|
||||||
|
use std::fs;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::process::Command;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
/// Options for the init command
|
||||||
|
pub struct InitOptions {
|
||||||
|
pub port: u16,
|
||||||
|
pub domain: String,
|
||||||
|
pub secret: Option<String>,
|
||||||
|
pub username: String,
|
||||||
|
pub config_dir: PathBuf,
|
||||||
|
pub no_start: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for InitOptions {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
port: 443,
|
||||||
|
domain: "www.google.com".to_string(),
|
||||||
|
secret: None,
|
||||||
|
username: "user".to_string(),
|
||||||
|
config_dir: PathBuf::from("/etc/telemt"),
|
||||||
|
no_start: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse --init subcommand options from CLI args.
|
||||||
|
///
|
||||||
|
/// Returns `Some(InitOptions)` if `--init` was found, `None` otherwise.
|
||||||
|
pub fn parse_init_args(args: &[String]) -> Option<InitOptions> {
|
||||||
|
if !args.iter().any(|a| a == "--init") {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut opts = InitOptions::default();
|
||||||
|
let mut i = 0;
|
||||||
|
|
||||||
|
while i < args.len() {
|
||||||
|
match args[i].as_str() {
|
||||||
|
"--port" => {
|
||||||
|
i += 1;
|
||||||
|
if i < args.len() {
|
||||||
|
opts.port = args[i].parse().unwrap_or(443);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--domain" => {
|
||||||
|
i += 1;
|
||||||
|
if i < args.len() {
|
||||||
|
opts.domain = args[i].clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--secret" => {
|
||||||
|
i += 1;
|
||||||
|
if i < args.len() {
|
||||||
|
opts.secret = Some(args[i].clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--user" => {
|
||||||
|
i += 1;
|
||||||
|
if i < args.len() {
|
||||||
|
opts.username = args[i].clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--config-dir" => {
|
||||||
|
i += 1;
|
||||||
|
if i < args.len() {
|
||||||
|
opts.config_dir = PathBuf::from(&args[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--no-start" => {
|
||||||
|
opts.no_start = true;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the fire-and-forget setup.
|
||||||
|
pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
eprintln!("[telemt] Fire-and-forget setup");
|
||||||
|
eprintln!();
|
||||||
|
|
||||||
|
// 1. Generate or validate secret
|
||||||
|
let secret = match opts.secret {
|
||||||
|
Some(s) => {
|
||||||
|
if s.len() != 32 || !s.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||||
|
eprintln!("[error] Secret must be exactly 32 hex characters");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
s
|
||||||
|
}
|
||||||
|
None => generate_secret(),
|
||||||
|
};
|
||||||
|
|
||||||
|
eprintln!("[+] Secret: {}", secret);
|
||||||
|
eprintln!("[+] User: {}", opts.username);
|
||||||
|
eprintln!("[+] Port: {}", opts.port);
|
||||||
|
eprintln!("[+] Domain: {}", opts.domain);
|
||||||
|
|
||||||
|
// 2. Create config directory
|
||||||
|
fs::create_dir_all(&opts.config_dir)?;
|
||||||
|
let config_path = opts.config_dir.join("config.toml");
|
||||||
|
|
||||||
|
// 3. Write config
|
||||||
|
let config_content = generate_config(&opts.username, &secret, opts.port, &opts.domain);
|
||||||
|
fs::write(&config_path, &config_content)?;
|
||||||
|
eprintln!("[+] Config written to {}", config_path.display());
|
||||||
|
|
||||||
|
// 4. Write systemd unit
|
||||||
|
let exe_path = std::env::current_exe()
|
||||||
|
.unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt"));
|
||||||
|
|
||||||
|
let unit_path = Path::new("/etc/systemd/system/telemt.service");
|
||||||
|
let unit_content = generate_systemd_unit(&exe_path, &config_path);
|
||||||
|
|
||||||
|
match fs::write(unit_path, &unit_content) {
|
||||||
|
Ok(()) => {
|
||||||
|
eprintln!("[+] Systemd unit written to {}", unit_path.display());
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("[!] Cannot write systemd unit (run as root?): {}", e);
|
||||||
|
eprintln!("[!] Manual unit file content:");
|
||||||
|
eprintln!("{}", unit_content);
|
||||||
|
|
||||||
|
// Still print links and config
|
||||||
|
print_links(&opts.username, &secret, opts.port, &opts.domain);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Reload systemd
|
||||||
|
run_cmd("systemctl", &["daemon-reload"]);
|
||||||
|
|
||||||
|
// 6. Enable service
|
||||||
|
run_cmd("systemctl", &["enable", "telemt.service"]);
|
||||||
|
eprintln!("[+] Service enabled");
|
||||||
|
|
||||||
|
// 7. Start service (unless --no-start)
|
||||||
|
if !opts.no_start {
|
||||||
|
run_cmd("systemctl", &["start", "telemt.service"]);
|
||||||
|
eprintln!("[+] Service started");
|
||||||
|
|
||||||
|
// Brief delay then check status
|
||||||
|
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||||
|
let status = Command::new("systemctl")
|
||||||
|
.args(["is-active", "telemt.service"])
|
||||||
|
.output();
|
||||||
|
|
||||||
|
match status {
|
||||||
|
Ok(out) if out.status.success() => {
|
||||||
|
eprintln!("[+] Service is running");
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
eprintln!("[!] Service may not have started correctly");
|
||||||
|
eprintln!("[!] Check: journalctl -u telemt.service -n 20");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
eprintln!("[+] Service not started (--no-start)");
|
||||||
|
eprintln!("[+] Start manually: systemctl start telemt.service");
|
||||||
|
}
|
||||||
|
|
||||||
|
eprintln!();
|
||||||
|
|
||||||
|
// 8. Print links
|
||||||
|
print_links(&opts.username, &secret, opts.port, &opts.domain);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_secret() -> String {
|
||||||
|
let mut rng = rand::rng();
|
||||||
|
let bytes: Vec<u8> = (0..16).map(|_| rng.random::<u8>()).collect();
|
||||||
|
hex::encode(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_config(username: &str, secret: &str, port: u16, domain: &str) -> String {
|
||||||
|
format!(
|
||||||
|
r#"# Telemt MTProxy — auto-generated config
|
||||||
|
# Re-run `telemt --init` to regenerate
|
||||||
|
|
||||||
|
show_link = ["{username}"]
|
||||||
|
|
||||||
|
[general]
|
||||||
|
prefer_ipv6 = false
|
||||||
|
fast_mode = true
|
||||||
|
use_middle_proxy = false
|
||||||
|
log_level = "normal"
|
||||||
|
|
||||||
|
[general.modes]
|
||||||
|
classic = false
|
||||||
|
secure = false
|
||||||
|
tls = true
|
||||||
|
|
||||||
|
[server]
|
||||||
|
port = {port}
|
||||||
|
listen_addr_ipv4 = "0.0.0.0"
|
||||||
|
listen_addr_ipv6 = "::"
|
||||||
|
|
||||||
|
[[server.listeners]]
|
||||||
|
ip = "0.0.0.0"
|
||||||
|
|
||||||
|
[[server.listeners]]
|
||||||
|
ip = "::"
|
||||||
|
|
||||||
|
[timeouts]
|
||||||
|
client_handshake = 15
|
||||||
|
tg_connect = 10
|
||||||
|
client_keepalive = 60
|
||||||
|
client_ack = 300
|
||||||
|
|
||||||
|
[censorship]
|
||||||
|
tls_domain = "{domain}"
|
||||||
|
mask = true
|
||||||
|
mask_port = 443
|
||||||
|
fake_cert_len = 2048
|
||||||
|
|
||||||
|
[access]
|
||||||
|
replay_check_len = 65536
|
||||||
|
replay_window_secs = 1800
|
||||||
|
ignore_time_skew = false
|
||||||
|
|
||||||
|
[access.users]
|
||||||
|
{username} = "{secret}"
|
||||||
|
|
||||||
|
[[upstreams]]
|
||||||
|
type = "direct"
|
||||||
|
enabled = true
|
||||||
|
weight = 10
|
||||||
|
"#,
|
||||||
|
username = username,
|
||||||
|
secret = secret,
|
||||||
|
port = port,
|
||||||
|
domain = domain,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_systemd_unit(exe_path: &Path, config_path: &Path) -> String {
|
||||||
|
format!(
|
||||||
|
r#"[Unit]
|
||||||
|
Description=Telemt MTProxy
|
||||||
|
Documentation=https://github.com/nicepkg/telemt
|
||||||
|
After=network-online.target
|
||||||
|
Wants=network-online.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
ExecStart={exe} {config}
|
||||||
|
Restart=always
|
||||||
|
RestartSec=5
|
||||||
|
LimitNOFILE=65535
|
||||||
|
# Security hardening
|
||||||
|
NoNewPrivileges=true
|
||||||
|
ProtectSystem=strict
|
||||||
|
ProtectHome=true
|
||||||
|
ReadWritePaths=/etc/telemt
|
||||||
|
PrivateTmp=true
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
|
"#,
|
||||||
|
exe = exe_path.display(),
|
||||||
|
config = config_path.display(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_cmd(cmd: &str, args: &[&str]) {
|
||||||
|
match Command::new(cmd).args(args).output() {
|
||||||
|
Ok(output) => {
|
||||||
|
if !output.status.success() {
|
||||||
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
|
eprintln!("[!] {} {} failed: {}", cmd, args.join(" "), stderr.trim());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("[!] Failed to run {} {}: {}", cmd, args.join(" "), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_links(username: &str, secret: &str, port: u16, domain: &str) {
|
||||||
|
let domain_hex = hex::encode(domain);
|
||||||
|
|
||||||
|
println!("=== Proxy Links ===");
|
||||||
|
println!("[{}]", username);
|
||||||
|
println!(" EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}",
|
||||||
|
port, secret, domain_hex);
|
||||||
|
println!();
|
||||||
|
println!("Replace YOUR_SERVER_IP with your server's public IP.");
|
||||||
|
println!("The proxy will auto-detect and display the correct link on startup.");
|
||||||
|
println!("Check: journalctl -u telemt.service | head -30");
|
||||||
|
println!("===================");
|
||||||
|
}
|
||||||
@@ -1,12 +1,89 @@
|
|||||||
//! Configuration
|
//! Configuration
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::IpAddr;
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use crate::error::{ProxyError, Result};
|
use crate::error::{ProxyError, Result};
|
||||||
|
|
||||||
|
// ============= Helper Defaults =============
|
||||||
|
|
||||||
|
fn default_true() -> bool { true }
|
||||||
|
fn default_port() -> u16 { 443 }
|
||||||
|
fn default_tls_domain() -> String { "www.google.com".to_string() }
|
||||||
|
fn default_mask_port() -> u16 { 443 }
|
||||||
|
fn default_replay_check_len() -> usize { 65536 }
|
||||||
|
fn default_replay_window_secs() -> u64 { 1800 }
|
||||||
|
fn default_handshake_timeout() -> u64 { 15 }
|
||||||
|
fn default_connect_timeout() -> u64 { 10 }
|
||||||
|
fn default_keepalive() -> u64 { 60 }
|
||||||
|
fn default_ack_timeout() -> u64 { 300 }
|
||||||
|
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
|
||||||
|
fn default_fake_cert_len() -> usize { 2048 }
|
||||||
|
fn default_weight() -> u16 { 1 }
|
||||||
|
fn default_metrics_whitelist() -> Vec<IpAddr> {
|
||||||
|
vec![
|
||||||
|
"127.0.0.1".parse().unwrap(),
|
||||||
|
"::1".parse().unwrap(),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Log Level =============
|
||||||
|
|
||||||
|
/// Logging verbosity level
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum LogLevel {
|
||||||
|
/// All messages including trace (trace + debug + info + warn + error)
|
||||||
|
Debug,
|
||||||
|
/// Detailed operational logs (debug + info + warn + error)
|
||||||
|
Verbose,
|
||||||
|
/// Standard operational logs (info + warn + error)
|
||||||
|
#[default]
|
||||||
|
Normal,
|
||||||
|
/// Minimal output: only warnings and errors (warn + error).
|
||||||
|
/// Startup messages (config, DC connectivity, proxy links) are always shown
|
||||||
|
/// via info! before the filter is applied.
|
||||||
|
Silent,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LogLevel {
|
||||||
|
/// Convert to tracing EnvFilter directive string
|
||||||
|
pub fn to_filter_str(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
LogLevel::Debug => "trace",
|
||||||
|
LogLevel::Verbose => "debug",
|
||||||
|
LogLevel::Normal => "info",
|
||||||
|
LogLevel::Silent => "warn",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse from a loose string (CLI argument)
|
||||||
|
pub fn from_str_loose(s: &str) -> Self {
|
||||||
|
match s.to_lowercase().as_str() {
|
||||||
|
"debug" | "trace" => LogLevel::Debug,
|
||||||
|
"verbose" => LogLevel::Verbose,
|
||||||
|
"normal" | "info" => LogLevel::Normal,
|
||||||
|
"silent" | "quiet" | "error" | "warn" => LogLevel::Silent,
|
||||||
|
_ => LogLevel::Normal,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for LogLevel {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
LogLevel::Debug => write!(f, "debug"),
|
||||||
|
LogLevel::Verbose => write!(f, "verbose"),
|
||||||
|
LogLevel::Normal => write!(f, "normal"),
|
||||||
|
LogLevel::Silent => write!(f, "silent"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Sub-Configs =============
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ProxyModes {
|
pub struct ProxyModes {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@@ -17,8 +94,6 @@ pub struct ProxyModes {
|
|||||||
pub tls: bool,
|
pub tls: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_true() -> bool { true }
|
|
||||||
|
|
||||||
impl Default for ProxyModes {
|
impl Default for ProxyModes {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self { classic: true, secure: true, tls: true }
|
Self { classic: true, secure: true, tls: true }
|
||||||
@@ -26,31 +101,10 @@ impl Default for ProxyModes {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ProxyConfig {
|
pub struct GeneralConfig {
|
||||||
#[serde(default = "default_port")]
|
|
||||||
pub port: u16,
|
|
||||||
|
|
||||||
#[serde(default)]
|
|
||||||
pub users: HashMap<String, String>,
|
|
||||||
|
|
||||||
#[serde(default)]
|
|
||||||
pub ad_tag: Option<String>,
|
|
||||||
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub modes: ProxyModes,
|
pub modes: ProxyModes,
|
||||||
|
|
||||||
#[serde(default = "default_tls_domain")]
|
|
||||||
pub tls_domain: String,
|
|
||||||
|
|
||||||
#[serde(default = "default_true")]
|
|
||||||
pub mask: bool,
|
|
||||||
|
|
||||||
#[serde(default)]
|
|
||||||
pub mask_host: Option<String>,
|
|
||||||
|
|
||||||
#[serde(default = "default_mask_port")]
|
|
||||||
pub mask_port: u16,
|
|
||||||
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub prefer_ipv6: bool,
|
pub prefer_ipv6: bool,
|
||||||
|
|
||||||
@@ -61,31 +115,29 @@ pub struct ProxyConfig {
|
|||||||
pub use_middle_proxy: bool,
|
pub use_middle_proxy: bool,
|
||||||
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub user_max_tcp_conns: HashMap<String, usize>,
|
pub ad_tag: Option<String>,
|
||||||
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub user_expirations: HashMap<String, DateTime<Utc>>,
|
pub log_level: LogLevel,
|
||||||
|
}
|
||||||
|
|
||||||
#[serde(default)]
|
impl Default for GeneralConfig {
|
||||||
pub user_data_quota: HashMap<String, u64>,
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
modes: ProxyModes::default(),
|
||||||
|
prefer_ipv6: false,
|
||||||
|
fast_mode: true,
|
||||||
|
use_middle_proxy: false,
|
||||||
|
ad_tag: None,
|
||||||
|
log_level: LogLevel::Normal,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[serde(default = "default_replay_check_len")]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub replay_check_len: usize,
|
pub struct ServerConfig {
|
||||||
|
#[serde(default = "default_port")]
|
||||||
#[serde(default)]
|
pub port: u16,
|
||||||
pub ignore_time_skew: bool,
|
|
||||||
|
|
||||||
#[serde(default = "default_handshake_timeout")]
|
|
||||||
pub client_handshake_timeout: u64,
|
|
||||||
|
|
||||||
#[serde(default = "default_connect_timeout")]
|
|
||||||
pub tg_connect_timeout: u64,
|
|
||||||
|
|
||||||
#[serde(default = "default_keepalive")]
|
|
||||||
pub client_keepalive: u64,
|
|
||||||
|
|
||||||
#[serde(default = "default_ack_timeout")]
|
|
||||||
pub client_ack_timeout: u64,
|
|
||||||
|
|
||||||
#[serde(default = "default_listen_addr")]
|
#[serde(default = "default_listen_addr")]
|
||||||
pub listen_addr_ipv4: String,
|
pub listen_addr_ipv4: String,
|
||||||
@@ -102,64 +154,205 @@ pub struct ProxyConfig {
|
|||||||
#[serde(default = "default_metrics_whitelist")]
|
#[serde(default = "default_metrics_whitelist")]
|
||||||
pub metrics_whitelist: Vec<IpAddr>,
|
pub metrics_whitelist: Vec<IpAddr>,
|
||||||
|
|
||||||
#[serde(default = "default_fake_cert_len")]
|
#[serde(default)]
|
||||||
pub fake_cert_len: usize,
|
pub listeners: Vec<ListenerConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_port() -> u16 { 443 }
|
impl Default for ServerConfig {
|
||||||
fn default_tls_domain() -> String { "www.google.com".to_string() }
|
|
||||||
fn default_mask_port() -> u16 { 443 }
|
|
||||||
fn default_replay_check_len() -> usize { 65536 }
|
|
||||||
fn default_handshake_timeout() -> u64 { 10 }
|
|
||||||
fn default_connect_timeout() -> u64 { 10 }
|
|
||||||
fn default_keepalive() -> u64 { 600 }
|
|
||||||
fn default_ack_timeout() -> u64 { 300 }
|
|
||||||
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
|
|
||||||
fn default_fake_cert_len() -> usize { 2048 }
|
|
||||||
|
|
||||||
fn default_metrics_whitelist() -> Vec<IpAddr> {
|
|
||||||
vec![
|
|
||||||
"127.0.0.1".parse().unwrap(),
|
|
||||||
"::1".parse().unwrap(),
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for ProxyConfig {
|
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
let mut users = HashMap::new();
|
|
||||||
users.insert("default".to_string(), "00000000000000000000000000000000".to_string());
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
port: default_port(),
|
port: default_port(),
|
||||||
users,
|
|
||||||
ad_tag: None,
|
|
||||||
modes: ProxyModes::default(),
|
|
||||||
tls_domain: default_tls_domain(),
|
|
||||||
mask: true,
|
|
||||||
mask_host: None,
|
|
||||||
mask_port: default_mask_port(),
|
|
||||||
prefer_ipv6: false,
|
|
||||||
fast_mode: true,
|
|
||||||
use_middle_proxy: false,
|
|
||||||
user_max_tcp_conns: HashMap::new(),
|
|
||||||
user_expirations: HashMap::new(),
|
|
||||||
user_data_quota: HashMap::new(),
|
|
||||||
replay_check_len: default_replay_check_len(),
|
|
||||||
ignore_time_skew: false,
|
|
||||||
client_handshake_timeout: default_handshake_timeout(),
|
|
||||||
tg_connect_timeout: default_connect_timeout(),
|
|
||||||
client_keepalive: default_keepalive(),
|
|
||||||
client_ack_timeout: default_ack_timeout(),
|
|
||||||
listen_addr_ipv4: default_listen_addr(),
|
listen_addr_ipv4: default_listen_addr(),
|
||||||
listen_addr_ipv6: Some("::".to_string()),
|
listen_addr_ipv6: Some("::".to_string()),
|
||||||
listen_unix_sock: None,
|
listen_unix_sock: None,
|
||||||
metrics_port: None,
|
metrics_port: None,
|
||||||
metrics_whitelist: default_metrics_whitelist(),
|
metrics_whitelist: default_metrics_whitelist(),
|
||||||
|
listeners: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct TimeoutsConfig {
|
||||||
|
#[serde(default = "default_handshake_timeout")]
|
||||||
|
pub client_handshake: u64,
|
||||||
|
|
||||||
|
#[serde(default = "default_connect_timeout")]
|
||||||
|
pub tg_connect: u64,
|
||||||
|
|
||||||
|
#[serde(default = "default_keepalive")]
|
||||||
|
pub client_keepalive: u64,
|
||||||
|
|
||||||
|
#[serde(default = "default_ack_timeout")]
|
||||||
|
pub client_ack: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for TimeoutsConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
client_handshake: default_handshake_timeout(),
|
||||||
|
tg_connect: default_connect_timeout(),
|
||||||
|
client_keepalive: default_keepalive(),
|
||||||
|
client_ack: default_ack_timeout(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct AntiCensorshipConfig {
|
||||||
|
#[serde(default = "default_tls_domain")]
|
||||||
|
pub tls_domain: String,
|
||||||
|
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub mask: bool,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub mask_host: Option<String>,
|
||||||
|
|
||||||
|
#[serde(default = "default_mask_port")]
|
||||||
|
pub mask_port: u16,
|
||||||
|
|
||||||
|
#[serde(default = "default_fake_cert_len")]
|
||||||
|
pub fake_cert_len: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AntiCensorshipConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
tls_domain: default_tls_domain(),
|
||||||
|
mask: true,
|
||||||
|
mask_host: None,
|
||||||
|
mask_port: default_mask_port(),
|
||||||
fake_cert_len: default_fake_cert_len(),
|
fake_cert_len: default_fake_cert_len(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct AccessConfig {
|
||||||
|
#[serde(default)]
|
||||||
|
pub users: HashMap<String, String>,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub user_max_tcp_conns: HashMap<String, usize>,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub user_expirations: HashMap<String, DateTime<Utc>>,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub user_data_quota: HashMap<String, u64>,
|
||||||
|
|
||||||
|
#[serde(default = "default_replay_check_len")]
|
||||||
|
pub replay_check_len: usize,
|
||||||
|
|
||||||
|
#[serde(default = "default_replay_window_secs")]
|
||||||
|
pub replay_window_secs: u64,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub ignore_time_skew: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AccessConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
let mut users = HashMap::new();
|
||||||
|
users.insert("default".to_string(), "00000000000000000000000000000000".to_string());
|
||||||
|
Self {
|
||||||
|
users,
|
||||||
|
user_max_tcp_conns: HashMap::new(),
|
||||||
|
user_expirations: HashMap::new(),
|
||||||
|
user_data_quota: HashMap::new(),
|
||||||
|
replay_check_len: default_replay_check_len(),
|
||||||
|
replay_window_secs: default_replay_window_secs(),
|
||||||
|
ignore_time_skew: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Aux Structures =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
#[serde(tag = "type", rename_all = "lowercase")]
|
||||||
|
pub enum UpstreamType {
|
||||||
|
Direct {
|
||||||
|
#[serde(default)]
|
||||||
|
interface: Option<String>,
|
||||||
|
},
|
||||||
|
Socks4 {
|
||||||
|
address: String,
|
||||||
|
#[serde(default)]
|
||||||
|
interface: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
user_id: Option<String>,
|
||||||
|
},
|
||||||
|
Socks5 {
|
||||||
|
address: String,
|
||||||
|
#[serde(default)]
|
||||||
|
interface: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
username: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
password: Option<String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct UpstreamConfig {
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub upstream_type: UpstreamType,
|
||||||
|
#[serde(default = "default_weight")]
|
||||||
|
pub weight: u16,
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub enabled: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ListenerConfig {
|
||||||
|
pub ip: IpAddr,
|
||||||
|
#[serde(default)]
|
||||||
|
pub announce_ip: Option<IpAddr>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Main Config =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
|
pub struct ProxyConfig {
|
||||||
|
#[serde(default)]
|
||||||
|
pub general: GeneralConfig,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub server: ServerConfig,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub timeouts: TimeoutsConfig,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub censorship: AntiCensorshipConfig,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub access: AccessConfig,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub upstreams: Vec<UpstreamConfig>,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub show_link: Vec<String>,
|
||||||
|
|
||||||
|
/// DC address overrides for non-standard DCs (CDN, media, test, etc.)
|
||||||
|
/// Keys are DC indices as strings, values are "ip:port" addresses.
|
||||||
|
/// Matches the C implementation's `proxy_for <dc_id> <ip>:<port>` config directive.
|
||||||
|
/// Example in config.toml:
|
||||||
|
/// [dc_overrides]
|
||||||
|
/// "203" = "149.154.175.100:443"
|
||||||
|
#[serde(default)]
|
||||||
|
pub dc_overrides: HashMap<String, String>,
|
||||||
|
|
||||||
|
/// Default DC index (1-5) for unmapped non-standard DCs.
|
||||||
|
/// Matches the C implementation's `default <dc_id>` config directive.
|
||||||
|
/// If not set, defaults to 2 (matching Telegram's official `default 2;` in proxy-multi.conf).
|
||||||
|
#[serde(default)]
|
||||||
|
pub default_dc: Option<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
impl ProxyConfig {
|
impl ProxyConfig {
|
||||||
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
|
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
|
||||||
let content = std::fs::read_to_string(path)
|
let content = std::fs::read_to_string(path)
|
||||||
@@ -169,7 +362,7 @@ impl ProxyConfig {
|
|||||||
.map_err(|e| ProxyError::Config(e.to_string()))?;
|
.map_err(|e| ProxyError::Config(e.to_string()))?;
|
||||||
|
|
||||||
// Validate secrets
|
// Validate secrets
|
||||||
for (user, secret) in &config.users {
|
for (user, secret) in &config.access.users {
|
||||||
if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 {
|
if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 {
|
||||||
return Err(ProxyError::InvalidSecret {
|
return Err(ProxyError::InvalidSecret {
|
||||||
user: user.clone(),
|
user: user.clone(),
|
||||||
@@ -178,50 +371,65 @@ impl ProxyConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default mask_host
|
// Validate tls_domain
|
||||||
if config.mask_host.is_none() {
|
if config.censorship.tls_domain.is_empty() {
|
||||||
config.mask_host = Some(config.tls_domain.clone());
|
return Err(ProxyError::Config("tls_domain cannot be empty".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default mask_host to tls_domain if not set
|
||||||
|
if config.censorship.mask_host.is_none() {
|
||||||
|
config.censorship.mask_host = Some(config.censorship.tls_domain.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Random fake_cert_len
|
// Random fake_cert_len
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
config.fake_cert_len = rand::thread_rng().gen_range(1024..4096);
|
config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096);
|
||||||
|
|
||||||
|
// Migration: Populate listeners if empty
|
||||||
|
if config.server.listeners.is_empty() {
|
||||||
|
if let Ok(ipv4) = config.server.listen_addr_ipv4.parse::<IpAddr>() {
|
||||||
|
config.server.listeners.push(ListenerConfig {
|
||||||
|
ip: ipv4,
|
||||||
|
announce_ip: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if let Some(ipv6_str) = &config.server.listen_addr_ipv6 {
|
||||||
|
if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() {
|
||||||
|
config.server.listeners.push(ListenerConfig {
|
||||||
|
ip: ipv6,
|
||||||
|
announce_ip: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Migration: Populate upstreams if empty (Default Direct)
|
||||||
|
if config.upstreams.is_empty() {
|
||||||
|
config.upstreams.push(UpstreamConfig {
|
||||||
|
upstream_type: UpstreamType::Direct { interface: None },
|
||||||
|
weight: 1,
|
||||||
|
enabled: true,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
Ok(config)
|
Ok(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn validate(&self) -> Result<()> {
|
pub fn validate(&self) -> Result<()> {
|
||||||
if self.users.is_empty() {
|
if self.access.users.is_empty() {
|
||||||
return Err(ProxyError::Config("No users configured".to_string()));
|
return Err(ProxyError::Config("No users configured".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if !self.modes.classic && !self.modes.secure && !self.modes.tls {
|
if !self.general.modes.classic && !self.general.modes.secure && !self.general.modes.tls {
|
||||||
return Err(ProxyError::Config("No modes enabled".to_string()));
|
return Err(ProxyError::Config("No modes enabled".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') {
|
||||||
|
return Err(ProxyError::Config(
|
||||||
|
format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_default_config() {
|
|
||||||
let config = ProxyConfig::default();
|
|
||||||
assert_eq!(config.port, 443);
|
|
||||||
assert!(config.modes.tls);
|
|
||||||
assert_eq!(config.client_keepalive, 600);
|
|
||||||
assert_eq!(config.client_ack_timeout, 300);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_config_validate() {
|
|
||||||
let mut config = ProxyConfig::default();
|
|
||||||
assert!(config.validate().is_ok());
|
|
||||||
|
|
||||||
config.users.clear();
|
|
||||||
assert!(config.validate().is_err());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +1,39 @@
|
|||||||
//! AES
|
//! AES encryption implementations
|
||||||
|
//!
|
||||||
|
//! Provides AES-256-CTR and AES-256-CBC modes for MTProto encryption.
|
||||||
|
//!
|
||||||
|
//! ## Zeroize policy
|
||||||
|
//!
|
||||||
|
//! - `AesCbc` stores raw key/IV bytes and zeroizes them on drop.
|
||||||
|
//! - `AesCtr` wraps an opaque `Aes256Ctr` cipher from the `ctr` crate.
|
||||||
|
//! The expanded key schedule lives inside that type and cannot be
|
||||||
|
//! zeroized from outside. Callers that hold raw key material (e.g.
|
||||||
|
//! `HandshakeSuccess`, `ObfuscationParams`) are responsible for
|
||||||
|
//! zeroizing their own copies.
|
||||||
|
|
||||||
use aes::Aes256;
|
use aes::Aes256;
|
||||||
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
|
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
|
||||||
use cbc::{Encryptor as CbcEncryptor, Decryptor as CbcDecryptor};
|
use zeroize::Zeroize;
|
||||||
use cbc::cipher::{BlockEncryptMut, BlockDecryptMut, block_padding::NoPadding};
|
|
||||||
use crate::error::{ProxyError, Result};
|
use crate::error::{ProxyError, Result};
|
||||||
|
|
||||||
type Aes256Ctr = Ctr128BE<Aes256>;
|
type Aes256Ctr = Ctr128BE<Aes256>;
|
||||||
type Aes256CbcEnc = CbcEncryptor<Aes256>;
|
|
||||||
type Aes256CbcDec = CbcDecryptor<Aes256>;
|
// ============= AES-256-CTR =============
|
||||||
|
|
||||||
/// AES-256-CTR encryptor/decryptor
|
/// AES-256-CTR encryptor/decryptor
|
||||||
|
///
|
||||||
|
/// CTR mode is symmetric — encryption and decryption are the same operation.
|
||||||
|
///
|
||||||
|
/// **Zeroize note:** The inner `Aes256Ctr` cipher state (expanded key schedule
|
||||||
|
/// + counter) is opaque and cannot be zeroized. If you need to protect key
|
||||||
|
/// material, zeroize the `[u8; 32]` key and `u128` IV at the call site
|
||||||
|
/// before dropping them.
|
||||||
pub struct AesCtr {
|
pub struct AesCtr {
|
||||||
cipher: Aes256Ctr,
|
cipher: Aes256Ctr,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AesCtr {
|
impl AesCtr {
|
||||||
|
/// Create new AES-CTR cipher with key and IV
|
||||||
pub fn new(key: &[u8; 32], iv: u128) -> Self {
|
pub fn new(key: &[u8; 32], iv: u128) -> Self {
|
||||||
let iv_bytes = iv.to_be_bytes();
|
let iv_bytes = iv.to_be_bytes();
|
||||||
Self {
|
Self {
|
||||||
@@ -23,6 +41,7 @@ impl AesCtr {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create from key and IV slices
|
||||||
pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result<Self> {
|
pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result<Self> {
|
||||||
if key.len() != 32 {
|
if key.len() != 32 {
|
||||||
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
|
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
|
||||||
@@ -54,17 +73,37 @@ impl AesCtr {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// AES-256-CBC Ciphermagic
|
// ============= AES-256-CBC =============
|
||||||
|
|
||||||
|
/// AES-256-CBC cipher with proper chaining
|
||||||
|
///
|
||||||
|
/// Unlike CTR mode, CBC is NOT symmetric — encryption and decryption
|
||||||
|
/// are different operations. This implementation handles CBC chaining
|
||||||
|
/// correctly across multiple blocks.
|
||||||
|
///
|
||||||
|
/// Key and IV are zeroized on drop.
|
||||||
pub struct AesCbc {
|
pub struct AesCbc {
|
||||||
key: [u8; 32],
|
key: [u8; 32],
|
||||||
iv: [u8; 16],
|
iv: [u8; 16],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Drop for AesCbc {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.key.zeroize();
|
||||||
|
self.iv.zeroize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl AesCbc {
|
impl AesCbc {
|
||||||
|
/// AES block size
|
||||||
|
const BLOCK_SIZE: usize = 16;
|
||||||
|
|
||||||
|
/// Create new AES-CBC cipher with key and IV
|
||||||
pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self {
|
pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self {
|
||||||
Self { key, iv }
|
Self { key, iv }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create from slices
|
||||||
pub fn from_slices(key: &[u8], iv: &[u8]) -> Result<Self> {
|
pub fn from_slices(key: &[u8], iv: &[u8]) -> Result<Self> {
|
||||||
if key.len() != 32 {
|
if key.len() != 32 {
|
||||||
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
|
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
|
||||||
@@ -79,9 +118,36 @@ impl AesCbc {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Encrypt data using CBC mode
|
/// Encrypt a single block using raw AES (no chaining)
|
||||||
|
fn encrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] {
|
||||||
|
use aes::cipher::BlockEncrypt;
|
||||||
|
let mut output = *block;
|
||||||
|
key_schedule.encrypt_block((&mut output).into());
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decrypt a single block using raw AES (no chaining)
|
||||||
|
fn decrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] {
|
||||||
|
use aes::cipher::BlockDecrypt;
|
||||||
|
let mut output = *block;
|
||||||
|
key_schedule.decrypt_block((&mut output).into());
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
/// XOR two 16-byte blocks
|
||||||
|
fn xor_blocks(a: &[u8; 16], b: &[u8; 16]) -> [u8; 16] {
|
||||||
|
let mut result = [0u8; 16];
|
||||||
|
for i in 0..16 {
|
||||||
|
result[i] = a[i] ^ b[i];
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encrypt data using CBC mode with proper chaining
|
||||||
|
///
|
||||||
|
/// CBC Encryption: C[i] = AES_Encrypt(P[i] XOR C[i-1]), where C[-1] = IV
|
||||||
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
|
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
|
||||||
if data.len() % 16 != 0 {
|
if data.len() % Self::BLOCK_SIZE != 0 {
|
||||||
return Err(ProxyError::Crypto(
|
return Err(ProxyError::Crypto(
|
||||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||||
));
|
));
|
||||||
@@ -91,20 +157,28 @@ impl AesCbc {
|
|||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut buffer = data.to_vec();
|
use aes::cipher::KeyInit;
|
||||||
|
let key_schedule = aes::Aes256::new((&self.key).into());
|
||||||
|
|
||||||
let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into());
|
let mut result = Vec::with_capacity(data.len());
|
||||||
|
let mut prev_ciphertext = self.iv;
|
||||||
|
|
||||||
for chunk in buffer.chunks_mut(16) {
|
for chunk in data.chunks(Self::BLOCK_SIZE) {
|
||||||
encryptor.encrypt_block_mut(chunk.into());
|
let plaintext: [u8; 16] = chunk.try_into().unwrap();
|
||||||
|
let xored = Self::xor_blocks(&plaintext, &prev_ciphertext);
|
||||||
|
let ciphertext = self.encrypt_block(&xored, &key_schedule);
|
||||||
|
prev_ciphertext = ciphertext;
|
||||||
|
result.extend_from_slice(&ciphertext);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(buffer)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Decrypt data using CBC mode
|
/// Decrypt data using CBC mode with proper chaining
|
||||||
|
///
|
||||||
|
/// CBC Decryption: P[i] = AES_Decrypt(C[i]) XOR C[i-1], where C[-1] = IV
|
||||||
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
|
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
|
||||||
if data.len() % 16 != 0 {
|
if data.len() % Self::BLOCK_SIZE != 0 {
|
||||||
return Err(ProxyError::Crypto(
|
return Err(ProxyError::Crypto(
|
||||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||||
));
|
));
|
||||||
@@ -114,20 +188,26 @@ impl AesCbc {
|
|||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut buffer = data.to_vec();
|
use aes::cipher::KeyInit;
|
||||||
|
let key_schedule = aes::Aes256::new((&self.key).into());
|
||||||
|
|
||||||
let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into());
|
let mut result = Vec::with_capacity(data.len());
|
||||||
|
let mut prev_ciphertext = self.iv;
|
||||||
|
|
||||||
for chunk in buffer.chunks_mut(16) {
|
for chunk in data.chunks(Self::BLOCK_SIZE) {
|
||||||
decryptor.decrypt_block_mut(chunk.into());
|
let ciphertext: [u8; 16] = chunk.try_into().unwrap();
|
||||||
|
let decrypted = self.decrypt_block(&ciphertext, &key_schedule);
|
||||||
|
let plaintext = Self::xor_blocks(&decrypted, &prev_ciphertext);
|
||||||
|
prev_ciphertext = ciphertext;
|
||||||
|
result.extend_from_slice(&plaintext);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(buffer)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Encrypt data in-place
|
/// Encrypt data in-place
|
||||||
pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
|
pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
|
||||||
if data.len() % 16 != 0 {
|
if data.len() % Self::BLOCK_SIZE != 0 {
|
||||||
return Err(ProxyError::Crypto(
|
return Err(ProxyError::Crypto(
|
||||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||||
));
|
));
|
||||||
@@ -137,10 +217,22 @@ impl AesCbc {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut encryptor = Aes256CbcEnc::new((&self.key).into(), (&self.iv).into());
|
use aes::cipher::KeyInit;
|
||||||
|
let key_schedule = aes::Aes256::new((&self.key).into());
|
||||||
|
|
||||||
for chunk in data.chunks_mut(16) {
|
let mut prev_ciphertext = self.iv;
|
||||||
encryptor.encrypt_block_mut(chunk.into());
|
|
||||||
|
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
|
||||||
|
let block = &mut data[i..i + Self::BLOCK_SIZE];
|
||||||
|
|
||||||
|
for j in 0..Self::BLOCK_SIZE {
|
||||||
|
block[j] ^= prev_ciphertext[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
let block_array: &mut [u8; 16] = block.try_into().unwrap();
|
||||||
|
*block_array = self.encrypt_block(block_array, &key_schedule);
|
||||||
|
|
||||||
|
prev_ciphertext = *block_array;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -148,7 +240,7 @@ impl AesCbc {
|
|||||||
|
|
||||||
/// Decrypt data in-place
|
/// Decrypt data in-place
|
||||||
pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
|
pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
|
||||||
if data.len() % 16 != 0 {
|
if data.len() % Self::BLOCK_SIZE != 0 {
|
||||||
return Err(ProxyError::Crypto(
|
return Err(ProxyError::Crypto(
|
||||||
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
|
||||||
));
|
));
|
||||||
@@ -158,16 +250,32 @@ impl AesCbc {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut decryptor = Aes256CbcDec::new((&self.key).into(), (&self.iv).into());
|
use aes::cipher::KeyInit;
|
||||||
|
let key_schedule = aes::Aes256::new((&self.key).into());
|
||||||
|
|
||||||
for chunk in data.chunks_mut(16) {
|
let mut prev_ciphertext = self.iv;
|
||||||
decryptor.decrypt_block_mut(chunk.into());
|
|
||||||
|
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
|
||||||
|
let block = &mut data[i..i + Self::BLOCK_SIZE];
|
||||||
|
|
||||||
|
let current_ciphertext: [u8; 16] = block.try_into().unwrap();
|
||||||
|
|
||||||
|
let block_array: &mut [u8; 16] = block.try_into().unwrap();
|
||||||
|
*block_array = self.decrypt_block(block_array, &key_schedule);
|
||||||
|
|
||||||
|
for j in 0..Self::BLOCK_SIZE {
|
||||||
|
block[j] ^= prev_ciphertext[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
prev_ciphertext = current_ciphertext;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============= Encryption Traits =============
|
||||||
|
|
||||||
/// Trait for unified encryption interface
|
/// Trait for unified encryption interface
|
||||||
pub trait Encryptor: Send + Sync {
|
pub trait Encryptor: Send + Sync {
|
||||||
fn encrypt(&mut self, data: &[u8]) -> Vec<u8>;
|
fn encrypt(&mut self, data: &[u8]) -> Vec<u8>;
|
||||||
@@ -209,6 +317,8 @@ impl Decryptor for PassthroughEncryptor {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
// ============= AES-CTR Tests =============
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_ctr_roundtrip() {
|
fn test_aes_ctr_roundtrip() {
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
@@ -225,12 +335,32 @@ mod tests {
|
|||||||
assert_eq!(original.as_slice(), decrypted.as_slice());
|
assert_eq!(original.as_slice(), decrypted.as_slice());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_aes_ctr_in_place() {
|
||||||
|
let key = [0x42u8; 32];
|
||||||
|
let iv = 999u128;
|
||||||
|
|
||||||
|
let original = b"Test data for in-place encryption";
|
||||||
|
let mut data = original.to_vec();
|
||||||
|
|
||||||
|
let mut cipher = AesCtr::new(&key, iv);
|
||||||
|
cipher.apply(&mut data);
|
||||||
|
|
||||||
|
assert_ne!(&data[..], original);
|
||||||
|
|
||||||
|
let mut cipher = AesCtr::new(&key, iv);
|
||||||
|
cipher.apply(&mut data);
|
||||||
|
|
||||||
|
assert_eq!(&data[..], original);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= AES-CBC Tests =============
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_cbc_roundtrip() {
|
fn test_aes_cbc_roundtrip() {
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
let iv = [0u8; 16];
|
let iv = [0u8; 16];
|
||||||
|
|
||||||
// Must be aligned to 16 bytes
|
|
||||||
let original = [0u8; 32];
|
let original = [0u8; 32];
|
||||||
|
|
||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
@@ -245,44 +375,47 @@ mod tests {
|
|||||||
let key = [0x42u8; 32];
|
let key = [0x42u8; 32];
|
||||||
let iv = [0x00u8; 16];
|
let iv = [0x00u8; 16];
|
||||||
|
|
||||||
let plaintext = [0xAA_u8; 32];
|
let plaintext = [0xAAu8; 32];
|
||||||
|
|
||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
||||||
|
|
||||||
// CBC Corrections
|
|
||||||
let block1 = &ciphertext[0..16];
|
let block1 = &ciphertext[0..16];
|
||||||
let block2 = &ciphertext[16..32];
|
let block2 = &ciphertext[16..32];
|
||||||
|
|
||||||
assert_ne!(block1, block2, "CBC chaining broken: identical plaintext blocks produced identical ciphertext");
|
assert_ne!(
|
||||||
|
block1, block2,
|
||||||
|
"CBC chaining broken: identical plaintext blocks produced identical ciphertext"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_cbc_known_vector() {
|
fn test_aes_cbc_known_vector() {
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
let iv = [0u8; 16];
|
let iv = [0u8; 16];
|
||||||
|
let plaintext = [0u8; 16];
|
||||||
// 3 Datablocks
|
|
||||||
let plaintext = [
|
|
||||||
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
|
|
||||||
0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF,
|
|
||||||
// Block 2
|
|
||||||
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
|
|
||||||
0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF,
|
|
||||||
// Block 3 - different
|
|
||||||
0xFF, 0xEE, 0xDD, 0xCC, 0xBB, 0xAA, 0x99, 0x88,
|
|
||||||
0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11, 0x00,
|
|
||||||
];
|
|
||||||
|
|
||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
||||||
|
|
||||||
// Decrypt + Verify
|
|
||||||
let decrypted = cipher.decrypt(&ciphertext).unwrap();
|
let decrypted = cipher.decrypt(&ciphertext).unwrap();
|
||||||
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
|
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
|
||||||
|
|
||||||
// Verify Ciphertexts Block 1 != Block 2
|
assert_ne!(ciphertext.as_slice(), plaintext.as_slice());
|
||||||
assert_ne!(&ciphertext[0..16], &ciphertext[16..32]);
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_aes_cbc_multi_block() {
|
||||||
|
let key = [0x12u8; 32];
|
||||||
|
let iv = [0x34u8; 16];
|
||||||
|
|
||||||
|
let plaintext: Vec<u8> = (0..80).collect();
|
||||||
|
|
||||||
|
let cipher = AesCbc::new(key, iv);
|
||||||
|
let ciphertext = cipher.encrypt(&plaintext).unwrap();
|
||||||
|
let decrypted = cipher.decrypt(&ciphertext).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(plaintext, decrypted);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -290,8 +423,8 @@ mod tests {
|
|||||||
let key = [0x12u8; 32];
|
let key = [0x12u8; 32];
|
||||||
let iv = [0x34u8; 16];
|
let iv = [0x34u8; 16];
|
||||||
|
|
||||||
let original = [0x56u8; 48]; // 3 blocks
|
let original = [0x56u8; 48];
|
||||||
let mut buffer = original.clone();
|
let mut buffer = original;
|
||||||
|
|
||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
|
|
||||||
@@ -317,35 +450,93 @@ mod tests {
|
|||||||
fn test_aes_cbc_unaligned_error() {
|
fn test_aes_cbc_unaligned_error() {
|
||||||
let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
|
let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
|
||||||
|
|
||||||
// 15 bytes
|
|
||||||
let result = cipher.encrypt(&[0u8; 15]);
|
let result = cipher.encrypt(&[0u8; 15]);
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
|
|
||||||
// 17 bytes
|
|
||||||
let result = cipher.encrypt(&[0u8; 17]);
|
let result = cipher.encrypt(&[0u8; 17]);
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_cbc_avalanche_effect() {
|
fn test_aes_cbc_avalanche_effect() {
|
||||||
// Cipherplane
|
|
||||||
|
|
||||||
let key = [0xAB; 32];
|
let key = [0xAB; 32];
|
||||||
let iv = [0xCD; 16];
|
let iv = [0xCD; 16];
|
||||||
|
|
||||||
let mut plaintext1 = [0u8; 32];
|
let plaintext1 = [0u8; 32];
|
||||||
let mut plaintext2 = [0u8; 32];
|
let mut plaintext2 = [0u8; 32];
|
||||||
plaintext2[0] = 0x01; // Один бит отличается
|
plaintext2[0] = 0x01;
|
||||||
|
|
||||||
let cipher = AesCbc::new(key, iv);
|
let cipher = AesCbc::new(key, iv);
|
||||||
|
|
||||||
let ciphertext1 = cipher.encrypt(&plaintext1).unwrap();
|
let ciphertext1 = cipher.encrypt(&plaintext1).unwrap();
|
||||||
let ciphertext2 = cipher.encrypt(&plaintext2).unwrap();
|
let ciphertext2 = cipher.encrypt(&plaintext2).unwrap();
|
||||||
|
|
||||||
// First Blocks Diff
|
|
||||||
assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]);
|
assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]);
|
||||||
|
|
||||||
// Second Blocks Diff
|
|
||||||
assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]);
|
assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_aes_cbc_iv_matters() {
|
||||||
|
let key = [0x55; 32];
|
||||||
|
let plaintext = [0x77u8; 16];
|
||||||
|
|
||||||
|
let cipher1 = AesCbc::new(key, [0u8; 16]);
|
||||||
|
let cipher2 = AesCbc::new(key, [1u8; 16]);
|
||||||
|
|
||||||
|
let ciphertext1 = cipher1.encrypt(&plaintext).unwrap();
|
||||||
|
let ciphertext2 = cipher2.encrypt(&plaintext).unwrap();
|
||||||
|
|
||||||
|
assert_ne!(ciphertext1, ciphertext2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_aes_cbc_deterministic() {
|
||||||
|
let key = [0x99; 32];
|
||||||
|
let iv = [0x88; 16];
|
||||||
|
let plaintext = [0x77u8; 32];
|
||||||
|
|
||||||
|
let cipher = AesCbc::new(key, iv);
|
||||||
|
|
||||||
|
let ciphertext1 = cipher.encrypt(&plaintext).unwrap();
|
||||||
|
let ciphertext2 = cipher.encrypt(&plaintext).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(ciphertext1, ciphertext2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Zeroize Tests =============
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_aes_cbc_zeroize_on_drop() {
|
||||||
|
let key = [0xAA; 32];
|
||||||
|
let iv = [0xBB; 16];
|
||||||
|
|
||||||
|
let cipher = AesCbc::new(key, iv);
|
||||||
|
// Verify key/iv are set
|
||||||
|
assert_eq!(cipher.key, [0xAA; 32]);
|
||||||
|
assert_eq!(cipher.iv, [0xBB; 16]);
|
||||||
|
|
||||||
|
drop(cipher);
|
||||||
|
// After drop, key/iv are zeroized (can't observe directly,
|
||||||
|
// but the Drop impl runs without panic)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Error Handling Tests =============
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_invalid_key_length() {
|
||||||
|
let result = AesCtr::from_key_iv(&[0u8; 16], &[0u8; 16]);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
let result = AesCbc::from_slices(&[0u8; 16], &[0u8; 16]);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_invalid_iv_length() {
|
||||||
|
let result = AesCtr::from_key_iv(&[0u8; 32], &[0u8; 8]);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
let result = AesCbc::from_slices(&[0u8; 32], &[0u8; 8]);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -1,3 +1,16 @@
|
|||||||
|
//! Cryptographic hash functions
|
||||||
|
//!
|
||||||
|
//! ## Protocol-required algorithms
|
||||||
|
//!
|
||||||
|
//! This module exposes MD5 and SHA-1 alongside SHA-256. These weaker
|
||||||
|
//! hash functions are **required by the Telegram Middle Proxy protocol**
|
||||||
|
//! (`derive_middleproxy_keys`) and cannot be replaced without breaking
|
||||||
|
//! compatibility. They are NOT used for any security-sensitive purpose
|
||||||
|
//! outside of that specific key derivation scheme mandated by Telegram.
|
||||||
|
//!
|
||||||
|
//! Static analysis tools (CodeQL, cargo-audit) may flag them — the
|
||||||
|
//! usages are intentional and protocol-mandated.
|
||||||
|
|
||||||
use hmac::{Hmac, Mac};
|
use hmac::{Hmac, Mac};
|
||||||
use sha2::Sha256;
|
use sha2::Sha256;
|
||||||
use md5::Md5;
|
use md5::Md5;
|
||||||
@@ -21,14 +34,16 @@ pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] {
|
|||||||
mac.finalize().into_bytes().into()
|
mac.finalize().into_bytes().into()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// SHA-1
|
/// SHA-1 — **protocol-required** by Telegram Middle Proxy key derivation.
|
||||||
|
/// Not used for general-purpose hashing.
|
||||||
pub fn sha1(data: &[u8]) -> [u8; 20] {
|
pub fn sha1(data: &[u8]) -> [u8; 20] {
|
||||||
let mut hasher = Sha1::new();
|
let mut hasher = Sha1::new();
|
||||||
hasher.update(data);
|
hasher.update(data);
|
||||||
hasher.finalize().into()
|
hasher.finalize().into()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// MD5
|
/// MD5 — **protocol-required** by Telegram Middle Proxy key derivation.
|
||||||
|
/// Not used for general-purpose hashing.
|
||||||
pub fn md5(data: &[u8]) -> [u8; 16] {
|
pub fn md5(data: &[u8]) -> [u8; 16] {
|
||||||
let mut hasher = Md5::new();
|
let mut hasher = Md5::new();
|
||||||
hasher.update(data);
|
hasher.update(data);
|
||||||
@@ -40,7 +55,11 @@ pub fn crc32(data: &[u8]) -> u32 {
|
|||||||
crc32fast::hash(data)
|
crc32fast::hash(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Middle Proxy Keygen
|
/// Middle Proxy key derivation
|
||||||
|
///
|
||||||
|
/// Uses MD5 + SHA-1 as mandated by the Telegram Middle Proxy protocol.
|
||||||
|
/// These algorithms are NOT replaceable here — changing them would break
|
||||||
|
/// interoperability with Telegram's middle proxy infrastructure.
|
||||||
pub fn derive_middleproxy_keys(
|
pub fn derive_middleproxy_keys(
|
||||||
nonce_srv: &[u8; 16],
|
nonce_srv: &[u8; 16],
|
||||||
nonce_clt: &[u8; 16],
|
nonce_clt: &[u8; 16],
|
||||||
|
|||||||
@@ -6,4 +6,4 @@ pub mod random;
|
|||||||
|
|
||||||
pub use aes::{AesCtr, AesCbc};
|
pub use aes::{AesCtr, AesCbc};
|
||||||
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32};
|
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32};
|
||||||
pub use random::{SecureRandom, SECURE_RANDOM};
|
pub use random::SecureRandom;
|
||||||
@@ -3,11 +3,8 @@
|
|||||||
use rand::{Rng, RngCore, SeedableRng};
|
use rand::{Rng, RngCore, SeedableRng};
|
||||||
use rand::rngs::StdRng;
|
use rand::rngs::StdRng;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
use zeroize::Zeroize;
|
||||||
use crate::crypto::AesCtr;
|
use crate::crypto::AesCtr;
|
||||||
use once_cell::sync::Lazy;
|
|
||||||
|
|
||||||
/// Global secure random instance
|
|
||||||
pub static SECURE_RANDOM: Lazy<SecureRandom> = Lazy::new(SecureRandom::new);
|
|
||||||
|
|
||||||
/// Cryptographically secure PRNG with AES-CTR
|
/// Cryptographically secure PRNG with AES-CTR
|
||||||
pub struct SecureRandom {
|
pub struct SecureRandom {
|
||||||
@@ -20,18 +17,30 @@ struct SecureRandomInner {
|
|||||||
buffer: Vec<u8>,
|
buffer: Vec<u8>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Drop for SecureRandomInner {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.buffer.zeroize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl SecureRandom {
|
impl SecureRandom {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
let mut rng = StdRng::from_entropy();
|
let mut seed_source = rand::rng();
|
||||||
|
let mut rng = StdRng::from_rng(&mut seed_source);
|
||||||
|
|
||||||
let mut key = [0u8; 32];
|
let mut key = [0u8; 32];
|
||||||
rng.fill_bytes(&mut key);
|
rng.fill_bytes(&mut key);
|
||||||
let iv: u128 = rng.gen();
|
let iv: u128 = rng.random();
|
||||||
|
|
||||||
|
let cipher = AesCtr::new(&key, iv);
|
||||||
|
|
||||||
|
// Zeroize local key copy — cipher already consumed it
|
||||||
|
key.zeroize();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
inner: Mutex::new(SecureRandomInner {
|
inner: Mutex::new(SecureRandomInner {
|
||||||
rng,
|
rng,
|
||||||
cipher: AesCtr::new(&key, iv),
|
cipher,
|
||||||
buffer: Vec::with_capacity(1024),
|
buffer: Vec::with_capacity(1024),
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
@@ -78,7 +87,6 @@ impl SecureRandom {
|
|||||||
result |= (b as u64) << (i * 8);
|
result |= (b as u64) << (i * 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mask extra bits
|
|
||||||
if k < 64 {
|
if k < 64 {
|
||||||
result &= (1u64 << k) - 1;
|
result &= (1u64 << k) - 1;
|
||||||
}
|
}
|
||||||
@@ -107,13 +115,13 @@ impl SecureRandom {
|
|||||||
/// Generate random u32
|
/// Generate random u32
|
||||||
pub fn u32(&self) -> u32 {
|
pub fn u32(&self) -> u32 {
|
||||||
let mut inner = self.inner.lock();
|
let mut inner = self.inner.lock();
|
||||||
inner.rng.gen()
|
inner.rng.random()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate random u64
|
/// Generate random u64
|
||||||
pub fn u64(&self) -> u64 {
|
pub fn u64(&self) -> u64 {
|
||||||
let mut inner = self.inner.lock();
|
let mut inner = self.inner.lock();
|
||||||
inner.rng.gen()
|
inner.rng.random()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,12 +170,10 @@ mod tests {
|
|||||||
fn test_bits() {
|
fn test_bits() {
|
||||||
let rng = SecureRandom::new();
|
let rng = SecureRandom::new();
|
||||||
|
|
||||||
// Single bit should be 0 or 1
|
|
||||||
for _ in 0..100 {
|
for _ in 0..100 {
|
||||||
assert!(rng.bits(1) <= 1);
|
assert!(rng.bits(1) <= 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 8 bits should be 0-255
|
|
||||||
for _ in 0..100 {
|
for _ in 0..100 {
|
||||||
assert!(rng.bits(8) <= 255);
|
assert!(rng.bits(8) <= 255);
|
||||||
}
|
}
|
||||||
@@ -185,10 +191,8 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should have seen all items
|
|
||||||
assert_eq!(seen.len(), 5);
|
assert_eq!(seen.len(), 5);
|
||||||
|
|
||||||
// Empty slice should return None
|
|
||||||
let empty: Vec<i32> = vec![];
|
let empty: Vec<i32> = vec![];
|
||||||
assert!(rng.choose(&empty).is_none());
|
assert!(rng.choose(&empty).is_none());
|
||||||
}
|
}
|
||||||
@@ -201,12 +205,10 @@ mod tests {
|
|||||||
let mut shuffled = original.clone();
|
let mut shuffled = original.clone();
|
||||||
rng.shuffle(&mut shuffled);
|
rng.shuffle(&mut shuffled);
|
||||||
|
|
||||||
// Should contain same elements
|
|
||||||
let mut sorted = shuffled.clone();
|
let mut sorted = shuffled.clone();
|
||||||
sorted.sort();
|
sorted.sort();
|
||||||
assert_eq!(sorted, original);
|
assert_eq!(sorted, original);
|
||||||
|
|
||||||
// Should be different order (with very high probability)
|
|
||||||
assert_ne!(shuffled, original);
|
assert_ne!(shuffled, original);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
282
src/error.rs
282
src/error.rs
@@ -1,8 +1,170 @@
|
|||||||
//! Error Types
|
//! Error Types
|
||||||
|
|
||||||
|
use std::fmt;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
|
// ============= Stream Errors =============
|
||||||
|
|
||||||
|
/// Errors specific to stream I/O operations
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum StreamError {
|
||||||
|
/// Partial read: got fewer bytes than expected
|
||||||
|
PartialRead {
|
||||||
|
expected: usize,
|
||||||
|
got: usize,
|
||||||
|
},
|
||||||
|
/// Partial write: wrote fewer bytes than expected
|
||||||
|
PartialWrite {
|
||||||
|
expected: usize,
|
||||||
|
written: usize,
|
||||||
|
},
|
||||||
|
/// Stream is in poisoned state and cannot be used
|
||||||
|
Poisoned {
|
||||||
|
reason: String,
|
||||||
|
},
|
||||||
|
/// Buffer overflow: attempted to buffer more than allowed
|
||||||
|
BufferOverflow {
|
||||||
|
limit: usize,
|
||||||
|
attempted: usize,
|
||||||
|
},
|
||||||
|
/// Invalid frame format
|
||||||
|
InvalidFrame {
|
||||||
|
details: String,
|
||||||
|
},
|
||||||
|
/// Unexpected end of stream
|
||||||
|
UnexpectedEof,
|
||||||
|
/// Underlying I/O error
|
||||||
|
Io(std::io::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for StreamError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::PartialRead { expected, got } => {
|
||||||
|
write!(f, "partial read: expected {} bytes, got {}", expected, got)
|
||||||
|
}
|
||||||
|
Self::PartialWrite { expected, written } => {
|
||||||
|
write!(f, "partial write: expected {} bytes, wrote {}", expected, written)
|
||||||
|
}
|
||||||
|
Self::Poisoned { reason } => {
|
||||||
|
write!(f, "stream poisoned: {}", reason)
|
||||||
|
}
|
||||||
|
Self::BufferOverflow { limit, attempted } => {
|
||||||
|
write!(f, "buffer overflow: limit {}, attempted {}", limit, attempted)
|
||||||
|
}
|
||||||
|
Self::InvalidFrame { details } => {
|
||||||
|
write!(f, "invalid frame: {}", details)
|
||||||
|
}
|
||||||
|
Self::UnexpectedEof => {
|
||||||
|
write!(f, "unexpected end of stream")
|
||||||
|
}
|
||||||
|
Self::Io(e) => {
|
||||||
|
write!(f, "I/O error: {}", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for StreamError {
|
||||||
|
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||||
|
match self {
|
||||||
|
Self::Io(e) => Some(e),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for StreamError {
|
||||||
|
fn from(err: std::io::Error) -> Self {
|
||||||
|
Self::Io(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StreamError> for std::io::Error {
|
||||||
|
fn from(err: StreamError) -> Self {
|
||||||
|
match err {
|
||||||
|
StreamError::Io(e) => e,
|
||||||
|
StreamError::UnexpectedEof => {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::UnexpectedEof, err)
|
||||||
|
}
|
||||||
|
StreamError::Poisoned { .. } => {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::Other, err)
|
||||||
|
}
|
||||||
|
StreamError::BufferOverflow { .. } => {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::OutOfMemory, err)
|
||||||
|
}
|
||||||
|
StreamError::InvalidFrame { .. } => {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::InvalidData, err)
|
||||||
|
}
|
||||||
|
StreamError::PartialRead { .. } | StreamError::PartialWrite { .. } => {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::Other, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Recoverable Trait =============
|
||||||
|
|
||||||
|
/// Trait for errors that may be recoverable
|
||||||
|
pub trait Recoverable {
|
||||||
|
/// Check if error is recoverable (can retry operation)
|
||||||
|
fn is_recoverable(&self) -> bool;
|
||||||
|
|
||||||
|
/// Check if connection can continue after this error
|
||||||
|
fn can_continue(&self) -> bool;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Recoverable for StreamError {
|
||||||
|
fn is_recoverable(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::PartialRead { .. } | Self::PartialWrite { .. } => true,
|
||||||
|
Self::Io(e) => matches!(
|
||||||
|
e.kind(),
|
||||||
|
std::io::ErrorKind::WouldBlock
|
||||||
|
| std::io::ErrorKind::Interrupted
|
||||||
|
| std::io::ErrorKind::TimedOut
|
||||||
|
),
|
||||||
|
Self::Poisoned { .. }
|
||||||
|
| Self::BufferOverflow { .. }
|
||||||
|
| Self::InvalidFrame { .. }
|
||||||
|
| Self::UnexpectedEof => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn can_continue(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Poisoned { .. } => false,
|
||||||
|
Self::UnexpectedEof => false,
|
||||||
|
Self::BufferOverflow { .. } => false,
|
||||||
|
_ => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Recoverable for std::io::Error {
|
||||||
|
fn is_recoverable(&self) -> bool {
|
||||||
|
matches!(
|
||||||
|
self.kind(),
|
||||||
|
std::io::ErrorKind::WouldBlock
|
||||||
|
| std::io::ErrorKind::Interrupted
|
||||||
|
| std::io::ErrorKind::TimedOut
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn can_continue(&self) -> bool {
|
||||||
|
!matches!(
|
||||||
|
self.kind(),
|
||||||
|
std::io::ErrorKind::BrokenPipe
|
||||||
|
| std::io::ErrorKind::ConnectionReset
|
||||||
|
| std::io::ErrorKind::ConnectionAborted
|
||||||
|
| std::io::ErrorKind::NotConnected
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Main Proxy Errors =============
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum ProxyError {
|
pub enum ProxyError {
|
||||||
// ============= Crypto Errors =============
|
// ============= Crypto Errors =============
|
||||||
@@ -13,6 +175,11 @@ pub enum ProxyError {
|
|||||||
#[error("Invalid key length: expected {expected}, got {got}")]
|
#[error("Invalid key length: expected {expected}, got {got}")]
|
||||||
InvalidKeyLength { expected: usize, got: usize },
|
InvalidKeyLength { expected: usize, got: usize },
|
||||||
|
|
||||||
|
// ============= Stream Errors =============
|
||||||
|
|
||||||
|
#[error("Stream error: {0}")]
|
||||||
|
Stream(#[from] StreamError),
|
||||||
|
|
||||||
// ============= Protocol Errors =============
|
// ============= Protocol Errors =============
|
||||||
|
|
||||||
#[error("Invalid handshake: {0}")]
|
#[error("Invalid handshake: {0}")]
|
||||||
@@ -39,6 +206,12 @@ pub enum ProxyError {
|
|||||||
#[error("Sequence number mismatch: expected={expected}, got={got}")]
|
#[error("Sequence number mismatch: expected={expected}, got={got}")]
|
||||||
SeqNoMismatch { expected: i32, got: i32 },
|
SeqNoMismatch { expected: i32, got: i32 },
|
||||||
|
|
||||||
|
#[error("TLS handshake failed: {reason}")]
|
||||||
|
TlsHandshakeFailed { reason: String },
|
||||||
|
|
||||||
|
#[error("Telegram handshake timeout")]
|
||||||
|
TgHandshakeTimeout,
|
||||||
|
|
||||||
// ============= Network Errors =============
|
// ============= Network Errors =============
|
||||||
|
|
||||||
#[error("Connection timeout to {addr}")]
|
#[error("Connection timeout to {addr}")]
|
||||||
@@ -55,6 +228,9 @@ pub enum ProxyError {
|
|||||||
#[error("Invalid proxy protocol header")]
|
#[error("Invalid proxy protocol header")]
|
||||||
InvalidProxyProtocol,
|
InvalidProxyProtocol,
|
||||||
|
|
||||||
|
#[error("Proxy error: {0}")]
|
||||||
|
Proxy(String),
|
||||||
|
|
||||||
// ============= Config Errors =============
|
// ============= Config Errors =============
|
||||||
|
|
||||||
#[error("Config error: {0}")]
|
#[error("Config error: {0}")]
|
||||||
@@ -77,27 +253,53 @@ pub enum ProxyError {
|
|||||||
#[error("Unknown user")]
|
#[error("Unknown user")]
|
||||||
UnknownUser,
|
UnknownUser,
|
||||||
|
|
||||||
|
#[error("Rate limited")]
|
||||||
|
RateLimited,
|
||||||
|
|
||||||
// ============= General Errors =============
|
// ============= General Errors =============
|
||||||
|
|
||||||
#[error("Internal error: {0}")]
|
#[error("Internal error: {0}")]
|
||||||
Internal(String),
|
Internal(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Recoverable for ProxyError {
|
||||||
|
fn is_recoverable(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Stream(e) => e.is_recoverable(),
|
||||||
|
Self::Io(e) => e.is_recoverable(),
|
||||||
|
Self::ConnectionTimeout { .. } => true,
|
||||||
|
Self::RateLimited => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn can_continue(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Stream(e) => e.can_continue(),
|
||||||
|
Self::Io(e) => e.can_continue(),
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Convenient Result type alias
|
/// Convenient Result type alias
|
||||||
pub type Result<T> = std::result::Result<T, ProxyError>;
|
pub type Result<T> = std::result::Result<T, ProxyError>;
|
||||||
|
|
||||||
|
/// Result type for stream operations
|
||||||
|
pub type StreamResult<T> = std::result::Result<T, StreamError>;
|
||||||
|
|
||||||
/// Result with optional bad client handling
|
/// Result with optional bad client handling
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum HandshakeResult<T> {
|
pub enum HandshakeResult<T, R, W> {
|
||||||
/// Handshake succeeded
|
/// Handshake succeeded
|
||||||
Success(T),
|
Success(T),
|
||||||
/// Client failed validation, needs masking
|
/// Client failed validation, needs masking. Returns ownership of streams.
|
||||||
BadClient,
|
BadClient { reader: R, writer: W },
|
||||||
/// Error occurred
|
/// Error occurred
|
||||||
Error(ProxyError),
|
Error(ProxyError),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> HandshakeResult<T> {
|
impl<T, R, W> HandshakeResult<T, R, W> {
|
||||||
/// Check if successful
|
/// Check if successful
|
||||||
pub fn is_success(&self) -> bool {
|
pub fn is_success(&self) -> bool {
|
||||||
matches!(self, HandshakeResult::Success(_))
|
matches!(self, HandshakeResult::Success(_))
|
||||||
@@ -105,58 +307,87 @@ impl<T> HandshakeResult<T> {
|
|||||||
|
|
||||||
/// Check if bad client
|
/// Check if bad client
|
||||||
pub fn is_bad_client(&self) -> bool {
|
pub fn is_bad_client(&self) -> bool {
|
||||||
matches!(self, HandshakeResult::BadClient)
|
matches!(self, HandshakeResult::BadClient { .. })
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert to Result, treating BadClient as error
|
|
||||||
pub fn into_result(self) -> Result<T> {
|
|
||||||
match self {
|
|
||||||
HandshakeResult::Success(v) => Ok(v),
|
|
||||||
HandshakeResult::BadClient => Err(ProxyError::InvalidHandshake("Bad client".into())),
|
|
||||||
HandshakeResult::Error(e) => Err(e),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Map the success value
|
/// Map the success value
|
||||||
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U> {
|
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U, R, W> {
|
||||||
match self {
|
match self {
|
||||||
HandshakeResult::Success(v) => HandshakeResult::Success(f(v)),
|
HandshakeResult::Success(v) => HandshakeResult::Success(f(v)),
|
||||||
HandshakeResult::BadClient => HandshakeResult::BadClient,
|
HandshakeResult::BadClient { reader, writer } => HandshakeResult::BadClient { reader, writer },
|
||||||
HandshakeResult::Error(e) => HandshakeResult::Error(e),
|
HandshakeResult::Error(e) => HandshakeResult::Error(e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> From<ProxyError> for HandshakeResult<T> {
|
impl<T, R, W> From<ProxyError> for HandshakeResult<T, R, W> {
|
||||||
fn from(err: ProxyError) -> Self {
|
fn from(err: ProxyError) -> Self {
|
||||||
HandshakeResult::Error(err)
|
HandshakeResult::Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> From<std::io::Error> for HandshakeResult<T> {
|
impl<T, R, W> From<std::io::Error> for HandshakeResult<T, R, W> {
|
||||||
fn from(err: std::io::Error) -> Self {
|
fn from(err: std::io::Error) -> Self {
|
||||||
HandshakeResult::Error(ProxyError::Io(err))
|
HandshakeResult::Error(ProxyError::Io(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T, R, W> From<StreamError> for HandshakeResult<T, R, W> {
|
||||||
|
fn from(err: StreamError) -> Self {
|
||||||
|
HandshakeResult::Error(ProxyError::Stream(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream_error_display() {
|
||||||
|
let err = StreamError::PartialRead { expected: 100, got: 50 };
|
||||||
|
assert!(err.to_string().contains("100"));
|
||||||
|
assert!(err.to_string().contains("50"));
|
||||||
|
|
||||||
|
let err = StreamError::Poisoned { reason: "test".into() };
|
||||||
|
assert!(err.to_string().contains("test"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream_error_recoverable() {
|
||||||
|
assert!(StreamError::PartialRead { expected: 10, got: 5 }.is_recoverable());
|
||||||
|
assert!(StreamError::PartialWrite { expected: 10, written: 5 }.is_recoverable());
|
||||||
|
assert!(!StreamError::Poisoned { reason: "x".into() }.is_recoverable());
|
||||||
|
assert!(!StreamError::UnexpectedEof.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream_error_can_continue() {
|
||||||
|
assert!(!StreamError::Poisoned { reason: "x".into() }.can_continue());
|
||||||
|
assert!(!StreamError::UnexpectedEof.can_continue());
|
||||||
|
assert!(StreamError::PartialRead { expected: 10, got: 5 }.can_continue());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream_error_to_io_error() {
|
||||||
|
let stream_err = StreamError::UnexpectedEof;
|
||||||
|
let io_err: std::io::Error = stream_err.into();
|
||||||
|
assert_eq!(io_err.kind(), std::io::ErrorKind::UnexpectedEof);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_handshake_result() {
|
fn test_handshake_result() {
|
||||||
let success: HandshakeResult<i32> = HandshakeResult::Success(42);
|
let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
|
||||||
assert!(success.is_success());
|
assert!(success.is_success());
|
||||||
assert!(!success.is_bad_client());
|
assert!(!success.is_bad_client());
|
||||||
|
|
||||||
let bad: HandshakeResult<i32> = HandshakeResult::BadClient;
|
let bad: HandshakeResult<i32, (), ()> = HandshakeResult::BadClient { reader: (), writer: () };
|
||||||
assert!(!bad.is_success());
|
assert!(!bad.is_success());
|
||||||
assert!(bad.is_bad_client());
|
assert!(bad.is_bad_client());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_handshake_result_map() {
|
fn test_handshake_result_map() {
|
||||||
let success: HandshakeResult<i32> = HandshakeResult::Success(42);
|
let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
|
||||||
let mapped = success.map(|x| x * 2);
|
let mapped = success.map(|x| x * 2);
|
||||||
|
|
||||||
match mapped {
|
match mapped {
|
||||||
@@ -165,6 +396,15 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_proxy_error_recoverable() {
|
||||||
|
let err = ProxyError::RateLimited;
|
||||||
|
assert!(err.is_recoverable());
|
||||||
|
|
||||||
|
let err = ProxyError::InvalidHandshake("bad".into());
|
||||||
|
assert!(!err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_error_display() {
|
fn test_error_display() {
|
||||||
let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() };
|
let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() };
|
||||||
|
|||||||
406
src/main.rs
406
src/main.rs
@@ -1,158 +1,306 @@
|
|||||||
//! Telemt - MTProxy on Rust
|
//! Telemt - MTProxy on Rust
|
||||||
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tracing::{info, error, Level};
|
use tokio::sync::Semaphore;
|
||||||
use tracing_subscriber::{FmtSubscriber, EnvFilter};
|
use tracing::{info, error, warn, debug};
|
||||||
|
use tracing_subscriber::{fmt, EnvFilter, reload, prelude::*};
|
||||||
|
|
||||||
mod error;
|
mod cli;
|
||||||
|
mod config;
|
||||||
mod crypto;
|
mod crypto;
|
||||||
|
mod error;
|
||||||
mod protocol;
|
mod protocol;
|
||||||
|
mod proxy;
|
||||||
|
mod stats;
|
||||||
mod stream;
|
mod stream;
|
||||||
mod transport;
|
mod transport;
|
||||||
mod proxy;
|
|
||||||
mod config;
|
|
||||||
mod stats;
|
|
||||||
mod util;
|
mod util;
|
||||||
|
|
||||||
use config::ProxyConfig;
|
use crate::config::{ProxyConfig, LogLevel};
|
||||||
use stats::{Stats, ReplayChecker};
|
use crate::proxy::ClientHandler;
|
||||||
use transport::ConnectionPool;
|
use crate::stats::{Stats, ReplayChecker};
|
||||||
use proxy::ClientHandler;
|
use crate::crypto::SecureRandom;
|
||||||
|
use crate::transport::{create_listener, ListenOptions, UpstreamManager};
|
||||||
|
use crate::util::ip::detect_ip;
|
||||||
|
use crate::stream::BufferPool;
|
||||||
|
|
||||||
#[tokio::main]
|
fn parse_cli() -> (String, bool, Option<String>) {
|
||||||
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
let mut config_path = "config.toml".to_string();
|
||||||
// Initialize logging with env filter
|
let mut silent = false;
|
||||||
// Use RUST_LOG=debug or RUST_LOG=trace for more details
|
let mut log_level: Option<String> = None;
|
||||||
let filter = EnvFilter::try_from_default_env()
|
|
||||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
|
||||||
|
|
||||||
let subscriber = FmtSubscriber::builder()
|
let args: Vec<String> = std::env::args().skip(1).collect();
|
||||||
.with_env_filter(filter)
|
|
||||||
.with_target(true)
|
|
||||||
.with_thread_ids(false)
|
|
||||||
.with_file(false)
|
|
||||||
.with_line_number(false)
|
|
||||||
.finish();
|
|
||||||
|
|
||||||
tracing::subscriber::set_global_default(subscriber)?;
|
// Check for --init first (handled before tokio)
|
||||||
|
if let Some(init_opts) = cli::parse_init_args(&args) {
|
||||||
// Load configuration
|
if let Err(e) = cli::run_init(init_opts) {
|
||||||
let config_path = std::env::args()
|
eprintln!("[telemt] Init failed: {}", e);
|
||||||
.nth(1)
|
|
||||||
.unwrap_or_else(|| "config.toml".to_string());
|
|
||||||
|
|
||||||
info!("Loading configuration from {}", config_path);
|
|
||||||
|
|
||||||
let config = ProxyConfig::load(&config_path).unwrap_or_else(|e| {
|
|
||||||
error!("Failed to load config: {}", e);
|
|
||||||
info!("Using default configuration");
|
|
||||||
ProxyConfig::default()
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Err(e) = config.validate() {
|
|
||||||
error!("Invalid configuration: {}", e);
|
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
}
|
}
|
||||||
|
std::process::exit(0);
|
||||||
let config = Arc::new(config);
|
|
||||||
|
|
||||||
info!("Starting MTProto Proxy on port {}", config.port);
|
|
||||||
info!("Fast mode: {}", config.fast_mode);
|
|
||||||
info!("Modes: classic={}, secure={}, tls={}",
|
|
||||||
config.modes.classic, config.modes.secure, config.modes.tls);
|
|
||||||
|
|
||||||
// Initialize components
|
|
||||||
let stats = Arc::new(Stats::new());
|
|
||||||
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len));
|
|
||||||
let pool = Arc::new(ConnectionPool::new());
|
|
||||||
|
|
||||||
// Create handler
|
|
||||||
let handler = Arc::new(ClientHandler::new(
|
|
||||||
Arc::clone(&config),
|
|
||||||
Arc::clone(&stats),
|
|
||||||
Arc::clone(&replay_checker),
|
|
||||||
Arc::clone(&pool),
|
|
||||||
));
|
|
||||||
|
|
||||||
// Start listener
|
|
||||||
let addr: SocketAddr = format!("{}:{}", config.listen_addr_ipv4, config.port)
|
|
||||||
.parse()?;
|
|
||||||
|
|
||||||
let listener = TcpListener::bind(addr).await?;
|
|
||||||
info!("Listening on {}", addr);
|
|
||||||
|
|
||||||
// Print proxy links
|
|
||||||
print_proxy_links(&config);
|
|
||||||
|
|
||||||
info!("Use RUST_LOG=debug or RUST_LOG=trace for more detailed logging");
|
|
||||||
|
|
||||||
// Main accept loop
|
|
||||||
let accept_loop = async {
|
|
||||||
loop {
|
|
||||||
match listener.accept().await {
|
|
||||||
Ok((stream, peer)) => {
|
|
||||||
let handler = Arc::clone(&handler);
|
|
||||||
tokio::spawn(async move {
|
|
||||||
handler.handle(stream, peer).await;
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut i = 0;
|
||||||
|
while i < args.len() {
|
||||||
|
match args[i].as_str() {
|
||||||
|
"--silent" | "-s" => { silent = true; }
|
||||||
|
"--log-level" => {
|
||||||
|
i += 1;
|
||||||
|
if i < args.len() { log_level = Some(args[i].clone()); }
|
||||||
|
}
|
||||||
|
s if s.starts_with("--log-level=") => {
|
||||||
|
log_level = Some(s.trim_start_matches("--log-level=").to_string());
|
||||||
|
}
|
||||||
|
"--help" | "-h" => {
|
||||||
|
eprintln!("Usage: telemt [config.toml] [OPTIONS]");
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("Options:");
|
||||||
|
eprintln!(" --silent, -s Suppress info logs");
|
||||||
|
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
|
||||||
|
eprintln!(" --help, -h Show this help");
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("Setup (fire-and-forget):");
|
||||||
|
eprintln!(" --init Generate config, install systemd service, start");
|
||||||
|
eprintln!(" --port <PORT> Listen port (default: 443)");
|
||||||
|
eprintln!(" --domain <DOMAIN> TLS domain for masking (default: www.google.com)");
|
||||||
|
eprintln!(" --secret <HEX> 32-char hex secret (auto-generated if omitted)");
|
||||||
|
eprintln!(" --user <NAME> Username (default: user)");
|
||||||
|
eprintln!(" --config-dir <DIR> Config directory (default: /etc/telemt)");
|
||||||
|
eprintln!(" --no-start Don't start the service after install");
|
||||||
|
std::process::exit(0);
|
||||||
|
}
|
||||||
|
s if !s.starts_with('-') => { config_path = s.to_string(); }
|
||||||
|
other => { eprintln!("Unknown option: {}", other); }
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
(config_path, silent, log_level)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let (config_path, cli_silent, cli_log_level) = parse_cli();
|
||||||
|
|
||||||
|
let config = match ProxyConfig::load(&config_path) {
|
||||||
|
Ok(c) => c,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Accept error: {}", e);
|
if std::path::Path::new(&config_path).exists() {
|
||||||
}
|
eprintln!("[telemt] Error: {}", e);
|
||||||
|
std::process::exit(1);
|
||||||
|
} else {
|
||||||
|
let default = ProxyConfig::default();
|
||||||
|
std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
|
||||||
|
eprintln!("[telemt] Created default config at {}", config_path);
|
||||||
|
default
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Graceful shutdown
|
if let Err(e) = config.validate() {
|
||||||
tokio::select! {
|
eprintln!("[telemt] Invalid config: {}", e);
|
||||||
_ = accept_loop => {}
|
std::process::exit(1);
|
||||||
_ = signal::ctrl_c() => {
|
}
|
||||||
info!("Shutting down...");
|
|
||||||
|
let has_rust_log = std::env::var("RUST_LOG").is_ok();
|
||||||
|
let effective_log_level = if cli_silent {
|
||||||
|
LogLevel::Silent
|
||||||
|
} else if let Some(ref s) = cli_log_level {
|
||||||
|
LogLevel::from_str_loose(s)
|
||||||
|
} else {
|
||||||
|
config.general.log_level.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start with INFO so startup messages are always visible,
|
||||||
|
// then switch to user-configured level after startup
|
||||||
|
let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info"));
|
||||||
|
tracing_subscriber::registry()
|
||||||
|
.with(filter_layer)
|
||||||
|
.with(fmt::Layer::default())
|
||||||
|
.init();
|
||||||
|
|
||||||
|
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
|
||||||
|
info!("Log level: {}", effective_log_level);
|
||||||
|
info!("Modes: classic={} secure={} tls={}",
|
||||||
|
config.general.modes.classic,
|
||||||
|
config.general.modes.secure,
|
||||||
|
config.general.modes.tls);
|
||||||
|
info!("TLS domain: {}", config.censorship.tls_domain);
|
||||||
|
info!("Mask: {} -> {}:{}",
|
||||||
|
config.censorship.mask,
|
||||||
|
config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain),
|
||||||
|
config.censorship.mask_port);
|
||||||
|
|
||||||
|
if config.censorship.tls_domain == "www.google.com" {
|
||||||
|
warn!("Using default tls_domain. Consider setting a custom domain.");
|
||||||
|
}
|
||||||
|
|
||||||
|
let prefer_ipv6 = config.general.prefer_ipv6;
|
||||||
|
let config = Arc::new(config);
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
|
||||||
|
let replay_checker = Arc::new(ReplayChecker::new(
|
||||||
|
config.access.replay_check_len,
|
||||||
|
Duration::from_secs(config.access.replay_window_secs),
|
||||||
|
));
|
||||||
|
|
||||||
|
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
|
||||||
|
let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
|
||||||
|
|
||||||
|
// Connection concurrency limit — prevents OOM under SYN flood / connection storm.
|
||||||
|
// 10000 is generous; each connection uses ~64KB (2x 16KB relay buffers + overhead).
|
||||||
|
// 10000 connections ≈ 640MB peak memory.
|
||||||
|
let max_connections = Arc::new(Semaphore::new(10_000));
|
||||||
|
|
||||||
|
// Startup DC ping
|
||||||
|
info!("=== Telegram DC Connectivity ===");
|
||||||
|
let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await;
|
||||||
|
for upstream_result in &ping_results {
|
||||||
|
info!(" via {}", upstream_result.upstream_name);
|
||||||
|
for dc in &upstream_result.results {
|
||||||
|
match (&dc.rtt_ms, &dc.error) {
|
||||||
|
(Some(rtt), _) => {
|
||||||
|
info!(" DC{} ({:>21}): {:.0}ms", dc.dc_idx, dc.dc_addr, rtt);
|
||||||
|
}
|
||||||
|
(None, Some(err)) => {
|
||||||
|
info!(" DC{} ({:>21}): FAIL ({})", dc.dc_idx, dc.dc_addr, err);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
info!(" DC{} ({:>21}): FAIL", dc.dc_idx, dc.dc_addr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info!("================================");
|
||||||
|
|
||||||
|
// Background tasks
|
||||||
|
let um_clone = upstream_manager.clone();
|
||||||
|
tokio::spawn(async move { um_clone.run_health_checks(prefer_ipv6).await; });
|
||||||
|
|
||||||
|
let rc_clone = replay_checker.clone();
|
||||||
|
tokio::spawn(async move { rc_clone.run_periodic_cleanup().await; });
|
||||||
|
|
||||||
|
let detected_ip = detect_ip().await;
|
||||||
|
debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6);
|
||||||
|
|
||||||
|
let mut listeners = Vec::new();
|
||||||
|
|
||||||
|
for listener_conf in &config.server.listeners {
|
||||||
|
let addr = SocketAddr::new(listener_conf.ip, config.server.port);
|
||||||
|
let options = ListenOptions {
|
||||||
|
ipv6_only: listener_conf.ip.is_ipv6(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
match create_listener(addr, &options) {
|
||||||
|
Ok(socket) => {
|
||||||
|
let listener = TcpListener::from_std(socket.into())?;
|
||||||
|
info!("Listening on {}", addr);
|
||||||
|
|
||||||
|
let public_ip = if let Some(ip) = listener_conf.announce_ip {
|
||||||
|
ip
|
||||||
|
} else if listener_conf.ip.is_unspecified() {
|
||||||
|
if listener_conf.ip.is_ipv4() {
|
||||||
|
detected_ip.ipv4.unwrap_or(listener_conf.ip)
|
||||||
|
} else {
|
||||||
|
detected_ip.ipv6.unwrap_or(listener_conf.ip)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
listener_conf.ip
|
||||||
|
};
|
||||||
|
|
||||||
|
if !config.show_link.is_empty() {
|
||||||
|
info!("--- Proxy Links ({}) ---", public_ip);
|
||||||
|
for user_name in &config.show_link {
|
||||||
|
if let Some(secret) = config.access.users.get(user_name) {
|
||||||
|
info!("User: {}", user_name);
|
||||||
|
if config.general.modes.classic {
|
||||||
|
info!(" Classic: tg://proxy?server={}&port={}&secret={}",
|
||||||
|
public_ip, config.server.port, secret);
|
||||||
|
}
|
||||||
|
if config.general.modes.secure {
|
||||||
|
info!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
|
||||||
|
public_ip, config.server.port, secret);
|
||||||
|
}
|
||||||
|
if config.general.modes.tls {
|
||||||
|
let domain_hex = hex::encode(&config.censorship.tls_domain);
|
||||||
|
info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
|
||||||
|
public_ip, config.server.port, secret, domain_hex);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!("User '{}' in show_link not found", user_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info!("------------------------");
|
||||||
|
}
|
||||||
|
|
||||||
|
listeners.push(listener);
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to bind to {}: {}", addr, e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cleanup
|
if listeners.is_empty() {
|
||||||
pool.close_all().await;
|
error!("No listeners. Exiting.");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switch to user-configured log level after startup
|
||||||
|
let runtime_filter = if has_rust_log {
|
||||||
|
EnvFilter::from_default_env()
|
||||||
|
} else {
|
||||||
|
EnvFilter::new(effective_log_level.to_filter_str())
|
||||||
|
};
|
||||||
|
filter_handle.reload(runtime_filter).expect("Failed to switch log filter");
|
||||||
|
|
||||||
|
for listener in listeners {
|
||||||
|
let config = config.clone();
|
||||||
|
let stats = stats.clone();
|
||||||
|
let upstream_manager = upstream_manager.clone();
|
||||||
|
let replay_checker = replay_checker.clone();
|
||||||
|
let buffer_pool = buffer_pool.clone();
|
||||||
|
let rng = rng.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
match listener.accept().await {
|
||||||
|
Ok((stream, peer_addr)) => {
|
||||||
|
let config = config.clone();
|
||||||
|
let stats = stats.clone();
|
||||||
|
let upstream_manager = upstream_manager.clone();
|
||||||
|
let replay_checker = replay_checker.clone();
|
||||||
|
let buffer_pool = buffer_pool.clone();
|
||||||
|
let rng = rng.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = ClientHandler::new(
|
||||||
|
stream, peer_addr, config, stats,
|
||||||
|
upstream_manager, replay_checker, buffer_pool, rng
|
||||||
|
).run().await {
|
||||||
|
debug!(peer = %peer_addr, error = %e, "Connection error");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Accept error: {}", e);
|
||||||
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
match signal::ctrl_c().await {
|
||||||
|
Ok(()) => info!("Shutting down..."),
|
||||||
|
Err(e) => error!("Signal error: {}", e),
|
||||||
|
}
|
||||||
|
|
||||||
info!("Goodbye!");
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn print_proxy_links(config: &ProxyConfig) {
|
|
||||||
println!("\n=== Proxy Links ===\n");
|
|
||||||
|
|
||||||
for (user, secret) in &config.users {
|
|
||||||
if config.modes.tls {
|
|
||||||
let tls_secret = format!(
|
|
||||||
"ee{}{}",
|
|
||||||
secret,
|
|
||||||
hex::encode(config.tls_domain.as_bytes())
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"{} (TLS): tg://proxy?server=IP&port={}&secret={}",
|
|
||||||
user, config.port, tls_secret
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.modes.secure {
|
|
||||||
println!(
|
|
||||||
"{} (Secure): tg://proxy?server=IP&port={}&secret=dd{}",
|
|
||||||
user, config.port, secret
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.modes.classic {
|
|
||||||
println!(
|
|
||||||
"{} (Classic): tg://proxy?server=IP&port={}&secret={}",
|
|
||||||
user, config.port, secret
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
println!();
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("===================\n");
|
|
||||||
}
|
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
//! Protocol constants and datacenter addresses
|
//! Protocol constants and datacenter addresses
|
||||||
|
|
||||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||||
use once_cell::sync::Lazy;
|
use std::sync::LazyLock;
|
||||||
|
|
||||||
// ============= Telegram Datacenters =============
|
// ============= Telegram Datacenters =============
|
||||||
|
|
||||||
pub const TG_DATACENTER_PORT: u16 = 443;
|
pub const TG_DATACENTER_PORT: u16 = 443;
|
||||||
|
|
||||||
pub static TG_DATACENTERS_V4: Lazy<Vec<IpAddr>> = Lazy::new(|| {
|
pub static TG_DATACENTERS_V4: LazyLock<Vec<IpAddr>> = LazyLock::new(|| {
|
||||||
vec![
|
vec![
|
||||||
IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)),
|
IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)),
|
||||||
IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)),
|
IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)),
|
||||||
@@ -17,7 +17,7 @@ pub static TG_DATACENTERS_V4: Lazy<Vec<IpAddr>> = Lazy::new(|| {
|
|||||||
]
|
]
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static TG_DATACENTERS_V6: Lazy<Vec<IpAddr>> = Lazy::new(|| {
|
pub static TG_DATACENTERS_V6: LazyLock<Vec<IpAddr>> = LazyLock::new(|| {
|
||||||
vec![
|
vec![
|
||||||
IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()),
|
IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()),
|
||||||
IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()),
|
IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()),
|
||||||
@@ -29,8 +29,8 @@ pub static TG_DATACENTERS_V6: Lazy<Vec<IpAddr>> = Lazy::new(|| {
|
|||||||
|
|
||||||
// ============= Middle Proxies (for advertising) =============
|
// ============= Middle Proxies (for advertising) =============
|
||||||
|
|
||||||
pub static TG_MIDDLE_PROXIES_V4: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
|
pub static TG_MIDDLE_PROXIES_V4: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
|
||||||
Lazy::new(|| {
|
LazyLock::new(|| {
|
||||||
let mut m = std::collections::HashMap::new();
|
let mut m = std::collections::HashMap::new();
|
||||||
m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
|
m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
|
||||||
m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
|
m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
|
||||||
@@ -45,8 +45,8 @@ pub static TG_MIDDLE_PROXIES_V4: Lazy<std::collections::HashMap<i32, Vec<(IpAddr
|
|||||||
m
|
m
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static TG_MIDDLE_PROXIES_V6: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
|
pub static TG_MIDDLE_PROXIES_V6: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
|
||||||
Lazy::new(|| {
|
LazyLock::new(|| {
|
||||||
let mut m = std::collections::HashMap::new();
|
let mut m = std::collections::HashMap::new();
|
||||||
m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
|
m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
|
||||||
m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
|
m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
|
||||||
@@ -167,7 +167,8 @@ pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300;
|
|||||||
// ============= Buffer Sizes =============
|
// ============= Buffer Sizes =============
|
||||||
|
|
||||||
/// Default buffer size
|
/// Default buffer size
|
||||||
pub const DEFAULT_BUFFER_SIZE: usize = 65536;
|
pub const DEFAULT_BUFFER_SIZE: usize = 16384;
|
||||||
|
|
||||||
/// Small buffer size for bad client handling
|
/// Small buffer size for bad client handling
|
||||||
pub const SMALL_BUFFER_SIZE: usize = 8192;
|
pub const SMALL_BUFFER_SIZE: usize = 8192;
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
//! MTProto Obfuscation
|
//! MTProto Obfuscation
|
||||||
|
|
||||||
|
use zeroize::Zeroize;
|
||||||
use crate::crypto::{sha256, AesCtr};
|
use crate::crypto::{sha256, AesCtr};
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
use super::constants::*;
|
use super::constants::*;
|
||||||
|
|
||||||
/// Obfuscation parameters from handshake
|
/// Obfuscation parameters from handshake
|
||||||
|
///
|
||||||
|
/// Key material is zeroized on drop.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ObfuscationParams {
|
pub struct ObfuscationParams {
|
||||||
/// Key for decrypting client -> proxy traffic
|
/// Key for decrypting client -> proxy traffic
|
||||||
@@ -21,25 +24,31 @@ pub struct ObfuscationParams {
|
|||||||
pub dc_idx: i16,
|
pub dc_idx: i16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Drop for ObfuscationParams {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.decrypt_key.zeroize();
|
||||||
|
self.decrypt_iv.zeroize();
|
||||||
|
self.encrypt_key.zeroize();
|
||||||
|
self.encrypt_iv.zeroize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ObfuscationParams {
|
impl ObfuscationParams {
|
||||||
/// Parse obfuscation parameters from handshake bytes
|
/// Parse obfuscation parameters from handshake bytes
|
||||||
/// Returns None if handshake doesn't match any user secret
|
/// Returns None if handshake doesn't match any user secret
|
||||||
pub fn from_handshake(
|
pub fn from_handshake(
|
||||||
handshake: &[u8; HANDSHAKE_LEN],
|
handshake: &[u8; HANDSHAKE_LEN],
|
||||||
secrets: &[(String, Vec<u8>)], // (username, secret_bytes)
|
secrets: &[(String, Vec<u8>)],
|
||||||
) -> Option<(Self, String)> {
|
) -> Option<(Self, String)> {
|
||||||
// Extract prekey and IV for decryption
|
|
||||||
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
|
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
|
||||||
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
|
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
|
||||||
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
|
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
|
||||||
|
|
||||||
// Reversed for encryption direction
|
|
||||||
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
|
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
|
||||||
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
|
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
|
||||||
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
|
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
|
||||||
|
|
||||||
for (username, secret) in secrets {
|
for (username, secret) in secrets {
|
||||||
// Derive decryption key
|
|
||||||
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
|
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
|
||||||
dec_key_input.extend_from_slice(dec_prekey);
|
dec_key_input.extend_from_slice(dec_prekey);
|
||||||
dec_key_input.extend_from_slice(secret);
|
dec_key_input.extend_from_slice(secret);
|
||||||
@@ -47,26 +56,22 @@ impl ObfuscationParams {
|
|||||||
|
|
||||||
let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
|
let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
|
||||||
|
|
||||||
// Create decryptor and decrypt handshake
|
|
||||||
let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv);
|
let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv);
|
||||||
let decrypted = decryptor.decrypt(handshake);
|
let decrypted = decryptor.decrypt(handshake);
|
||||||
|
|
||||||
// Check protocol tag
|
|
||||||
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
|
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
|
||||||
.try_into()
|
.try_into()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
|
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
|
||||||
Some(tag) => tag,
|
Some(tag) => tag,
|
||||||
None => continue, // Try next secret
|
None => continue,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Extract DC index
|
|
||||||
let dc_idx = i16::from_le_bytes(
|
let dc_idx = i16::from_le_bytes(
|
||||||
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
|
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Derive encryption key
|
|
||||||
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
|
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
|
||||||
enc_key_input.extend_from_slice(enc_prekey);
|
enc_key_input.extend_from_slice(enc_prekey);
|
||||||
enc_key_input.extend_from_slice(secret);
|
enc_key_input.extend_from_slice(secret);
|
||||||
@@ -123,18 +128,15 @@ pub fn generate_nonce<R: FnMut(usize) -> Vec<u8>>(mut random_bytes: R) -> [u8; H
|
|||||||
|
|
||||||
/// Check if nonce is valid (not matching reserved patterns)
|
/// Check if nonce is valid (not matching reserved patterns)
|
||||||
pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
|
pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
|
||||||
// Check first byte
|
|
||||||
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
|
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check first 4 bytes
|
|
||||||
let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
|
let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
|
||||||
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
|
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check bytes 4-7
|
|
||||||
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
|
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
|
||||||
if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
|
if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
|
||||||
return false;
|
return false;
|
||||||
@@ -147,12 +149,10 @@ pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
|
|||||||
pub fn prepare_tg_nonce(
|
pub fn prepare_tg_nonce(
|
||||||
nonce: &mut [u8; HANDSHAKE_LEN],
|
nonce: &mut [u8; HANDSHAKE_LEN],
|
||||||
proto_tag: ProtoTag,
|
proto_tag: ProtoTag,
|
||||||
enc_key_iv: Option<&[u8]>, // For fast mode
|
enc_key_iv: Option<&[u8]>,
|
||||||
) {
|
) {
|
||||||
// Set protocol tag
|
|
||||||
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
|
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
|
||||||
|
|
||||||
// For fast mode, copy the reversed enc_key_iv
|
|
||||||
if let Some(key_iv) = enc_key_iv {
|
if let Some(key_iv) = enc_key_iv {
|
||||||
let reversed: Vec<u8> = key_iv.iter().rev().copied().collect();
|
let reversed: Vec<u8> = key_iv.iter().rev().copied().collect();
|
||||||
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed);
|
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed);
|
||||||
@@ -161,14 +161,12 @@ pub fn prepare_tg_nonce(
|
|||||||
|
|
||||||
/// Encrypt the outgoing nonce for Telegram
|
/// Encrypt the outgoing nonce for Telegram
|
||||||
pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
|
pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
|
||||||
// Derive encryption key from the nonce itself
|
|
||||||
let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
||||||
let enc_key = sha256(key_iv);
|
let enc_key = sha256(key_iv);
|
||||||
let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap());
|
let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap());
|
||||||
|
|
||||||
let mut encryptor = AesCtr::new(&enc_key, enc_iv);
|
let mut encryptor = AesCtr::new(&enc_key, enc_iv);
|
||||||
|
|
||||||
// Only encrypt from PROTO_TAG_POS onwards
|
|
||||||
let mut result = nonce.to_vec();
|
let mut result = nonce.to_vec();
|
||||||
let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]);
|
let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]);
|
||||||
result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part);
|
result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part);
|
||||||
@@ -182,22 +180,18 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_is_valid_nonce() {
|
fn test_is_valid_nonce() {
|
||||||
// Valid nonce
|
|
||||||
let mut valid = [0x42u8; HANDSHAKE_LEN];
|
let mut valid = [0x42u8; HANDSHAKE_LEN];
|
||||||
valid[4..8].copy_from_slice(&[1, 2, 3, 4]);
|
valid[4..8].copy_from_slice(&[1, 2, 3, 4]);
|
||||||
assert!(is_valid_nonce(&valid));
|
assert!(is_valid_nonce(&valid));
|
||||||
|
|
||||||
// Invalid: starts with 0xef
|
|
||||||
let mut invalid = [0x00u8; HANDSHAKE_LEN];
|
let mut invalid = [0x00u8; HANDSHAKE_LEN];
|
||||||
invalid[0] = 0xef;
|
invalid[0] = 0xef;
|
||||||
assert!(!is_valid_nonce(&invalid));
|
assert!(!is_valid_nonce(&invalid));
|
||||||
|
|
||||||
// Invalid: starts with HEAD
|
|
||||||
let mut invalid = [0x00u8; HANDSHAKE_LEN];
|
let mut invalid = [0x00u8; HANDSHAKE_LEN];
|
||||||
invalid[..4].copy_from_slice(b"HEAD");
|
invalid[..4].copy_from_slice(b"HEAD");
|
||||||
assert!(!is_valid_nonce(&invalid));
|
assert!(!is_valid_nonce(&invalid));
|
||||||
|
|
||||||
// Invalid: bytes 4-7 are zeros
|
|
||||||
let mut invalid = [0x42u8; HANDSHAKE_LEN];
|
let mut invalid = [0x42u8; HANDSHAKE_LEN];
|
||||||
invalid[4..8].copy_from_slice(&[0, 0, 0, 0]);
|
invalid[4..8].copy_from_slice(&[0, 0, 0, 0]);
|
||||||
assert!(!is_valid_nonce(&invalid));
|
assert!(!is_valid_nonce(&invalid));
|
||||||
|
|||||||
@@ -1,14 +1,22 @@
|
|||||||
//! Fake TLS 1.3 Handshake
|
//! Fake TLS 1.3 Handshake
|
||||||
|
//!
|
||||||
|
//! This module handles the fake TLS 1.3 handshake used by MTProto proxy
|
||||||
|
//! for domain fronting. The handshake looks like valid TLS 1.3 but
|
||||||
|
//! actually carries MTProto authentication data.
|
||||||
|
|
||||||
use crate::crypto::{sha256_hmac, random::SECURE_RANDOM};
|
use crate::crypto::{sha256_hmac, SecureRandom};
|
||||||
use crate::error::{ProxyError, Result};
|
use crate::error::{ProxyError, Result};
|
||||||
use super::constants::*;
|
use super::constants::*;
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
// ============= Public Constants =============
|
||||||
|
|
||||||
/// TLS handshake digest length
|
/// TLS handshake digest length
|
||||||
pub const TLS_DIGEST_LEN: usize = 32;
|
pub const TLS_DIGEST_LEN: usize = 32;
|
||||||
|
|
||||||
/// Position of digest in TLS ClientHello
|
/// Position of digest in TLS ClientHello
|
||||||
pub const TLS_DIGEST_POS: usize = 11;
|
pub const TLS_DIGEST_POS: usize = 11;
|
||||||
|
|
||||||
/// Length to store for replay protection (first 16 bytes of digest)
|
/// Length to store for replay protection (first 16 bytes of digest)
|
||||||
pub const TLS_DIGEST_HALF_LEN: usize = 16;
|
pub const TLS_DIGEST_HALF_LEN: usize = 16;
|
||||||
|
|
||||||
@@ -16,6 +24,26 @@ pub const TLS_DIGEST_HALF_LEN: usize = 16;
|
|||||||
pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before
|
pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before
|
||||||
pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after
|
pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after
|
||||||
|
|
||||||
|
// ============= Private Constants =============
|
||||||
|
|
||||||
|
/// TLS Extension types
|
||||||
|
mod extension_type {
|
||||||
|
pub const KEY_SHARE: u16 = 0x0033;
|
||||||
|
pub const SUPPORTED_VERSIONS: u16 = 0x002b;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TLS Cipher Suites
|
||||||
|
mod cipher_suite {
|
||||||
|
pub const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TLS Named Curves
|
||||||
|
mod named_curve {
|
||||||
|
pub const X25519: u16 = 0x001d;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= TLS Validation Result =============
|
||||||
|
|
||||||
/// Result of validating TLS handshake
|
/// Result of validating TLS handshake
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct TlsValidation {
|
pub struct TlsValidation {
|
||||||
@@ -29,7 +57,185 @@ pub struct TlsValidation {
|
|||||||
pub timestamp: u32,
|
pub timestamp: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============= TLS Extension Builder =============
|
||||||
|
|
||||||
|
/// Builder for TLS extensions with correct length calculation
|
||||||
|
struct TlsExtensionBuilder {
|
||||||
|
extensions: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TlsExtensionBuilder {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
extensions: Vec::with_capacity(128),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add Key Share extension with X25519 key
|
||||||
|
fn add_key_share(&mut self, public_key: &[u8; 32]) -> &mut Self {
|
||||||
|
// Extension type: key_share (0x0033)
|
||||||
|
self.extensions.extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes());
|
||||||
|
|
||||||
|
// Key share entry: curve (2) + key_len (2) + key (32) = 36 bytes
|
||||||
|
// Extension data length
|
||||||
|
let entry_len: u16 = 2 + 2 + 32; // curve + length + key
|
||||||
|
self.extensions.extend_from_slice(&entry_len.to_be_bytes());
|
||||||
|
|
||||||
|
// Named curve: x25519
|
||||||
|
self.extensions.extend_from_slice(&named_curve::X25519.to_be_bytes());
|
||||||
|
|
||||||
|
// Key length
|
||||||
|
self.extensions.extend_from_slice(&(32u16).to_be_bytes());
|
||||||
|
|
||||||
|
// Key data
|
||||||
|
self.extensions.extend_from_slice(public_key);
|
||||||
|
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add Supported Versions extension
|
||||||
|
fn add_supported_versions(&mut self, version: u16) -> &mut Self {
|
||||||
|
// Extension type: supported_versions (0x002b)
|
||||||
|
self.extensions.extend_from_slice(&extension_type::SUPPORTED_VERSIONS.to_be_bytes());
|
||||||
|
|
||||||
|
// Extension data: length (2) + version (2)
|
||||||
|
self.extensions.extend_from_slice(&(2u16).to_be_bytes());
|
||||||
|
|
||||||
|
// Selected version
|
||||||
|
self.extensions.extend_from_slice(&version.to_be_bytes());
|
||||||
|
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build final extensions with length prefix
|
||||||
|
fn build(self) -> Vec<u8> {
|
||||||
|
let mut result = Vec::with_capacity(2 + self.extensions.len());
|
||||||
|
|
||||||
|
// Extensions length (2 bytes)
|
||||||
|
let len = self.extensions.len() as u16;
|
||||||
|
result.extend_from_slice(&len.to_be_bytes());
|
||||||
|
|
||||||
|
// Extensions data
|
||||||
|
result.extend_from_slice(&self.extensions);
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get current extensions without length prefix (for calculation)
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn as_bytes(&self) -> &[u8] {
|
||||||
|
&self.extensions
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= ServerHello Builder =============
|
||||||
|
|
||||||
|
/// Builder for TLS ServerHello with correct structure
|
||||||
|
struct ServerHelloBuilder {
|
||||||
|
/// Random bytes (32 bytes, will contain digest)
|
||||||
|
random: [u8; 32],
|
||||||
|
/// Session ID (echoed from ClientHello)
|
||||||
|
session_id: Vec<u8>,
|
||||||
|
/// Cipher suite
|
||||||
|
cipher_suite: [u8; 2],
|
||||||
|
/// Compression method
|
||||||
|
compression: u8,
|
||||||
|
/// Extensions
|
||||||
|
extensions: TlsExtensionBuilder,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServerHelloBuilder {
|
||||||
|
fn new(session_id: Vec<u8>) -> Self {
|
||||||
|
Self {
|
||||||
|
random: [0u8; 32],
|
||||||
|
session_id,
|
||||||
|
cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256,
|
||||||
|
compression: 0x00,
|
||||||
|
extensions: TlsExtensionBuilder::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn with_x25519_key(mut self, key: &[u8; 32]) -> Self {
|
||||||
|
self.extensions.add_key_share(key);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn with_tls13_version(mut self) -> Self {
|
||||||
|
// TLS 1.3 = 0x0304
|
||||||
|
self.extensions.add_supported_versions(0x0304);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build ServerHello message (without record header)
|
||||||
|
fn build_message(&self) -> Vec<u8> {
|
||||||
|
let extensions = self.extensions.extensions.clone();
|
||||||
|
let extensions_len = extensions.len() as u16;
|
||||||
|
|
||||||
|
// Calculate total length
|
||||||
|
let body_len = 2 + // version
|
||||||
|
32 + // random
|
||||||
|
1 + self.session_id.len() + // session_id length + data
|
||||||
|
2 + // cipher suite
|
||||||
|
1 + // compression
|
||||||
|
2 + extensions.len(); // extensions length + data
|
||||||
|
|
||||||
|
let mut message = Vec::with_capacity(4 + body_len);
|
||||||
|
|
||||||
|
// Handshake header
|
||||||
|
message.push(0x02); // ServerHello message type
|
||||||
|
|
||||||
|
// 3-byte length
|
||||||
|
let len_bytes = (body_len as u32).to_be_bytes();
|
||||||
|
message.extend_from_slice(&len_bytes[1..4]);
|
||||||
|
|
||||||
|
// Server version (TLS 1.2 in header, actual version in extension)
|
||||||
|
message.extend_from_slice(&TLS_VERSION);
|
||||||
|
|
||||||
|
// Random (32 bytes) - placeholder, will be replaced with digest
|
||||||
|
message.extend_from_slice(&self.random);
|
||||||
|
|
||||||
|
// Session ID
|
||||||
|
message.push(self.session_id.len() as u8);
|
||||||
|
message.extend_from_slice(&self.session_id);
|
||||||
|
|
||||||
|
// Cipher suite
|
||||||
|
message.extend_from_slice(&self.cipher_suite);
|
||||||
|
|
||||||
|
// Compression method
|
||||||
|
message.push(self.compression);
|
||||||
|
|
||||||
|
// Extensions length
|
||||||
|
message.extend_from_slice(&extensions_len.to_be_bytes());
|
||||||
|
|
||||||
|
// Extensions data
|
||||||
|
message.extend_from_slice(&extensions);
|
||||||
|
|
||||||
|
message
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build complete ServerHello TLS record
|
||||||
|
fn build_record(&self) -> Vec<u8> {
|
||||||
|
let message = self.build_message();
|
||||||
|
|
||||||
|
let mut record = Vec::with_capacity(5 + message.len());
|
||||||
|
|
||||||
|
// TLS record header
|
||||||
|
record.push(TLS_RECORD_HANDSHAKE);
|
||||||
|
record.extend_from_slice(&TLS_VERSION);
|
||||||
|
record.extend_from_slice(&(message.len() as u16).to_be_bytes());
|
||||||
|
|
||||||
|
// Message
|
||||||
|
record.extend_from_slice(&message);
|
||||||
|
|
||||||
|
record
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Public Functions =============
|
||||||
|
|
||||||
/// Validate TLS ClientHello against user secrets
|
/// Validate TLS ClientHello against user secrets
|
||||||
|
///
|
||||||
|
/// Returns validation result if a matching user is found.
|
||||||
pub fn validate_tls_handshake(
|
pub fn validate_tls_handshake(
|
||||||
handshake: &[u8],
|
handshake: &[u8],
|
||||||
secrets: &[(String, Vec<u8>)],
|
secrets: &[(String, Vec<u8>)],
|
||||||
@@ -86,7 +292,8 @@ pub fn validate_tls_handshake(
|
|||||||
// Check time skew
|
// Check time skew
|
||||||
if !ignore_time_skew {
|
if !ignore_time_skew {
|
||||||
// Allow very small timestamps (boot time instead of unix time)
|
// Allow very small timestamps (boot time instead of unix time)
|
||||||
let is_boot_time = timestamp < 60 * 60 * 24 * 1000;
|
// This is a quirk in some clients that use uptime instead of real time
|
||||||
|
let is_boot_time = timestamp < 60 * 60 * 24 * 1000; // < ~2.7 years in seconds
|
||||||
|
|
||||||
if !is_boot_time && (time_diff < TIME_SKEW_MIN || time_diff > TIME_SKEW_MAX) {
|
if !is_boot_time && (time_diff < TIME_SKEW_MIN || time_diff > TIME_SKEW_MAX) {
|
||||||
continue;
|
continue;
|
||||||
@@ -105,79 +312,73 @@ pub fn validate_tls_handshake(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Generate a fake X25519 public key for TLS
|
/// Generate a fake X25519 public key for TLS
|
||||||
/// This generates a value that looks like a valid X25519 key
|
///
|
||||||
pub fn gen_fake_x25519_key() -> [u8; 32] {
|
/// This generates random bytes that look like a valid X25519 public key.
|
||||||
// For simplicity, just generate random 32 bytes
|
/// Since we're not doing real TLS, the actual cryptographic properties don't matter.
|
||||||
// In real X25519, this would be a point on the curve
|
pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] {
|
||||||
let bytes = SECURE_RANDOM.bytes(32);
|
let bytes = rng.bytes(32);
|
||||||
bytes.try_into().unwrap()
|
bytes.try_into().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build TLS ServerHello response
|
/// Build TLS ServerHello response
|
||||||
|
///
|
||||||
|
/// This builds a complete TLS 1.3-like response including:
|
||||||
|
/// - ServerHello record with extensions
|
||||||
|
/// - Change Cipher Spec record
|
||||||
|
/// - Fake encrypted certificate (Application Data record)
|
||||||
|
///
|
||||||
|
/// The response includes an HMAC digest that the client can verify.
|
||||||
pub fn build_server_hello(
|
pub fn build_server_hello(
|
||||||
secret: &[u8],
|
secret: &[u8],
|
||||||
client_digest: &[u8; TLS_DIGEST_LEN],
|
client_digest: &[u8; TLS_DIGEST_LEN],
|
||||||
session_id: &[u8],
|
session_id: &[u8],
|
||||||
fake_cert_len: usize,
|
fake_cert_len: usize,
|
||||||
|
rng: &SecureRandom,
|
||||||
) -> Vec<u8> {
|
) -> Vec<u8> {
|
||||||
let x25519_key = gen_fake_x25519_key();
|
let x25519_key = gen_fake_x25519_key(rng);
|
||||||
|
|
||||||
// TLS extensions
|
// Build ServerHello
|
||||||
let mut extensions = Vec::new();
|
let server_hello = ServerHelloBuilder::new(session_id.to_vec())
|
||||||
extensions.extend_from_slice(&[0x00, 0x2e]); // Extension length placeholder
|
.with_x25519_key(&x25519_key)
|
||||||
extensions.extend_from_slice(&[0x00, 0x33, 0x00, 0x24]); // Key share extension
|
.with_tls13_version()
|
||||||
extensions.extend_from_slice(&[0x00, 0x1d, 0x00, 0x20]); // X25519 curve
|
.build_record();
|
||||||
extensions.extend_from_slice(&x25519_key);
|
|
||||||
extensions.extend_from_slice(&[0x00, 0x2b, 0x00, 0x02, 0x03, 0x04]); // Supported versions
|
|
||||||
|
|
||||||
// ServerHello body
|
// Build Change Cipher Spec record
|
||||||
let mut srv_hello = Vec::new();
|
let change_cipher_spec = [
|
||||||
srv_hello.extend_from_slice(&TLS_VERSION);
|
|
||||||
srv_hello.extend_from_slice(&[0u8; TLS_DIGEST_LEN]); // Placeholder for digest
|
|
||||||
srv_hello.push(session_id.len() as u8);
|
|
||||||
srv_hello.extend_from_slice(session_id);
|
|
||||||
srv_hello.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256
|
|
||||||
srv_hello.push(0x00); // No compression
|
|
||||||
srv_hello.extend_from_slice(&extensions);
|
|
||||||
|
|
||||||
// Build complete packet
|
|
||||||
let mut hello_pkt = Vec::new();
|
|
||||||
|
|
||||||
// ServerHello record
|
|
||||||
hello_pkt.push(TLS_RECORD_HANDSHAKE);
|
|
||||||
hello_pkt.extend_from_slice(&TLS_VERSION);
|
|
||||||
hello_pkt.extend_from_slice(&((srv_hello.len() + 4) as u16).to_be_bytes());
|
|
||||||
hello_pkt.push(0x02); // ServerHello message type
|
|
||||||
let len_bytes = (srv_hello.len() as u32).to_be_bytes();
|
|
||||||
hello_pkt.extend_from_slice(&len_bytes[1..4]); // 3-byte length
|
|
||||||
hello_pkt.extend_from_slice(&srv_hello);
|
|
||||||
|
|
||||||
// Change Cipher Spec record
|
|
||||||
hello_pkt.extend_from_slice(&[
|
|
||||||
TLS_RECORD_CHANGE_CIPHER,
|
TLS_RECORD_CHANGE_CIPHER,
|
||||||
TLS_VERSION[0], TLS_VERSION[1],
|
TLS_VERSION[0], TLS_VERSION[1],
|
||||||
0x00, 0x01, 0x01
|
0x00, 0x01, // length = 1
|
||||||
]);
|
0x01, // CCS byte
|
||||||
|
];
|
||||||
|
|
||||||
// Application Data record (fake certificate)
|
// Build fake certificate (Application Data record)
|
||||||
let fake_cert = SECURE_RANDOM.bytes(fake_cert_len);
|
let fake_cert = rng.bytes(fake_cert_len);
|
||||||
hello_pkt.push(TLS_RECORD_APPLICATION);
|
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
|
||||||
hello_pkt.extend_from_slice(&TLS_VERSION);
|
app_data_record.push(TLS_RECORD_APPLICATION);
|
||||||
hello_pkt.extend_from_slice(&(fake_cert.len() as u16).to_be_bytes());
|
app_data_record.extend_from_slice(&TLS_VERSION);
|
||||||
hello_pkt.extend_from_slice(&fake_cert);
|
app_data_record.extend_from_slice(&(fake_cert_len as u16).to_be_bytes());
|
||||||
|
app_data_record.extend_from_slice(&fake_cert);
|
||||||
|
|
||||||
|
// Combine all records
|
||||||
|
let mut response = Vec::with_capacity(
|
||||||
|
server_hello.len() + change_cipher_spec.len() + app_data_record.len()
|
||||||
|
);
|
||||||
|
response.extend_from_slice(&server_hello);
|
||||||
|
response.extend_from_slice(&change_cipher_spec);
|
||||||
|
response.extend_from_slice(&app_data_record);
|
||||||
|
|
||||||
// Compute HMAC for the response
|
// Compute HMAC for the response
|
||||||
let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + hello_pkt.len());
|
let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + response.len());
|
||||||
hmac_input.extend_from_slice(client_digest);
|
hmac_input.extend_from_slice(client_digest);
|
||||||
hmac_input.extend_from_slice(&hello_pkt);
|
hmac_input.extend_from_slice(&response);
|
||||||
let response_digest = sha256_hmac(secret, &hmac_input);
|
let response_digest = sha256_hmac(secret, &hmac_input);
|
||||||
|
|
||||||
// Insert computed digest
|
// Insert computed digest into ServerHello
|
||||||
// Position: after record header (5) + message type/length (4) + version (2) = 11
|
// Position: record header (5) + message type (1) + length (3) + version (2) = 11
|
||||||
hello_pkt[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
|
response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
|
||||||
.copy_from_slice(&response_digest);
|
.copy_from_slice(&response_digest);
|
||||||
|
|
||||||
hello_pkt
|
response
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if bytes look like a TLS ClientHello
|
/// Check if bytes look like a TLS ClientHello
|
||||||
@@ -186,7 +387,7 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLS record header: 0x16 0x03 0x01
|
// TLS record header: 0x16 (handshake) 0x03 0x01 (TLS 1.0)
|
||||||
first_bytes[0] == TLS_RECORD_HANDSHAKE
|
first_bytes[0] == TLS_RECORD_HANDSHAKE
|
||||||
&& first_bytes[1] == 0x03
|
&& first_bytes[1] == 0x03
|
||||||
&& first_bytes[2] == 0x01
|
&& first_bytes[2] == 0x01
|
||||||
@@ -206,6 +407,61 @@ pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> {
|
|||||||
Some((record_type, length))
|
Some((record_type, length))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Validate a ServerHello response structure
|
||||||
|
///
|
||||||
|
/// This is useful for testing that our ServerHello is well-formed.
|
||||||
|
#[cfg(test)]
|
||||||
|
fn validate_server_hello_structure(data: &[u8]) -> Result<()> {
|
||||||
|
if data.len() < 5 {
|
||||||
|
return Err(ProxyError::InvalidTlsRecord {
|
||||||
|
record_type: 0,
|
||||||
|
version: [0, 0],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check record header
|
||||||
|
if data[0] != TLS_RECORD_HANDSHAKE {
|
||||||
|
return Err(ProxyError::InvalidTlsRecord {
|
||||||
|
record_type: data[0],
|
||||||
|
version: [data[1], data[2]],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check version
|
||||||
|
if data[1..3] != TLS_VERSION {
|
||||||
|
return Err(ProxyError::InvalidTlsRecord {
|
||||||
|
record_type: data[0],
|
||||||
|
version: [data[1], data[2]],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check record length
|
||||||
|
let record_len = u16::from_be_bytes([data[3], data[4]]) as usize;
|
||||||
|
if data.len() < 5 + record_len {
|
||||||
|
return Err(ProxyError::InvalidHandshake(
|
||||||
|
format!("ServerHello record truncated: expected {}, got {}",
|
||||||
|
5 + record_len, data.len())
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check message type
|
||||||
|
if data[5] != 0x02 {
|
||||||
|
return Err(ProxyError::InvalidHandshake(
|
||||||
|
format!("Expected ServerHello (0x02), got 0x{:02x}", data[5])
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse message length
|
||||||
|
let msg_len = u32::from_be_bytes([0, data[6], data[7], data[8]]) as usize;
|
||||||
|
if msg_len + 4 != record_len {
|
||||||
|
return Err(ProxyError::InvalidHandshake(
|
||||||
|
format!("Message length mismatch: {} + 4 != {}", msg_len, record_len)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -234,11 +490,155 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_gen_fake_x25519_key() {
|
fn test_gen_fake_x25519_key() {
|
||||||
let key1 = gen_fake_x25519_key();
|
let rng = SecureRandom::new();
|
||||||
let key2 = gen_fake_x25519_key();
|
let key1 = gen_fake_x25519_key(&rng);
|
||||||
|
let key2 = gen_fake_x25519_key(&rng);
|
||||||
|
|
||||||
assert_eq!(key1.len(), 32);
|
assert_eq!(key1.len(), 32);
|
||||||
assert_eq!(key2.len(), 32);
|
assert_eq!(key2.len(), 32);
|
||||||
assert_ne!(key1, key2); // Should be random
|
assert_ne!(key1, key2); // Should be random
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tls_extension_builder() {
|
||||||
|
let key = [0x42u8; 32];
|
||||||
|
|
||||||
|
let mut builder = TlsExtensionBuilder::new();
|
||||||
|
builder.add_key_share(&key);
|
||||||
|
builder.add_supported_versions(0x0304);
|
||||||
|
|
||||||
|
let result = builder.build();
|
||||||
|
|
||||||
|
// Check length prefix
|
||||||
|
let len = u16::from_be_bytes([result[0], result[1]]) as usize;
|
||||||
|
assert_eq!(len, result.len() - 2);
|
||||||
|
|
||||||
|
// Check key_share extension is present
|
||||||
|
assert!(result.len() > 40); // At least key share
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_server_hello_builder() {
|
||||||
|
let session_id = vec![0x01, 0x02, 0x03, 0x04];
|
||||||
|
let key = [0x55u8; 32];
|
||||||
|
|
||||||
|
let builder = ServerHelloBuilder::new(session_id.clone())
|
||||||
|
.with_x25519_key(&key)
|
||||||
|
.with_tls13_version();
|
||||||
|
|
||||||
|
let record = builder.build_record();
|
||||||
|
|
||||||
|
// Validate structure
|
||||||
|
validate_server_hello_structure(&record).expect("Invalid ServerHello structure");
|
||||||
|
|
||||||
|
// Check record type
|
||||||
|
assert_eq!(record[0], TLS_RECORD_HANDSHAKE);
|
||||||
|
|
||||||
|
// Check version
|
||||||
|
assert_eq!(&record[1..3], &TLS_VERSION);
|
||||||
|
|
||||||
|
// Check message type (ServerHello = 0x02)
|
||||||
|
assert_eq!(record[5], 0x02);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_server_hello_structure() {
|
||||||
|
let secret = b"test secret";
|
||||||
|
let client_digest = [0x42u8; 32];
|
||||||
|
let session_id = vec![0xAA; 32];
|
||||||
|
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng);
|
||||||
|
|
||||||
|
// Should have at least 3 records
|
||||||
|
assert!(response.len() > 100);
|
||||||
|
|
||||||
|
// First record should be ServerHello
|
||||||
|
assert_eq!(response[0], TLS_RECORD_HANDSHAKE);
|
||||||
|
|
||||||
|
// Validate ServerHello structure
|
||||||
|
validate_server_hello_structure(&response).expect("Invalid ServerHello");
|
||||||
|
|
||||||
|
// Find Change Cipher Spec
|
||||||
|
let server_hello_len = 5 + u16::from_be_bytes([response[3], response[4]]) as usize;
|
||||||
|
let ccs_start = server_hello_len;
|
||||||
|
|
||||||
|
assert!(response.len() > ccs_start + 6);
|
||||||
|
assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER);
|
||||||
|
|
||||||
|
// Find Application Data
|
||||||
|
let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize;
|
||||||
|
let app_start = ccs_start + ccs_len;
|
||||||
|
|
||||||
|
assert!(response.len() > app_start + 5);
|
||||||
|
assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_server_hello_digest() {
|
||||||
|
let secret = b"test secret key here";
|
||||||
|
let client_digest = [0x42u8; 32];
|
||||||
|
let session_id = vec![0xAA; 32];
|
||||||
|
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
|
||||||
|
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
|
||||||
|
|
||||||
|
// Digest position should have non-zero data
|
||||||
|
let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
|
||||||
|
assert!(!digest1.iter().all(|&b| b == 0));
|
||||||
|
|
||||||
|
// Different calls should have different digests (due to random cert)
|
||||||
|
let digest2 = &response2[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
|
||||||
|
assert_ne!(digest1, digest2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_server_hello_extensions_length() {
|
||||||
|
let session_id = vec![0x01; 32];
|
||||||
|
let key = [0x55u8; 32];
|
||||||
|
|
||||||
|
let builder = ServerHelloBuilder::new(session_id)
|
||||||
|
.with_x25519_key(&key)
|
||||||
|
.with_tls13_version();
|
||||||
|
|
||||||
|
let record = builder.build_record();
|
||||||
|
|
||||||
|
// Parse to find extensions
|
||||||
|
let msg_start = 5; // After record header
|
||||||
|
let msg_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize;
|
||||||
|
|
||||||
|
// Skip to session ID
|
||||||
|
let session_id_pos = msg_start + 4 + 2 + 32; // header(4) + version(2) + random(32)
|
||||||
|
let session_id_len = record[session_id_pos] as usize;
|
||||||
|
|
||||||
|
// Skip to extensions
|
||||||
|
let ext_len_pos = session_id_pos + 1 + session_id_len + 2 + 1; // session_id + cipher(2) + compression(1)
|
||||||
|
let ext_len = u16::from_be_bytes([record[ext_len_pos], record[ext_len_pos + 1]]) as usize;
|
||||||
|
|
||||||
|
// Verify extensions length matches actual data
|
||||||
|
let extensions_data = &record[ext_len_pos + 2..msg_start + 4 + msg_len];
|
||||||
|
assert_eq!(ext_len, extensions_data.len(),
|
||||||
|
"Extension length mismatch: declared {}, actual {}", ext_len, extensions_data.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_tls_handshake_format() {
|
||||||
|
// Build a minimal ClientHello-like structure
|
||||||
|
let mut handshake = vec![0u8; 100];
|
||||||
|
|
||||||
|
// Put a valid-looking digest at position 11
|
||||||
|
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
|
||||||
|
.copy_from_slice(&[0x42; 32]);
|
||||||
|
|
||||||
|
// Session ID length
|
||||||
|
handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32;
|
||||||
|
|
||||||
|
// This won't validate (wrong HMAC) but shouldn't panic
|
||||||
|
let secrets = vec![("test".to_string(), b"secret".to_vec())];
|
||||||
|
let result = validate_tls_handshake(&handshake, &secrets, true);
|
||||||
|
|
||||||
|
// Should return None (no match) but not panic
|
||||||
|
assert!(result.is_none());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -13,103 +13,104 @@ use crate::error::{ProxyError, Result, HandshakeResult};
|
|||||||
use crate::protocol::constants::*;
|
use crate::protocol::constants::*;
|
||||||
use crate::protocol::tls;
|
use crate::protocol::tls;
|
||||||
use crate::stats::{Stats, ReplayChecker};
|
use crate::stats::{Stats, ReplayChecker};
|
||||||
use crate::transport::{ConnectionPool, configure_client_socket};
|
use crate::transport::{configure_client_socket, UpstreamManager};
|
||||||
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter};
|
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool};
|
||||||
use crate::crypto::AesCtr;
|
use crate::crypto::{AesCtr, SecureRandom};
|
||||||
|
|
||||||
use super::handshake::{
|
use crate::proxy::handshake::{
|
||||||
handle_tls_handshake, handle_mtproto_handshake,
|
handle_tls_handshake, handle_mtproto_handshake,
|
||||||
HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce,
|
HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce,
|
||||||
};
|
};
|
||||||
use super::relay::relay_bidirectional;
|
use crate::proxy::relay::relay_bidirectional;
|
||||||
use super::masking::handle_bad_client;
|
use crate::proxy::masking::handle_bad_client;
|
||||||
|
|
||||||
/// Client connection handler
|
pub struct ClientHandler;
|
||||||
pub struct ClientHandler {
|
|
||||||
|
pub struct RunningClientHandler {
|
||||||
|
stream: TcpStream,
|
||||||
|
peer: SocketAddr,
|
||||||
config: Arc<ProxyConfig>,
|
config: Arc<ProxyConfig>,
|
||||||
stats: Arc<Stats>,
|
stats: Arc<Stats>,
|
||||||
replay_checker: Arc<ReplayChecker>,
|
replay_checker: Arc<ReplayChecker>,
|
||||||
pool: Arc<ConnectionPool>,
|
upstream_manager: Arc<UpstreamManager>,
|
||||||
|
buffer_pool: Arc<BufferPool>,
|
||||||
|
rng: Arc<SecureRandom>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClientHandler {
|
impl ClientHandler {
|
||||||
/// Create new client handler
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
|
stream: TcpStream,
|
||||||
|
peer: SocketAddr,
|
||||||
config: Arc<ProxyConfig>,
|
config: Arc<ProxyConfig>,
|
||||||
stats: Arc<Stats>,
|
stats: Arc<Stats>,
|
||||||
|
upstream_manager: Arc<UpstreamManager>,
|
||||||
replay_checker: Arc<ReplayChecker>,
|
replay_checker: Arc<ReplayChecker>,
|
||||||
pool: Arc<ConnectionPool>,
|
buffer_pool: Arc<BufferPool>,
|
||||||
) -> Self {
|
rng: Arc<SecureRandom>,
|
||||||
Self {
|
) -> RunningClientHandler {
|
||||||
config,
|
RunningClientHandler {
|
||||||
stats,
|
stream, peer, config, stats, replay_checker,
|
||||||
replay_checker,
|
upstream_manager, buffer_pool, rng,
|
||||||
pool,
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle a client connection
|
impl RunningClientHandler {
|
||||||
pub async fn handle(&self, stream: TcpStream, peer: SocketAddr) {
|
pub async fn run(mut self) -> Result<()> {
|
||||||
self.stats.increment_connects_all();
|
self.stats.increment_connects_all();
|
||||||
|
|
||||||
|
let peer = self.peer;
|
||||||
debug!(peer = %peer, "New connection");
|
debug!(peer = %peer, "New connection");
|
||||||
|
|
||||||
// Configure socket
|
|
||||||
if let Err(e) = configure_client_socket(
|
if let Err(e) = configure_client_socket(
|
||||||
&stream,
|
&self.stream,
|
||||||
self.config.client_keepalive,
|
self.config.timeouts.client_keepalive,
|
||||||
self.config.client_ack_timeout,
|
self.config.timeouts.client_ack,
|
||||||
) {
|
) {
|
||||||
debug!(peer = %peer, error = %e, "Failed to configure client socket");
|
debug!(peer = %peer, error = %e, "Failed to configure client socket");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform handshake with timeout
|
let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
|
||||||
let handshake_timeout = Duration::from_secs(self.config.client_handshake_timeout);
|
let stats = self.stats.clone();
|
||||||
|
|
||||||
let result = timeout(
|
let result = timeout(handshake_timeout, self.do_handshake()).await;
|
||||||
handshake_timeout,
|
|
||||||
self.do_handshake(stream, peer)
|
|
||||||
).await;
|
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(Ok(())) => {
|
Ok(Ok(())) => {
|
||||||
debug!(peer = %peer, "Connection handled successfully");
|
debug!(peer = %peer, "Connection handled successfully");
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
debug!(peer = %peer, error = %e, "Handshake failed");
|
debug!(peer = %peer, error = %e, "Handshake failed");
|
||||||
|
Err(e)
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
self.stats.increment_handshake_timeouts();
|
stats.increment_handshake_timeouts();
|
||||||
debug!(peer = %peer, "Handshake timeout");
|
debug!(peer = %peer, "Handshake timeout");
|
||||||
|
Err(ProxyError::TgHandshakeTimeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Perform handshake and relay
|
async fn do_handshake(mut self) -> Result<()> {
|
||||||
async fn do_handshake(&self, mut stream: TcpStream, peer: SocketAddr) -> Result<()> {
|
|
||||||
// Read first bytes to determine handshake type
|
|
||||||
let mut first_bytes = [0u8; 5];
|
let mut first_bytes = [0u8; 5];
|
||||||
stream.read_exact(&mut first_bytes).await?;
|
self.stream.read_exact(&mut first_bytes).await?;
|
||||||
|
|
||||||
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
||||||
|
let peer = self.peer;
|
||||||
|
|
||||||
debug!(peer = %peer, is_tls = is_tls, first_bytes = %hex::encode(&first_bytes), "Handshake type detected");
|
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
|
||||||
|
|
||||||
if is_tls {
|
if is_tls {
|
||||||
self.handle_tls_client(stream, peer, first_bytes).await
|
self.handle_tls_client(first_bytes).await
|
||||||
} else {
|
} else {
|
||||||
self.handle_direct_client(stream, peer, first_bytes).await
|
self.handle_direct_client(first_bytes).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle TLS-wrapped client
|
async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
|
||||||
async fn handle_tls_client(
|
let peer = self.peer;
|
||||||
&self,
|
|
||||||
mut stream: TcpStream,
|
|
||||||
peer: SocketAddr,
|
|
||||||
first_bytes: [u8; 5],
|
|
||||||
) -> Result<()> {
|
|
||||||
// Read TLS handshake length
|
|
||||||
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
||||||
|
|
||||||
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake");
|
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake");
|
||||||
@@ -117,113 +118,111 @@ impl ClientHandler {
|
|||||||
if tls_len < 512 {
|
if tls_len < 512 {
|
||||||
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
|
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
|
||||||
self.stats.increment_connects_bad();
|
self.stats.increment_connects_bad();
|
||||||
handle_bad_client(stream, &first_bytes, &self.config).await;
|
let (reader, writer) = self.stream.into_split();
|
||||||
|
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read full TLS handshake
|
|
||||||
let mut handshake = vec![0u8; 5 + tls_len];
|
let mut handshake = vec![0u8; 5 + tls_len];
|
||||||
handshake[..5].copy_from_slice(&first_bytes);
|
handshake[..5].copy_from_slice(&first_bytes);
|
||||||
stream.read_exact(&mut handshake[5..]).await?;
|
self.stream.read_exact(&mut handshake[5..]).await?;
|
||||||
|
|
||||||
// Split stream for reading/writing
|
let config = self.config.clone();
|
||||||
let (read_half, write_half) = stream.into_split();
|
let replay_checker = self.replay_checker.clone();
|
||||||
|
let stats = self.stats.clone();
|
||||||
|
let buffer_pool = self.buffer_pool.clone();
|
||||||
|
|
||||||
|
let (read_half, write_half) = self.stream.into_split();
|
||||||
|
|
||||||
// Handle TLS handshake
|
|
||||||
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
|
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
|
||||||
&handshake,
|
&handshake, read_half, write_half, peer,
|
||||||
read_half,
|
&config, &replay_checker, &self.rng,
|
||||||
write_half,
|
|
||||||
peer,
|
|
||||||
&self.config,
|
|
||||||
&self.replay_checker,
|
|
||||||
).await {
|
).await {
|
||||||
HandshakeResult::Success(result) => result,
|
HandshakeResult::Success(result) => result,
|
||||||
HandshakeResult::BadClient => {
|
HandshakeResult::BadClient { reader, writer } => {
|
||||||
self.stats.increment_connects_bad();
|
stats.increment_connects_bad();
|
||||||
|
handle_bad_client(reader, writer, &handshake, &config).await;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Read MTProto handshake through TLS
|
|
||||||
debug!(peer = %peer, "Reading MTProto handshake through TLS");
|
debug!(peer = %peer, "Reading MTProto handshake through TLS");
|
||||||
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
|
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
|
||||||
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
|
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
|
||||||
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
|
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
|
||||||
|
|
||||||
// Handle MTProto handshake
|
|
||||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||||
&mtproto_handshake,
|
&mtproto_handshake, tls_reader, tls_writer, peer,
|
||||||
tls_reader,
|
&config, &replay_checker, true,
|
||||||
tls_writer,
|
|
||||||
peer,
|
|
||||||
&self.config,
|
|
||||||
&self.replay_checker,
|
|
||||||
true,
|
|
||||||
).await {
|
).await {
|
||||||
HandshakeResult::Success(result) => result,
|
HandshakeResult::Success(result) => result,
|
||||||
HandshakeResult::BadClient => {
|
HandshakeResult::BadClient { reader: _, writer: _ } => {
|
||||||
self.stats.increment_connects_bad();
|
stats.increment_connects_bad();
|
||||||
|
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Handle authenticated client
|
Self::handle_authenticated_static(
|
||||||
self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await
|
crypto_reader, crypto_writer, success,
|
||||||
|
self.upstream_manager, self.stats, self.config,
|
||||||
|
buffer_pool, self.rng,
|
||||||
|
).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle direct (non-TLS) client
|
async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
|
||||||
async fn handle_direct_client(
|
let peer = self.peer;
|
||||||
&self,
|
|
||||||
mut stream: TcpStream,
|
if !self.config.general.modes.classic && !self.config.general.modes.secure {
|
||||||
peer: SocketAddr,
|
|
||||||
first_bytes: [u8; 5],
|
|
||||||
) -> Result<()> {
|
|
||||||
// Check if non-TLS modes are enabled
|
|
||||||
if !self.config.modes.classic && !self.config.modes.secure {
|
|
||||||
debug!(peer = %peer, "Non-TLS modes disabled");
|
debug!(peer = %peer, "Non-TLS modes disabled");
|
||||||
self.stats.increment_connects_bad();
|
self.stats.increment_connects_bad();
|
||||||
handle_bad_client(stream, &first_bytes, &self.config).await;
|
let (reader, writer) = self.stream.into_split();
|
||||||
|
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read rest of handshake
|
|
||||||
let mut handshake = [0u8; HANDSHAKE_LEN];
|
let mut handshake = [0u8; HANDSHAKE_LEN];
|
||||||
handshake[..5].copy_from_slice(&first_bytes);
|
handshake[..5].copy_from_slice(&first_bytes);
|
||||||
stream.read_exact(&mut handshake[5..]).await?;
|
self.stream.read_exact(&mut handshake[5..]).await?;
|
||||||
|
|
||||||
// Split stream
|
let config = self.config.clone();
|
||||||
let (read_half, write_half) = stream.into_split();
|
let replay_checker = self.replay_checker.clone();
|
||||||
|
let stats = self.stats.clone();
|
||||||
|
let buffer_pool = self.buffer_pool.clone();
|
||||||
|
|
||||||
|
let (read_half, write_half) = self.stream.into_split();
|
||||||
|
|
||||||
// Handle MTProto handshake
|
|
||||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||||
&handshake,
|
&handshake, read_half, write_half, peer,
|
||||||
read_half,
|
&config, &replay_checker, false,
|
||||||
write_half,
|
|
||||||
peer,
|
|
||||||
&self.config,
|
|
||||||
&self.replay_checker,
|
|
||||||
false,
|
|
||||||
).await {
|
).await {
|
||||||
HandshakeResult::Success(result) => result,
|
HandshakeResult::Success(result) => result,
|
||||||
HandshakeResult::BadClient => {
|
HandshakeResult::BadClient { reader, writer } => {
|
||||||
self.stats.increment_connects_bad();
|
stats.increment_connects_bad();
|
||||||
|
handle_bad_client(reader, writer, &handshake, &config).await;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
};
|
};
|
||||||
|
|
||||||
self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await
|
Self::handle_authenticated_static(
|
||||||
|
crypto_reader, crypto_writer, success,
|
||||||
|
self.upstream_manager, self.stats, self.config,
|
||||||
|
buffer_pool, self.rng,
|
||||||
|
).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle authenticated client - connect to Telegram and relay
|
async fn handle_authenticated_static<R, W>(
|
||||||
async fn handle_authenticated_inner<R, W>(
|
|
||||||
&self,
|
|
||||||
client_reader: CryptoReader<R>,
|
client_reader: CryptoReader<R>,
|
||||||
client_writer: CryptoWriter<W>,
|
client_writer: CryptoWriter<W>,
|
||||||
success: HandshakeSuccess,
|
success: HandshakeSuccess,
|
||||||
|
upstream_manager: Arc<UpstreamManager>,
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
config: Arc<ProxyConfig>,
|
||||||
|
buffer_pool: Arc<BufferPool>,
|
||||||
|
rng: Arc<SecureRandom>,
|
||||||
) -> Result<()>
|
) -> Result<()>
|
||||||
where
|
where
|
||||||
R: AsyncRead + Unpin + Send + 'static,
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
@@ -231,14 +230,12 @@ impl ClientHandler {
|
|||||||
{
|
{
|
||||||
let user = &success.user;
|
let user = &success.user;
|
||||||
|
|
||||||
// Check user limits
|
if let Err(e) = Self::check_user_limits_static(user, &config, &stats) {
|
||||||
if let Err(e) = self.check_user_limits(user) {
|
|
||||||
warn!(user = %user, error = %e, "User limit exceeded");
|
warn!(user = %user, error = %e, "User limit exceeded");
|
||||||
return Err(e);
|
return Err(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get datacenter address
|
let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?;
|
||||||
let dc_addr = self.get_dc_addr(success.dc_idx)?;
|
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
user = %user,
|
user = %user,
|
||||||
@@ -246,69 +243,54 @@ impl ClientHandler {
|
|||||||
dc = success.dc_idx,
|
dc = success.dc_idx,
|
||||||
dc_addr = %dc_addr,
|
dc_addr = %dc_addr,
|
||||||
proto = ?success.proto_tag,
|
proto = ?success.proto_tag,
|
||||||
fast_mode = self.config.fast_mode,
|
|
||||||
"Connecting to Telegram"
|
"Connecting to Telegram"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Connect to Telegram
|
// Pass dc_idx for latency-based upstream selection
|
||||||
let tg_stream = self.pool.get(dc_addr).await?;
|
let tg_stream = upstream_manager.connect(dc_addr, Some(success.dc_idx)).await?;
|
||||||
|
|
||||||
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected to Telegram, performing handshake");
|
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
|
||||||
|
|
||||||
// Perform Telegram handshake and get crypto streams
|
let (tg_reader, tg_writer) = Self::do_tg_handshake_static(
|
||||||
let (tg_reader, tg_writer) = self.do_tg_handshake(
|
tg_stream, &success, &config, rng.as_ref(),
|
||||||
tg_stream,
|
|
||||||
&success,
|
|
||||||
).await?;
|
).await?;
|
||||||
|
|
||||||
debug!(peer = %success.peer, "Telegram handshake complete, starting relay");
|
debug!(peer = %success.peer, "TG handshake complete, starting relay");
|
||||||
|
|
||||||
// Update stats
|
stats.increment_user_connects(user);
|
||||||
self.stats.increment_user_connects(user);
|
stats.increment_user_curr_connects(user);
|
||||||
self.stats.increment_user_curr_connects(user);
|
|
||||||
|
|
||||||
// Relay traffic - передаём Arc::clone(&self.stats)
|
|
||||||
let relay_result = relay_bidirectional(
|
let relay_result = relay_bidirectional(
|
||||||
client_reader,
|
client_reader, client_writer,
|
||||||
client_writer,
|
tg_reader, tg_writer,
|
||||||
tg_reader,
|
user, Arc::clone(&stats), buffer_pool,
|
||||||
tg_writer,
|
|
||||||
user,
|
|
||||||
Arc::clone(&self.stats),
|
|
||||||
).await;
|
).await;
|
||||||
|
|
||||||
// Update stats
|
stats.decrement_user_curr_connects(user);
|
||||||
self.stats.decrement_user_curr_connects(user);
|
|
||||||
|
|
||||||
match &relay_result {
|
match &relay_result {
|
||||||
Ok(()) => debug!(user = %user, peer = %success.peer, "Relay completed normally"),
|
Ok(()) => debug!(user = %user, "Relay completed"),
|
||||||
Err(e) => debug!(user = %user, peer = %success.peer, error = %e, "Relay ended with error"),
|
Err(e) => debug!(user = %user, error = %e, "Relay ended with error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
relay_result
|
relay_result
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check user limits (expiration, connection count, data quota)
|
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
|
||||||
fn check_user_limits(&self, user: &str) -> Result<()> {
|
if let Some(expiration) = config.access.user_expirations.get(user) {
|
||||||
// Check expiration
|
|
||||||
if let Some(expiration) = self.config.user_expirations.get(user) {
|
|
||||||
if chrono::Utc::now() > *expiration {
|
if chrono::Utc::now() > *expiration {
|
||||||
return Err(ProxyError::UserExpired { user: user.to_string() });
|
return Err(ProxyError::UserExpired { user: user.to_string() });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check connection limit
|
if let Some(limit) = config.access.user_max_tcp_conns.get(user) {
|
||||||
if let Some(limit) = self.config.user_max_tcp_conns.get(user) {
|
if stats.get_user_curr_connects(user) >= *limit as u64 {
|
||||||
let current = self.stats.get_user_curr_connects(user);
|
|
||||||
if current >= *limit as u64 {
|
|
||||||
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
|
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check data quota
|
if let Some(quota) = config.access.user_data_quota.get(user) {
|
||||||
if let Some(quota) = self.config.user_data_quota.get(user) {
|
if stats.get_user_total_octets(user) >= *quota {
|
||||||
let used = self.stats.get_user_total_octets(user);
|
|
||||||
if used >= *quota {
|
|
||||||
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
|
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -316,63 +298,105 @@ impl ClientHandler {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get datacenter address by index
|
/// Resolve DC index to a target address.
|
||||||
fn get_dc_addr(&self, dc_idx: i16) -> Result<SocketAddr> {
|
///
|
||||||
let idx = (dc_idx.abs() - 1) as usize;
|
/// Matches the C implementation's behavior exactly:
|
||||||
|
///
|
||||||
let datacenters = if self.config.prefer_ipv6 {
|
/// 1. Look up DC in known clusters (standard DCs ±1..±5)
|
||||||
|
/// 2. If not found and `force=1` → fall back to `default_cluster`
|
||||||
|
///
|
||||||
|
/// In the C code:
|
||||||
|
/// - `proxy-multi.conf` is downloaded from Telegram, contains only DC ±1..±5
|
||||||
|
/// - `default 2;` directive sets the default cluster
|
||||||
|
/// - `mf_cluster_lookup(CurConf, target_dc, 1)` returns default_cluster
|
||||||
|
/// for any unknown DC (like CDN DC 203)
|
||||||
|
///
|
||||||
|
/// So DC 203, DC 101, DC -300, etc. all route to the default DC (2).
|
||||||
|
/// There is NO modular arithmetic in the C implementation.
|
||||||
|
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
||||||
|
let datacenters = if config.general.prefer_ipv6 {
|
||||||
&*TG_DATACENTERS_V6
|
&*TG_DATACENTERS_V6
|
||||||
} else {
|
} else {
|
||||||
&*TG_DATACENTERS_V4
|
&*TG_DATACENTERS_V4
|
||||||
};
|
};
|
||||||
|
|
||||||
datacenters.get(idx)
|
let num_dcs = datacenters.len(); // 5
|
||||||
.map(|ip| SocketAddr::new(*ip, TG_DATACENTER_PORT))
|
|
||||||
.ok_or_else(|| ProxyError::InvalidHandshake(
|
// === Step 1: Check dc_overrides (like C's `proxy_for <dc> <ip>:<port>`) ===
|
||||||
format!("Invalid DC index: {}", dc_idx)
|
let dc_key = dc_idx.to_string();
|
||||||
))
|
if let Some(addr_str) = config.dc_overrides.get(&dc_key) {
|
||||||
|
match addr_str.parse::<SocketAddr>() {
|
||||||
|
Ok(addr) => {
|
||||||
|
debug!(dc_idx = dc_idx, addr = %addr, "Using DC override from config");
|
||||||
|
return Ok(addr);
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
warn!(dc_idx = dc_idx, addr_str = %addr_str,
|
||||||
|
"Invalid DC override address in config, ignoring");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Perform handshake with Telegram server
|
// === Step 2: Standard DCs ±1..±5 — direct lookup ===
|
||||||
/// Returns crypto reader and writer for TG connection
|
let abs_dc = dc_idx.unsigned_abs() as usize;
|
||||||
async fn do_tg_handshake(
|
if abs_dc >= 1 && abs_dc <= num_dcs {
|
||||||
&self,
|
return Ok(SocketAddr::new(datacenters[abs_dc - 1], TG_DATACENTER_PORT));
|
||||||
mut stream: TcpStream,
|
}
|
||||||
success: &HandshakeSuccess,
|
|
||||||
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
|
// === Step 3: Unknown DC — fall back to default_cluster ===
|
||||||
// Generate nonce with keys for TG
|
// Exactly like C's `mf_cluster_lookup(CurConf, target_dc, force=1)`
|
||||||
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
|
// which returns `MC->default_cluster` when the DC is not found.
|
||||||
success.proto_tag,
|
// Telegram's proxy-multi.conf uses `default 2;`
|
||||||
&success.dec_key, // Client's dec key
|
let default_dc = config.default_dc.unwrap_or(2) as usize;
|
||||||
success.dec_iv,
|
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
|
||||||
self.config.fast_mode,
|
default_dc - 1
|
||||||
|
} else {
|
||||||
|
1 // DC 2 (index 1) — matches Telegram's `default 2;`
|
||||||
|
};
|
||||||
|
|
||||||
|
info!(
|
||||||
|
original_dc = dc_idx,
|
||||||
|
fallback_dc = (fallback_idx + 1) as u16,
|
||||||
|
fallback_addr = %datacenters[fallback_idx],
|
||||||
|
"Special DC ---> default_cluster"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(SocketAddr::new(datacenters[fallback_idx], TG_DATACENTER_PORT))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn do_tg_handshake_static(
|
||||||
|
mut stream: TcpStream,
|
||||||
|
success: &HandshakeSuccess,
|
||||||
|
config: &ProxyConfig,
|
||||||
|
rng: &SecureRandom,
|
||||||
|
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
|
||||||
|
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
|
||||||
|
success.proto_tag,
|
||||||
|
&success.dec_key,
|
||||||
|
success.dec_iv,
|
||||||
|
rng,
|
||||||
|
config.general.fast_mode,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Encrypt nonce
|
|
||||||
let encrypted_nonce = encrypt_tg_nonce(&nonce);
|
let encrypted_nonce = encrypt_tg_nonce(&nonce);
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
peer = %success.peer,
|
peer = %success.peer,
|
||||||
nonce_head = %hex::encode(&nonce[..16]),
|
nonce_head = %hex::encode(&nonce[..16]),
|
||||||
encrypted_head = %hex::encode(&encrypted_nonce[..16]),
|
|
||||||
"Sending nonce to Telegram"
|
"Sending nonce to Telegram"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Send to Telegram
|
|
||||||
stream.write_all(&encrypted_nonce).await?;
|
stream.write_all(&encrypted_nonce).await?;
|
||||||
stream.flush().await?;
|
stream.flush().await?;
|
||||||
|
|
||||||
debug!(peer = %success.peer, "Nonce sent to Telegram");
|
|
||||||
|
|
||||||
// Split stream and wrap with crypto
|
|
||||||
let (read_half, write_half) = stream.into_split();
|
let (read_half, write_half) = stream.into_split();
|
||||||
|
|
||||||
let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv);
|
let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv);
|
||||||
let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv);
|
let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv);
|
||||||
|
|
||||||
let tg_reader = CryptoReader::new(read_half, decryptor);
|
Ok((
|
||||||
let tg_writer = CryptoWriter::new(write_half, encryptor);
|
CryptoReader::new(read_half, decryptor),
|
||||||
|
CryptoWriter::new(write_half, encryptor),
|
||||||
Ok((tg_reader, tg_writer))
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
//! MTProto Handshake Magics
|
//! MTProto Handshake
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||||
use tracing::{debug, warn, trace, info};
|
use tracing::{debug, warn, trace, info};
|
||||||
|
use zeroize::Zeroize;
|
||||||
|
|
||||||
use crate::crypto::{sha256, AesCtr};
|
use crate::crypto::{sha256, AesCtr, SecureRandom};
|
||||||
use crate::crypto::random::SECURE_RANDOM;
|
|
||||||
use crate::protocol::constants::*;
|
use crate::protocol::constants::*;
|
||||||
use crate::protocol::tls;
|
use crate::protocol::tls;
|
||||||
use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
|
use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
|
||||||
@@ -14,6 +14,9 @@ use crate::stats::ReplayChecker;
|
|||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
|
|
||||||
/// Result of successful handshake
|
/// Result of successful handshake
|
||||||
|
///
|
||||||
|
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
|
||||||
|
/// zeroized on drop.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct HandshakeSuccess {
|
pub struct HandshakeSuccess {
|
||||||
/// Authenticated user name
|
/// Authenticated user name
|
||||||
@@ -34,6 +37,15 @@ pub struct HandshakeSuccess {
|
|||||||
pub is_tls: bool,
|
pub is_tls: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Drop for HandshakeSuccess {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.dec_key.zeroize();
|
||||||
|
self.dec_iv.zeroize();
|
||||||
|
self.enc_key.zeroize();
|
||||||
|
self.enc_iv.zeroize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Handle fake TLS handshake
|
/// Handle fake TLS handshake
|
||||||
pub async fn handle_tls_handshake<R, W>(
|
pub async fn handle_tls_handshake<R, W>(
|
||||||
handshake: &[u8],
|
handshake: &[u8],
|
||||||
@@ -42,76 +54,74 @@ pub async fn handle_tls_handshake<R, W>(
|
|||||||
peer: SocketAddr,
|
peer: SocketAddr,
|
||||||
config: &ProxyConfig,
|
config: &ProxyConfig,
|
||||||
replay_checker: &ReplayChecker,
|
replay_checker: &ReplayChecker,
|
||||||
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String)>
|
rng: &SecureRandom,
|
||||||
|
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
|
||||||
where
|
where
|
||||||
R: AsyncRead + Unpin,
|
R: AsyncRead + Unpin,
|
||||||
W: AsyncWrite + Unpin,
|
W: AsyncWrite + Unpin,
|
||||||
{
|
{
|
||||||
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
|
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
|
||||||
|
|
||||||
// Check minimum length
|
|
||||||
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
|
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
|
||||||
debug!(peer = %peer, "TLS handshake too short");
|
debug!(peer = %peer, "TLS handshake too short");
|
||||||
return HandshakeResult::BadClient;
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract digest for replay check
|
|
||||||
let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN];
|
let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN];
|
||||||
let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN];
|
let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN];
|
||||||
|
|
||||||
// Check for replay
|
|
||||||
if replay_checker.check_tls_digest(digest_half) {
|
if replay_checker.check_tls_digest(digest_half) {
|
||||||
warn!(peer = %peer, "TLS replay attack detected");
|
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
|
||||||
return HandshakeResult::BadClient;
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build secrets list
|
let secrets: Vec<(String, Vec<u8>)> = config.access.users.iter()
|
||||||
let secrets: Vec<(String, Vec<u8>)> = config.users.iter()
|
|
||||||
.filter_map(|(name, hex)| {
|
.filter_map(|(name, hex)| {
|
||||||
hex::decode(hex).ok().map(|bytes| (name.clone(), bytes))
|
hex::decode(hex).ok().map(|bytes| (name.clone(), bytes))
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
debug!(peer = %peer, num_users = secrets.len(), "Validating TLS handshake against users");
|
|
||||||
|
|
||||||
// Validate handshake
|
|
||||||
let validation = match tls::validate_tls_handshake(
|
let validation = match tls::validate_tls_handshake(
|
||||||
handshake,
|
handshake,
|
||||||
&secrets,
|
&secrets,
|
||||||
config.ignore_time_skew,
|
config.access.ignore_time_skew,
|
||||||
) {
|
) {
|
||||||
Some(v) => v,
|
Some(v) => v,
|
||||||
None => {
|
None => {
|
||||||
debug!(peer = %peer, "TLS handshake validation failed - no matching user");
|
debug!(
|
||||||
return HandshakeResult::BadClient;
|
peer = %peer,
|
||||||
|
ignore_time_skew = config.access.ignore_time_skew,
|
||||||
|
"TLS handshake validation failed - no matching user or time skew"
|
||||||
|
);
|
||||||
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get secret for response
|
|
||||||
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
|
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
|
||||||
Some((_, s)) => s,
|
Some((_, s)) => s,
|
||||||
None => return HandshakeResult::BadClient,
|
None => return HandshakeResult::BadClient { reader, writer },
|
||||||
};
|
};
|
||||||
|
|
||||||
// Build and send response
|
|
||||||
let response = tls::build_server_hello(
|
let response = tls::build_server_hello(
|
||||||
secret,
|
secret,
|
||||||
&validation.digest,
|
&validation.digest,
|
||||||
&validation.session_id,
|
&validation.session_id,
|
||||||
config.fake_cert_len,
|
config.censorship.fake_cert_len,
|
||||||
|
rng,
|
||||||
);
|
);
|
||||||
|
|
||||||
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
|
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
|
||||||
|
|
||||||
if let Err(e) = writer.write_all(&response).await {
|
if let Err(e) = writer.write_all(&response).await {
|
||||||
|
warn!(peer = %peer, error = %e, "Failed to write TLS ServerHello");
|
||||||
return HandshakeResult::Error(ProxyError::Io(e));
|
return HandshakeResult::Error(ProxyError::Io(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(e) = writer.flush().await {
|
if let Err(e) = writer.flush().await {
|
||||||
|
warn!(peer = %peer, error = %e, "Failed to flush TLS ServerHello");
|
||||||
return HandshakeResult::Error(ProxyError::Io(e));
|
return HandshakeResult::Error(ProxyError::Io(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Record for replay protection
|
|
||||||
replay_checker.add_tls_digest(digest_half);
|
replay_checker.add_tls_digest(digest_half);
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
@@ -136,39 +146,28 @@ pub async fn handle_mtproto_handshake<R, W>(
|
|||||||
config: &ProxyConfig,
|
config: &ProxyConfig,
|
||||||
replay_checker: &ReplayChecker,
|
replay_checker: &ReplayChecker,
|
||||||
is_tls: bool,
|
is_tls: bool,
|
||||||
) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess)>
|
) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess), R, W>
|
||||||
where
|
where
|
||||||
R: AsyncRead + Unpin + Send,
|
R: AsyncRead + Unpin + Send,
|
||||||
W: AsyncWrite + Unpin + Send,
|
W: AsyncWrite + Unpin + Send,
|
||||||
{
|
{
|
||||||
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
|
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
|
||||||
|
|
||||||
// Extract prekey and IV
|
|
||||||
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
|
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
|
||||||
|
|
||||||
debug!(
|
|
||||||
peer = %peer,
|
|
||||||
dec_prekey_iv = %hex::encode(dec_prekey_iv),
|
|
||||||
"Extracted prekey+IV from handshake"
|
|
||||||
);
|
|
||||||
|
|
||||||
// Check for replay
|
|
||||||
if replay_checker.check_handshake(dec_prekey_iv) {
|
if replay_checker.check_handshake(dec_prekey_iv) {
|
||||||
warn!(peer = %peer, "MTProto replay attack detected");
|
warn!(peer = %peer, "MTProto replay attack detected");
|
||||||
return HandshakeResult::BadClient;
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reversed for encryption direction
|
|
||||||
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
|
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
|
||||||
|
|
||||||
// Try each user's secret
|
for (user, secret_hex) in &config.access.users {
|
||||||
for (user, secret_hex) in &config.users {
|
|
||||||
let secret = match hex::decode(secret_hex) {
|
let secret = match hex::decode(secret_hex) {
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(_) => continue,
|
Err(_) => continue,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Derive decryption key
|
|
||||||
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
|
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
|
||||||
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
|
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
|
||||||
|
|
||||||
@@ -179,38 +178,23 @@ where
|
|||||||
|
|
||||||
let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
|
let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
|
||||||
|
|
||||||
// Decrypt handshake to check protocol tag
|
|
||||||
let mut decryptor = AesCtr::new(&dec_key, dec_iv);
|
let mut decryptor = AesCtr::new(&dec_key, dec_iv);
|
||||||
let decrypted = decryptor.decrypt(handshake);
|
let decrypted = decryptor.decrypt(handshake);
|
||||||
|
|
||||||
trace!(
|
|
||||||
peer = %peer,
|
|
||||||
user = %user,
|
|
||||||
decrypted_tail = %hex::encode(&decrypted[PROTO_TAG_POS..]),
|
|
||||||
"Decrypted handshake tail"
|
|
||||||
);
|
|
||||||
|
|
||||||
// Check protocol tag
|
|
||||||
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
|
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
|
||||||
.try_into()
|
.try_into()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
|
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
|
||||||
Some(tag) => tag,
|
Some(tag) => tag,
|
||||||
None => {
|
None => continue,
|
||||||
trace!(peer = %peer, user = %user, tag = %hex::encode(tag_bytes), "Invalid proto tag");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
debug!(peer = %peer, user = %user, proto = ?proto_tag, "Found valid proto tag");
|
|
||||||
|
|
||||||
// Check if mode is enabled
|
|
||||||
let mode_ok = match proto_tag {
|
let mode_ok = match proto_tag {
|
||||||
ProtoTag::Secure => {
|
ProtoTag::Secure => {
|
||||||
if is_tls { config.modes.tls } else { config.modes.secure }
|
if is_tls { config.general.modes.tls } else { config.general.modes.secure }
|
||||||
}
|
}
|
||||||
ProtoTag::Intermediate | ProtoTag::Abridged => config.modes.classic,
|
ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic,
|
||||||
};
|
};
|
||||||
|
|
||||||
if !mode_ok {
|
if !mode_ok {
|
||||||
@@ -218,12 +202,10 @@ where
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract DC index
|
|
||||||
let dc_idx = i16::from_le_bytes(
|
let dc_idx = i16::from_le_bytes(
|
||||||
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
|
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Derive encryption key
|
|
||||||
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
|
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
|
||||||
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
|
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
|
||||||
|
|
||||||
@@ -234,10 +216,8 @@ where
|
|||||||
|
|
||||||
let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
|
let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
|
||||||
|
|
||||||
// Record for replay protection
|
|
||||||
replay_checker.add_handshake(dec_prekey_iv);
|
replay_checker.add_handshake(dec_prekey_iv);
|
||||||
|
|
||||||
// Create new cipher instances
|
|
||||||
let decryptor = AesCtr::new(&dec_key, dec_iv);
|
let decryptor = AesCtr::new(&dec_key, dec_iv);
|
||||||
let encryptor = AesCtr::new(&enc_key, enc_iv);
|
let encryptor = AesCtr::new(&enc_key, enc_iv);
|
||||||
|
|
||||||
@@ -270,56 +250,37 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
debug!(peer = %peer, "MTProto handshake: no matching user found");
|
debug!(peer = %peer, "MTProto handshake: no matching user found");
|
||||||
HandshakeResult::BadClient
|
HandshakeResult::BadClient { reader, writer }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate nonce for Telegram connection
|
/// Generate nonce for Telegram connection
|
||||||
///
|
|
||||||
/// In FAST MODE: we use the same keys for TG as for client, but reversed.
|
|
||||||
/// This means: client's enc_key becomes TG's dec_key and vice versa.
|
|
||||||
pub fn generate_tg_nonce(
|
pub fn generate_tg_nonce(
|
||||||
proto_tag: ProtoTag,
|
proto_tag: ProtoTag,
|
||||||
client_dec_key: &[u8; 32],
|
client_dec_key: &[u8; 32],
|
||||||
client_dec_iv: u128,
|
client_dec_iv: u128,
|
||||||
|
rng: &SecureRandom,
|
||||||
fast_mode: bool,
|
fast_mode: bool,
|
||||||
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
|
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
|
||||||
loop {
|
loop {
|
||||||
let bytes = SECURE_RANDOM.bytes(HANDSHAKE_LEN);
|
let bytes = rng.bytes(HANDSHAKE_LEN);
|
||||||
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
|
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
|
||||||
|
|
||||||
// Check reserved patterns
|
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
|
||||||
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
|
let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
|
||||||
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
|
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; }
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
|
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
|
||||||
if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
|
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; }
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set protocol tag
|
|
||||||
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
|
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
|
||||||
|
|
||||||
// Fast mode: copy client's dec_key+iv (this becomes TG's enc direction)
|
|
||||||
// In fast mode, we make TG use the same keys as client but swapped:
|
|
||||||
// - What we decrypt FROM TG = what we encrypt TO client (so no re-encryption needed)
|
|
||||||
// - What we encrypt TO TG = what we decrypt FROM client
|
|
||||||
if fast_mode {
|
if fast_mode {
|
||||||
// Put client's dec_key + dec_iv into nonce[8:56]
|
|
||||||
// This will be used by TG for encryption TO us
|
|
||||||
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key);
|
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key);
|
||||||
nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN]
|
nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN]
|
||||||
.copy_from_slice(&client_dec_iv.to_be_bytes());
|
.copy_from_slice(&client_dec_iv.to_be_bytes());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now compute what keys WE will use for TG connection
|
|
||||||
// enc_key_iv = nonce[8:56] (for encrypting TO TG)
|
|
||||||
// dec_key_iv = nonce[8:56] reversed (for decrypting FROM TG)
|
|
||||||
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
||||||
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
|
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
|
||||||
|
|
||||||
@@ -329,44 +290,22 @@ pub fn generate_tg_nonce(
|
|||||||
let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
|
let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
|
||||||
let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
|
let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
|
||||||
|
|
||||||
debug!(
|
|
||||||
fast_mode = fast_mode,
|
|
||||||
tg_enc_key = %hex::encode(&tg_enc_key[..8]),
|
|
||||||
tg_dec_key = %hex::encode(&tg_dec_key[..8]),
|
|
||||||
"Generated TG nonce"
|
|
||||||
);
|
|
||||||
|
|
||||||
return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv);
|
return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Encrypt nonce for sending to Telegram
|
/// Encrypt nonce for sending to Telegram
|
||||||
///
|
|
||||||
/// Only the part from PROTO_TAG_POS onwards is encrypted.
|
|
||||||
/// The encryption key is derived from enc_key_iv in the nonce itself.
|
|
||||||
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
|
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
|
||||||
// enc_key_iv is at nonce[8:56]
|
|
||||||
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
||||||
|
|
||||||
// Key for encrypting is just the first 32 bytes of enc_key_iv
|
|
||||||
let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
|
let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
|
||||||
let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
|
let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
|
||||||
|
|
||||||
let mut encryptor = AesCtr::new(&key, iv);
|
let mut encryptor = AesCtr::new(&key, iv);
|
||||||
|
|
||||||
// Encrypt the entire nonce first, then take only the encrypted tail
|
|
||||||
let encrypted_full = encryptor.encrypt(nonce);
|
let encrypted_full = encryptor.encrypt(nonce);
|
||||||
|
|
||||||
// Result: unencrypted head + encrypted tail
|
|
||||||
let mut result = nonce[..PROTO_TAG_POS].to_vec();
|
let mut result = nonce[..PROTO_TAG_POS].to_vec();
|
||||||
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
|
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
|
||||||
|
|
||||||
trace!(
|
|
||||||
original = %hex::encode(&nonce[PROTO_TAG_POS..]),
|
|
||||||
encrypted = %hex::encode(&result[PROTO_TAG_POS..]),
|
|
||||||
"Encrypted nonce tail"
|
|
||||||
);
|
|
||||||
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -379,13 +318,12 @@ mod tests {
|
|||||||
let client_dec_key = [0x42u8; 32];
|
let client_dec_key = [0x42u8; 32];
|
||||||
let client_dec_iv = 12345u128;
|
let client_dec_iv = 12345u128;
|
||||||
|
|
||||||
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) =
|
let rng = SecureRandom::new();
|
||||||
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false);
|
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) =
|
||||||
|
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false);
|
||||||
|
|
||||||
// Check length
|
|
||||||
assert_eq!(nonce.len(), HANDSHAKE_LEN);
|
assert_eq!(nonce.len(), HANDSHAKE_LEN);
|
||||||
|
|
||||||
// Check proto tag is set
|
|
||||||
let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap();
|
let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap();
|
||||||
assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure));
|
assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure));
|
||||||
}
|
}
|
||||||
@@ -395,17 +333,35 @@ mod tests {
|
|||||||
let client_dec_key = [0x42u8; 32];
|
let client_dec_key = [0x42u8; 32];
|
||||||
let client_dec_iv = 12345u128;
|
let client_dec_iv = 12345u128;
|
||||||
|
|
||||||
|
let rng = SecureRandom::new();
|
||||||
let (nonce, _, _, _, _) =
|
let (nonce, _, _, _, _) =
|
||||||
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false);
|
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false);
|
||||||
|
|
||||||
let encrypted = encrypt_tg_nonce(&nonce);
|
let encrypted = encrypt_tg_nonce(&nonce);
|
||||||
|
|
||||||
assert_eq!(encrypted.len(), HANDSHAKE_LEN);
|
assert_eq!(encrypted.len(), HANDSHAKE_LEN);
|
||||||
|
|
||||||
// First PROTO_TAG_POS bytes should be unchanged
|
|
||||||
assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]);
|
assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]);
|
||||||
|
|
||||||
// Rest should be different (encrypted)
|
|
||||||
assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]);
|
assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_handshake_success_zeroize_on_drop() {
|
||||||
|
let success = HandshakeSuccess {
|
||||||
|
user: "test".to_string(),
|
||||||
|
dc_idx: 2,
|
||||||
|
proto_tag: ProtoTag::Secure,
|
||||||
|
dec_key: [0xAA; 32],
|
||||||
|
dec_iv: 0xBBBBBBBB,
|
||||||
|
enc_key: [0xCC; 32],
|
||||||
|
enc_iv: 0xDDDDDDDD,
|
||||||
|
peer: "127.0.0.1:1234".parse().unwrap(),
|
||||||
|
is_tls: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(success.dec_key, [0xAA; 32]);
|
||||||
|
assert_eq!(success.enc_key, [0xCC; 32]);
|
||||||
|
|
||||||
|
drop(success);
|
||||||
|
// Drop impl zeroizes key material without panic
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -1,35 +1,76 @@
|
|||||||
//! Masking - forward unrecognized traffic to mask host
|
//! Masking - forward unrecognized traffic to mask host
|
||||||
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use std::str;
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
use crate::transport::set_linger_zero;
|
|
||||||
|
|
||||||
const MASK_TIMEOUT: Duration = Duration::from_secs(5);
|
const MASK_TIMEOUT: Duration = Duration::from_secs(5);
|
||||||
|
/// Maximum duration for the entire masking relay.
|
||||||
|
/// Limits resource consumption from slow-loris attacks and port scanners.
|
||||||
|
const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60);
|
||||||
const MASK_BUFFER_SIZE: usize = 8192;
|
const MASK_BUFFER_SIZE: usize = 8192;
|
||||||
|
|
||||||
|
/// Detect client type based on initial data
|
||||||
|
fn detect_client_type(data: &[u8]) -> &'static str {
|
||||||
|
// Check for HTTP request
|
||||||
|
if data.len() > 4 {
|
||||||
|
if data.starts_with(b"GET ") || data.starts_with(b"POST") ||
|
||||||
|
data.starts_with(b"HEAD") || data.starts_with(b"PUT ") ||
|
||||||
|
data.starts_with(b"DELETE") || data.starts_with(b"OPTIONS") {
|
||||||
|
return "HTTP";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for TLS ClientHello (0x16 = handshake, 0x03 0x01-0x03 = TLS version)
|
||||||
|
if data.len() > 3 && data[0] == 0x16 && data[1] == 0x03 {
|
||||||
|
return "TLS-scanner";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for SSH
|
||||||
|
if data.starts_with(b"SSH-") {
|
||||||
|
return "SSH";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Port scanner (very short data)
|
||||||
|
if data.len() < 10 {
|
||||||
|
return "port-scanner";
|
||||||
|
}
|
||||||
|
|
||||||
|
"unknown"
|
||||||
|
}
|
||||||
|
|
||||||
/// Handle a bad client by forwarding to mask host
|
/// Handle a bad client by forwarding to mask host
|
||||||
pub async fn handle_bad_client(
|
pub async fn handle_bad_client<R, W>(
|
||||||
mut client: TcpStream,
|
mut reader: R,
|
||||||
|
mut writer: W,
|
||||||
initial_data: &[u8],
|
initial_data: &[u8],
|
||||||
config: &ProxyConfig,
|
config: &ProxyConfig,
|
||||||
) {
|
)
|
||||||
if !config.mask {
|
where
|
||||||
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
if !config.censorship.mask {
|
||||||
// Masking disabled, just consume data
|
// Masking disabled, just consume data
|
||||||
consume_client_data(client).await;
|
consume_client_data(reader).await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mask_host = config.mask_host.as_deref()
|
let client_type = detect_client_type(initial_data);
|
||||||
.unwrap_or(&config.tls_domain);
|
|
||||||
let mask_port = config.mask_port;
|
let mask_host = config.censorship.mask_host.as_deref()
|
||||||
|
.unwrap_or(&config.censorship.tls_domain);
|
||||||
|
let mask_port = config.censorship.mask_port;
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
|
client_type = client_type,
|
||||||
host = %mask_host,
|
host = %mask_host,
|
||||||
port = mask_port,
|
port = mask_port,
|
||||||
|
data_len = initial_data.len(),
|
||||||
"Forwarding bad client to mask host"
|
"Forwarding bad client to mask host"
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -40,33 +81,32 @@ pub async fn handle_bad_client(
|
|||||||
TcpStream::connect(&mask_addr)
|
TcpStream::connect(&mask_addr)
|
||||||
).await;
|
).await;
|
||||||
|
|
||||||
let mut mask_stream = match connect_result {
|
let mask_stream = match connect_result {
|
||||||
Ok(Ok(s)) => s,
|
Ok(Ok(s)) => s,
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
debug!(error = %e, "Failed to connect to mask host");
|
debug!(error = %e, "Failed to connect to mask host");
|
||||||
consume_client_data(client).await;
|
consume_client_data(reader).await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
debug!("Timeout connecting to mask host");
|
debug!("Timeout connecting to mask host");
|
||||||
consume_client_data(client).await;
|
consume_client_data(reader).await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let (mut mask_read, mut mask_write) = mask_stream.into_split();
|
||||||
|
|
||||||
// Send initial data to mask host
|
// Send initial data to mask host
|
||||||
if mask_stream.write_all(initial_data).await.is_err() {
|
if mask_write.write_all(initial_data).await.is_err() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Relay traffic
|
// Relay traffic
|
||||||
let (mut client_read, mut client_write) = client.into_split();
|
|
||||||
let (mut mask_read, mut mask_write) = mask_stream.into_split();
|
|
||||||
|
|
||||||
let c2m = tokio::spawn(async move {
|
let c2m = tokio::spawn(async move {
|
||||||
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
|
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
|
||||||
loop {
|
loop {
|
||||||
match client_read.read(&mut buf).await {
|
match reader.read(&mut buf).await {
|
||||||
Ok(0) | Err(_) => {
|
Ok(0) | Err(_) => {
|
||||||
let _ = mask_write.shutdown().await;
|
let _ = mask_write.shutdown().await;
|
||||||
break;
|
break;
|
||||||
@@ -85,11 +125,11 @@ pub async fn handle_bad_client(
|
|||||||
loop {
|
loop {
|
||||||
match mask_read.read(&mut buf).await {
|
match mask_read.read(&mut buf).await {
|
||||||
Ok(0) | Err(_) => {
|
Ok(0) | Err(_) => {
|
||||||
let _ = client_write.shutdown().await;
|
let _ = writer.shutdown().await;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Ok(n) => {
|
Ok(n) => {
|
||||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
if writer.write_all(&buf[..n]).await.is_err() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -105,9 +145,9 @@ pub async fn handle_bad_client(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Just consume all data from client without responding
|
/// Just consume all data from client without responding
|
||||||
async fn consume_client_data(mut client: TcpStream) {
|
async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
|
||||||
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
|
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
|
||||||
while let Ok(n) = client.read(&mut buf).await {
|
while let Ok(n) = reader.read(&mut buf).await {
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,22 +1,320 @@
|
|||||||
//! Bidirectional Relay
|
//! Bidirectional Relay — poll-based, no head-of-line blocking
|
||||||
|
//!
|
||||||
|
//! ## What changed and why
|
||||||
|
//!
|
||||||
|
//! Previous implementation used a single-task `select! { biased; ... }` loop
|
||||||
|
//! where each branch called `write_all()`. This caused head-of-line blocking:
|
||||||
|
//! while `write_all()` waited for a slow writer (e.g. client on 3G downloading
|
||||||
|
//! media), the entire loop was blocked — the other direction couldn't make progress.
|
||||||
|
//!
|
||||||
|
//! Symptoms observed in production:
|
||||||
|
//! - Media loading at ~8 KB/s despite fast server connection
|
||||||
|
//! - Stop-and-go pattern with 50–500ms gaps between chunks
|
||||||
|
//! - `biased` select starving S→C direction
|
||||||
|
//! - Some users unable to load media at all
|
||||||
|
//!
|
||||||
|
//! ## New architecture
|
||||||
|
//!
|
||||||
|
//! Uses `tokio::io::copy_bidirectional` which polls both directions concurrently
|
||||||
|
//! in a single task via non-blocking `poll_read` / `poll_write` calls:
|
||||||
|
//!
|
||||||
|
//! Old (select! + write_all — BLOCKING):
|
||||||
|
//!
|
||||||
|
//! loop {
|
||||||
|
//! select! {
|
||||||
|
//! biased;
|
||||||
|
//! data = client.read() => { server.write_all(data).await; } ← BLOCKS here
|
||||||
|
//! data = server.read() => { client.write_all(data).await; } ← can't run
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! New (copy_bidirectional — CONCURRENT):
|
||||||
|
//!
|
||||||
|
//! poll(cx) {
|
||||||
|
//! // Both directions polled in the same poll cycle
|
||||||
|
//! C→S: poll_read(client) → poll_write(server) // non-blocking
|
||||||
|
//! S→C: poll_read(server) → poll_write(client) // non-blocking
|
||||||
|
//! // If one writer is Pending, the other direction still progresses
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! Benefits:
|
||||||
|
//! - No head-of-line blocking: slow client download doesn't block uploads
|
||||||
|
//! - No biased starvation: fair polling of both directions
|
||||||
|
//! - Proper flush: `copy_bidirectional` calls `poll_flush` when reader stalls,
|
||||||
|
//! so CryptoWriter's pending ciphertext is always drained (fixes "stuck at 95%")
|
||||||
|
//! - No deadlock risk: old write_all could deadlock when both TCP buffers filled;
|
||||||
|
//! poll-based approach lets TCP flow control work correctly
|
||||||
|
//!
|
||||||
|
//! Stats tracking:
|
||||||
|
//! - `StatsIo` wraps client side, intercepts `poll_read` / `poll_write`
|
||||||
|
//! - `poll_read` on client = C→S (client sending) → `octets_from`, `msgs_from`
|
||||||
|
//! - `poll_write` on client = S→C (to client) → `octets_to`, `msgs_to`
|
||||||
|
//! - `SharedCounters` (atomics) let the watchdog read stats without locking
|
||||||
|
|
||||||
|
use std::io;
|
||||||
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional};
|
||||||
|
use tokio::time::Instant;
|
||||||
use tracing::{debug, trace, warn};
|
use tracing::{debug, trace, warn};
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
use crate::stats::Stats;
|
use crate::stats::Stats;
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use crate::stream::BufferPool;
|
||||||
|
|
||||||
const BUFFER_SIZE: usize = 65536;
|
// ============= Constants =============
|
||||||
|
|
||||||
/// Relay data bidirectionally between client and server
|
/// Activity timeout for iOS compatibility.
|
||||||
|
///
|
||||||
|
/// iOS keeps Telegram connections alive in background for up to 30 minutes.
|
||||||
|
/// Closing earlier causes unnecessary reconnects and handshake overhead.
|
||||||
|
const ACTIVITY_TIMEOUT: Duration = Duration::from_secs(1800);
|
||||||
|
|
||||||
|
/// Watchdog check interval — also used for periodic rate logging.
|
||||||
|
///
|
||||||
|
/// 10 seconds gives responsive timeout detection (±10s accuracy)
|
||||||
|
/// without measurable overhead from atomic reads.
|
||||||
|
const WATCHDOG_INTERVAL: Duration = Duration::from_secs(10);
|
||||||
|
|
||||||
|
// ============= CombinedStream =============
|
||||||
|
|
||||||
|
/// Combines separate read and write halves into a single bidirectional stream.
|
||||||
|
///
|
||||||
|
/// `copy_bidirectional` requires `AsyncRead + AsyncWrite` on each side,
|
||||||
|
/// but the handshake layer produces split reader/writer pairs
|
||||||
|
/// (e.g. `CryptoReader<FakeTlsReader<OwnedReadHalf>>` + `CryptoWriter<...>`).
|
||||||
|
///
|
||||||
|
/// This wrapper reunifies them with zero overhead — each trait method
|
||||||
|
/// delegates directly to the corresponding half. No buffering, no copies.
|
||||||
|
///
|
||||||
|
/// Safety: `poll_read` only touches `reader`, `poll_write` only touches `writer`,
|
||||||
|
/// so there's no aliasing even though both are called on the same `&mut self`.
|
||||||
|
struct CombinedStream<R, W> {
|
||||||
|
reader: R,
|
||||||
|
writer: W,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R, W> CombinedStream<R, W> {
|
||||||
|
fn new(reader: R, writer: W) -> Self {
|
||||||
|
Self { reader, writer }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: AsyncRead + Unpin, W: Unpin> AsyncRead for CombinedStream<R, W> {
|
||||||
|
#[inline]
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.get_mut().reader).poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: Unpin, W: AsyncWrite + Unpin> AsyncWrite for CombinedStream<R, W> {
|
||||||
|
#[inline]
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.get_mut().writer).poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.get_mut().writer).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.get_mut().writer).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= SharedCounters =============
|
||||||
|
|
||||||
|
/// Atomic counters shared between the relay (via StatsIo) and the watchdog task.
|
||||||
|
///
|
||||||
|
/// Using `Relaxed` ordering is sufficient because:
|
||||||
|
/// - Counters are monotonically increasing (no ABA problem)
|
||||||
|
/// - Slight staleness in watchdog reads is harmless (±10s check interval anyway)
|
||||||
|
/// - No ordering dependencies between different counters
|
||||||
|
struct SharedCounters {
|
||||||
|
/// Bytes read from client (C→S direction)
|
||||||
|
c2s_bytes: AtomicU64,
|
||||||
|
/// Bytes written to client (S→C direction)
|
||||||
|
s2c_bytes: AtomicU64,
|
||||||
|
/// Number of poll_read completions (≈ C→S chunks)
|
||||||
|
c2s_ops: AtomicU64,
|
||||||
|
/// Number of poll_write completions (≈ S→C chunks)
|
||||||
|
s2c_ops: AtomicU64,
|
||||||
|
/// Milliseconds since relay epoch of last I/O activity
|
||||||
|
last_activity_ms: AtomicU64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SharedCounters {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
c2s_bytes: AtomicU64::new(0),
|
||||||
|
s2c_bytes: AtomicU64::new(0),
|
||||||
|
c2s_ops: AtomicU64::new(0),
|
||||||
|
s2c_ops: AtomicU64::new(0),
|
||||||
|
last_activity_ms: AtomicU64::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Record activity at this instant.
|
||||||
|
#[inline]
|
||||||
|
fn touch(&self, now: Instant, epoch: Instant) {
|
||||||
|
let ms = now.duration_since(epoch).as_millis() as u64;
|
||||||
|
self.last_activity_ms.store(ms, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// How long since last recorded activity.
|
||||||
|
fn idle_duration(&self, now: Instant, epoch: Instant) -> Duration {
|
||||||
|
let last_ms = self.last_activity_ms.load(Ordering::Relaxed);
|
||||||
|
let now_ms = now.duration_since(epoch).as_millis() as u64;
|
||||||
|
Duration::from_millis(now_ms.saturating_sub(last_ms))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= StatsIo =============
|
||||||
|
|
||||||
|
/// Transparent I/O wrapper that tracks per-user statistics and activity.
|
||||||
|
///
|
||||||
|
/// Wraps the **client** side of the relay. Direction mapping:
|
||||||
|
///
|
||||||
|
/// | poll method | direction | stats updated |
|
||||||
|
/// |-------------|-----------|--------------------------------------|
|
||||||
|
/// | `poll_read` | C→S | `octets_from`, `msgs_from`, counters |
|
||||||
|
/// | `poll_write` | S→C | `octets_to`, `msgs_to`, counters |
|
||||||
|
///
|
||||||
|
/// Both update the shared activity timestamp for the watchdog.
|
||||||
|
///
|
||||||
|
/// Note on message counts: the original code counted one `read()`/`write_all()`
|
||||||
|
/// as one "message". Here we count `poll_read`/`poll_write` completions instead.
|
||||||
|
/// Byte counts are identical; op counts may differ slightly due to different
|
||||||
|
/// internal buffering in `copy_bidirectional`. This is fine for monitoring.
|
||||||
|
struct StatsIo<S> {
|
||||||
|
inner: S,
|
||||||
|
counters: Arc<SharedCounters>,
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
user: String,
|
||||||
|
epoch: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> StatsIo<S> {
|
||||||
|
fn new(
|
||||||
|
inner: S,
|
||||||
|
counters: Arc<SharedCounters>,
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
user: String,
|
||||||
|
epoch: Instant,
|
||||||
|
) -> Self {
|
||||||
|
// Mark initial activity so the watchdog doesn't fire before data flows
|
||||||
|
counters.touch(Instant::now(), epoch);
|
||||||
|
Self { inner, counters, stats, user, epoch }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
let this = self.get_mut();
|
||||||
|
let before = buf.filled().len();
|
||||||
|
|
||||||
|
match Pin::new(&mut this.inner).poll_read(cx, buf) {
|
||||||
|
Poll::Ready(Ok(())) => {
|
||||||
|
let n = buf.filled().len() - before;
|
||||||
|
if n > 0 {
|
||||||
|
// C→S: client sent data
|
||||||
|
this.counters.c2s_bytes.fetch_add(n as u64, Ordering::Relaxed);
|
||||||
|
this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed);
|
||||||
|
this.counters.touch(Instant::now(), this.epoch);
|
||||||
|
|
||||||
|
this.stats.add_user_octets_from(&this.user, n as u64);
|
||||||
|
this.stats.increment_user_msgs_from(&this.user);
|
||||||
|
|
||||||
|
trace!(user = %this.user, bytes = n, "C->S");
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
other => other,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
let this = self.get_mut();
|
||||||
|
|
||||||
|
match Pin::new(&mut this.inner).poll_write(cx, buf) {
|
||||||
|
Poll::Ready(Ok(n)) => {
|
||||||
|
if n > 0 {
|
||||||
|
// S→C: data written to client
|
||||||
|
this.counters.s2c_bytes.fetch_add(n as u64, Ordering::Relaxed);
|
||||||
|
this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed);
|
||||||
|
this.counters.touch(Instant::now(), this.epoch);
|
||||||
|
|
||||||
|
this.stats.add_user_octets_to(&this.user, n as u64);
|
||||||
|
this.stats.increment_user_msgs_to(&this.user);
|
||||||
|
|
||||||
|
trace!(user = %this.user, bytes = n, "S->C");
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(n))
|
||||||
|
}
|
||||||
|
other => other,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Relay =============
|
||||||
|
|
||||||
|
/// Relay data bidirectionally between client and server.
|
||||||
|
///
|
||||||
|
/// Uses `tokio::io::copy_bidirectional` for concurrent, non-blocking data transfer.
|
||||||
|
///
|
||||||
|
/// ## API compatibility
|
||||||
|
///
|
||||||
|
/// Signature is identical to the previous implementation. The `_buffer_pool`
|
||||||
|
/// parameter is retained for call-site compatibility — `copy_bidirectional`
|
||||||
|
/// manages its own internal buffers (8 KB per direction).
|
||||||
|
///
|
||||||
|
/// ## Guarantees preserved
|
||||||
|
///
|
||||||
|
/// - Activity timeout: 30 minutes of inactivity → clean shutdown
|
||||||
|
/// - Per-user stats: bytes and ops counted per direction
|
||||||
|
/// - Periodic rate logging: every 10 seconds when active
|
||||||
|
/// - Clean shutdown: both write sides are shut down on exit
|
||||||
|
/// - Error propagation: I/O errors are returned as `ProxyError::Io`
|
||||||
pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
||||||
mut client_reader: CR,
|
client_reader: CR,
|
||||||
mut client_writer: CW,
|
client_writer: CW,
|
||||||
mut server_reader: SR,
|
server_reader: SR,
|
||||||
mut server_writer: SW,
|
server_writer: SW,
|
||||||
user: &str,
|
user: &str,
|
||||||
stats: Arc<Stats>,
|
stats: Arc<Stats>,
|
||||||
|
_buffer_pool: Arc<BufferPool>,
|
||||||
) -> Result<()>
|
) -> Result<()>
|
||||||
where
|
where
|
||||||
CR: AsyncRead + Unpin + Send + 'static,
|
CR: AsyncRead + Unpin + Send + 'static,
|
||||||
@@ -24,139 +322,145 @@ where
|
|||||||
SR: AsyncRead + Unpin + Send + 'static,
|
SR: AsyncRead + Unpin + Send + 'static,
|
||||||
SW: AsyncWrite + Unpin + Send + 'static,
|
SW: AsyncWrite + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
let user_c2s = user.to_string();
|
let epoch = Instant::now();
|
||||||
let user_s2c = user.to_string();
|
let counters = Arc::new(SharedCounters::new());
|
||||||
|
let user_owned = user.to_string();
|
||||||
|
|
||||||
// Используем Arc::clone вместо stats.clone()
|
// ── Combine split halves into bidirectional streams ──────────────
|
||||||
let stats_c2s = Arc::clone(&stats);
|
let client_combined = CombinedStream::new(client_reader, client_writer);
|
||||||
let stats_s2c = Arc::clone(&stats);
|
let mut server = CombinedStream::new(server_reader, server_writer);
|
||||||
|
|
||||||
let c2s_bytes = Arc::new(AtomicU64::new(0));
|
// Wrap client with stats/activity tracking
|
||||||
let s2c_bytes = Arc::new(AtomicU64::new(0));
|
let mut client = StatsIo::new(
|
||||||
let c2s_bytes_clone = Arc::clone(&c2s_bytes);
|
client_combined,
|
||||||
let s2c_bytes_clone = Arc::clone(&s2c_bytes);
|
Arc::clone(&counters),
|
||||||
|
Arc::clone(&stats),
|
||||||
|
user_owned.clone(),
|
||||||
|
epoch,
|
||||||
|
);
|
||||||
|
|
||||||
// Client -> Server task
|
// ── Watchdog: activity timeout + periodic rate logging ──────────
|
||||||
let c2s = tokio::spawn(async move {
|
let wd_counters = Arc::clone(&counters);
|
||||||
let mut buf = vec![0u8; BUFFER_SIZE];
|
let wd_user = user_owned.clone();
|
||||||
let mut total_bytes = 0u64;
|
|
||||||
let mut msg_count = 0u64;
|
let watchdog = async {
|
||||||
|
let mut prev_c2s: u64 = 0;
|
||||||
|
let mut prev_s2c: u64 = 0;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match client_reader.read(&mut buf).await {
|
tokio::time::sleep(WATCHDOG_INTERVAL).await;
|
||||||
Ok(0) => {
|
|
||||||
|
let now = Instant::now();
|
||||||
|
let idle = wd_counters.idle_duration(now, epoch);
|
||||||
|
|
||||||
|
// ── Activity timeout ────────────────────────────────────
|
||||||
|
if idle >= ACTIVITY_TIMEOUT {
|
||||||
|
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
|
||||||
|
let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed);
|
||||||
|
warn!(
|
||||||
|
user = %wd_user,
|
||||||
|
c2s_bytes = c2s,
|
||||||
|
s2c_bytes = s2c,
|
||||||
|
idle_secs = idle.as_secs(),
|
||||||
|
"Activity timeout"
|
||||||
|
);
|
||||||
|
return; // Causes select! to cancel copy_bidirectional
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Periodic rate logging ───────────────────────────────
|
||||||
|
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
|
||||||
|
let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed);
|
||||||
|
let c2s_delta = c2s - prev_c2s;
|
||||||
|
let s2c_delta = s2c - prev_s2c;
|
||||||
|
|
||||||
|
if c2s_delta > 0 || s2c_delta > 0 {
|
||||||
|
let secs = WATCHDOG_INTERVAL.as_secs_f64();
|
||||||
debug!(
|
debug!(
|
||||||
user = %user_c2s,
|
user = %wd_user,
|
||||||
total_bytes = total_bytes,
|
c2s_kbps = (c2s_delta as f64 / secs / 1024.0) as u64,
|
||||||
msgs = msg_count,
|
s2c_kbps = (s2c_delta as f64 / secs / 1024.0) as u64,
|
||||||
"Client closed connection (C->S)"
|
c2s_total = c2s,
|
||||||
|
s2c_total = s2c,
|
||||||
|
"Relay active"
|
||||||
);
|
);
|
||||||
let _ = server_writer.shutdown().await;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
Ok(n) => {
|
|
||||||
total_bytes += n as u64;
|
|
||||||
msg_count += 1;
|
|
||||||
c2s_bytes_clone.store(total_bytes, Ordering::Relaxed);
|
|
||||||
|
|
||||||
stats_c2s.add_user_octets_from(&user_c2s, n as u64);
|
prev_c2s = c2s;
|
||||||
stats_c2s.increment_user_msgs_from(&user_c2s);
|
prev_s2c = s2c;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
trace!(
|
// ── Run bidirectional copy + watchdog concurrently ───────────────
|
||||||
user = %user_c2s,
|
//
|
||||||
bytes = n,
|
// copy_bidirectional polls both directions in the same poll() call:
|
||||||
total = total_bytes,
|
// C→S: poll_read(client/StatsIo) → poll_write(server)
|
||||||
data_preview = %hex::encode(&buf[..n.min(32)]),
|
// S→C: poll_read(server) → poll_write(client/StatsIo)
|
||||||
"C->S data"
|
//
|
||||||
);
|
// When one direction's writer returns Pending, the other direction
|
||||||
|
// continues — no head-of-line blocking.
|
||||||
|
//
|
||||||
|
// When the watchdog fires, select! drops the copy future,
|
||||||
|
// releasing the &mut borrows on client and server.
|
||||||
|
let copy_result = tokio::select! {
|
||||||
|
result = copy_bidirectional(&mut client, &mut server) => Some(result),
|
||||||
|
_ = watchdog => None, // Activity timeout — cancel relay
|
||||||
|
};
|
||||||
|
|
||||||
if let Err(e) = server_writer.write_all(&buf[..n]).await {
|
// ── Clean shutdown ──────────────────────────────────────────────
|
||||||
debug!(user = %user_c2s, error = %e, "Failed to write to server");
|
// After select!, the losing future is dropped, borrows released.
|
||||||
break;
|
// Shut down both write sides for clean TCP FIN.
|
||||||
}
|
let _ = client.shutdown().await;
|
||||||
if let Err(e) = server_writer.flush().await {
|
let _ = server.shutdown().await;
|
||||||
debug!(user = %user_c2s, error = %e, "Failed to flush to server");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Server -> Client task
|
// ── Final logging ───────────────────────────────────────────────
|
||||||
let s2c = tokio::spawn(async move {
|
let c2s_ops = counters.c2s_ops.load(Ordering::Relaxed);
|
||||||
let mut buf = vec![0u8; BUFFER_SIZE];
|
let s2c_ops = counters.s2c_ops.load(Ordering::Relaxed);
|
||||||
let mut total_bytes = 0u64;
|
let duration = epoch.elapsed();
|
||||||
let mut msg_count = 0u64;
|
|
||||||
|
|
||||||
loop {
|
match copy_result {
|
||||||
match server_reader.read(&mut buf).await {
|
Some(Ok((c2s, s2c))) => {
|
||||||
Ok(0) => {
|
// Normal completion — one side closed the connection
|
||||||
debug!(
|
debug!(
|
||||||
user = %user_s2c,
|
user = %user_owned,
|
||||||
total_bytes = total_bytes,
|
c2s_bytes = c2s,
|
||||||
msgs = msg_count,
|
s2c_bytes = s2c,
|
||||||
"Server closed connection (S->C)"
|
c2s_msgs = c2s_ops,
|
||||||
);
|
s2c_msgs = s2c_ops,
|
||||||
let _ = client_writer.shutdown().await;
|
duration_secs = duration.as_secs(),
|
||||||
break;
|
|
||||||
}
|
|
||||||
Ok(n) => {
|
|
||||||
total_bytes += n as u64;
|
|
||||||
msg_count += 1;
|
|
||||||
s2c_bytes_clone.store(total_bytes, Ordering::Relaxed);
|
|
||||||
|
|
||||||
stats_s2c.add_user_octets_to(&user_s2c, n as u64);
|
|
||||||
stats_s2c.increment_user_msgs_to(&user_s2c);
|
|
||||||
|
|
||||||
trace!(
|
|
||||||
user = %user_s2c,
|
|
||||||
bytes = n,
|
|
||||||
total = total_bytes,
|
|
||||||
data_preview = %hex::encode(&buf[..n.min(32)]),
|
|
||||||
"S->C data"
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Err(e) = client_writer.write_all(&buf[..n]).await {
|
|
||||||
debug!(user = %user_s2c, error = %e, "Failed to write to client");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Err(e) = client_writer.flush().await {
|
|
||||||
debug!(user = %user_s2c, error = %e, "Failed to flush to client");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Wait for either direction to complete
|
|
||||||
tokio::select! {
|
|
||||||
result = c2s => {
|
|
||||||
if let Err(e) = result {
|
|
||||||
warn!(error = %e, "C->S task panicked");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result = s2c => {
|
|
||||||
if let Err(e) = result {
|
|
||||||
warn!(error = %e, "S->C task panicked");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
debug!(
|
|
||||||
c2s_bytes = c2s_bytes.load(Ordering::Relaxed),
|
|
||||||
s2c_bytes = s2c_bytes.load(Ordering::Relaxed),
|
|
||||||
"Relay finished"
|
"Relay finished"
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Some(Err(e)) => {
|
||||||
|
// I/O error in one of the directions
|
||||||
|
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||||
|
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||||
|
debug!(
|
||||||
|
user = %user_owned,
|
||||||
|
c2s_bytes = c2s,
|
||||||
|
s2c_bytes = s2c,
|
||||||
|
c2s_msgs = c2s_ops,
|
||||||
|
s2c_msgs = s2c_ops,
|
||||||
|
duration_secs = duration.as_secs(),
|
||||||
|
error = %e,
|
||||||
|
"Relay error"
|
||||||
|
);
|
||||||
|
Err(e.into())
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
// Activity timeout (watchdog fired)
|
||||||
|
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||||
|
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||||
|
debug!(
|
||||||
|
user = %user_owned,
|
||||||
|
c2s_bytes = c2s,
|
||||||
|
s2c_bytes = s2c,
|
||||||
|
c2s_msgs = c2s_ops,
|
||||||
|
s2c_msgs = s2c_ops,
|
||||||
|
duration_secs = duration.as_secs(),
|
||||||
|
"Relay finished (activity timeout)"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
373
src/stats/mod.rs
373
src/stats/mod.rs
@@ -1,29 +1,28 @@
|
|||||||
//! Statistics
|
//! Statistics and replay protection
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::{Instant, Duration};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use parking_lot::RwLock;
|
use parking_lot::Mutex;
|
||||||
use lru::LruCache;
|
use lru::LruCache;
|
||||||
use std::num::NonZeroUsize;
|
use std::num::NonZeroUsize;
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
|
use std::collections::hash_map::DefaultHasher;
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
// ============= Stats =============
|
||||||
|
|
||||||
/// Thread-safe statistics
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct Stats {
|
pub struct Stats {
|
||||||
// Global counters
|
|
||||||
connects_all: AtomicU64,
|
connects_all: AtomicU64,
|
||||||
connects_bad: AtomicU64,
|
connects_bad: AtomicU64,
|
||||||
handshake_timeouts: AtomicU64,
|
handshake_timeouts: AtomicU64,
|
||||||
|
|
||||||
// Per-user stats
|
|
||||||
user_stats: DashMap<String, UserStats>,
|
user_stats: DashMap<String, UserStats>,
|
||||||
|
start_time: parking_lot::RwLock<Option<Instant>>,
|
||||||
// Start time
|
|
||||||
start_time: RwLock<Option<Instant>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Per-user statistics
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct UserStats {
|
pub struct UserStats {
|
||||||
pub connects: AtomicU64,
|
pub connects: AtomicU64,
|
||||||
@@ -41,42 +40,20 @@ impl Stats {
|
|||||||
stats
|
stats
|
||||||
}
|
}
|
||||||
|
|
||||||
// Global stats
|
pub fn increment_connects_all(&self) { self.connects_all.fetch_add(1, Ordering::Relaxed); }
|
||||||
pub fn increment_connects_all(&self) {
|
pub fn increment_connects_bad(&self) { self.connects_bad.fetch_add(1, Ordering::Relaxed); }
|
||||||
self.connects_all.fetch_add(1, Ordering::Relaxed);
|
pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); }
|
||||||
}
|
pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) }
|
||||||
|
pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) }
|
||||||
|
|
||||||
pub fn increment_connects_bad(&self) {
|
|
||||||
self.connects_bad.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn increment_handshake_timeouts(&self) {
|
|
||||||
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_connects_all(&self) -> u64 {
|
|
||||||
self.connects_all.load(Ordering::Relaxed)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_connects_bad(&self) -> u64 {
|
|
||||||
self.connects_bad.load(Ordering::Relaxed)
|
|
||||||
}
|
|
||||||
|
|
||||||
// User stats
|
|
||||||
pub fn increment_user_connects(&self, user: &str) {
|
pub fn increment_user_connects(&self, user: &str) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.connects.fetch_add(1, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.connects
|
|
||||||
.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn increment_user_curr_connects(&self, user: &str) {
|
pub fn increment_user_curr_connects(&self, user: &str) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.curr_connects.fetch_add(1, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.curr_connects
|
|
||||||
.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decrement_user_curr_connects(&self, user: &str) {
|
pub fn decrement_user_curr_connects(&self, user: &str) {
|
||||||
@@ -86,47 +63,33 @@ impl Stats {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_user_curr_connects(&self, user: &str) -> u64 {
|
pub fn get_user_curr_connects(&self, user: &str) -> u64 {
|
||||||
self.user_stats
|
self.user_stats.get(user)
|
||||||
.get(user)
|
|
||||||
.map(|s| s.curr_connects.load(Ordering::Relaxed))
|
.map(|s| s.curr_connects.load(Ordering::Relaxed))
|
||||||
.unwrap_or(0)
|
.unwrap_or(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_user_octets_from(&self, user: &str, bytes: u64) {
|
pub fn add_user_octets_from(&self, user: &str, bytes: u64) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.octets_from_client
|
|
||||||
.fetch_add(bytes, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
|
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.octets_to_client
|
|
||||||
.fetch_add(bytes, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn increment_user_msgs_from(&self, user: &str) {
|
pub fn increment_user_msgs_from(&self, user: &str) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.msgs_from_client.fetch_add(1, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.msgs_from_client
|
|
||||||
.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn increment_user_msgs_to(&self, user: &str) {
|
pub fn increment_user_msgs_to(&self, user: &str) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.msgs_to_client.fetch_add(1, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.msgs_to_client
|
|
||||||
.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_user_total_octets(&self, user: &str) -> u64 {
|
pub fn get_user_total_octets(&self, user: &str) -> u64 {
|
||||||
self.user_stats
|
self.user_stats.get(user)
|
||||||
.get(user)
|
|
||||||
.map(|s| {
|
.map(|s| {
|
||||||
s.octets_from_client.load(Ordering::Relaxed) +
|
s.octets_from_client.load(Ordering::Relaxed) +
|
||||||
s.octets_to_client.load(Ordering::Relaxed)
|
s.octets_to_client.load(Ordering::Relaxed)
|
||||||
@@ -141,37 +104,209 @@ impl Stats {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Arc<Stats> Hightech Stats :D
|
// ============= Replay Checker =============
|
||||||
|
|
||||||
/// Replay attack checker using LRU cache
|
|
||||||
pub struct ReplayChecker {
|
pub struct ReplayChecker {
|
||||||
handshakes: RwLock<LruCache<Vec<u8>, ()>>,
|
shards: Vec<Mutex<ReplayShard>>,
|
||||||
tls_digests: RwLock<LruCache<Vec<u8>, ()>>,
|
shard_mask: usize,
|
||||||
|
window: Duration,
|
||||||
|
checks: AtomicU64,
|
||||||
|
hits: AtomicU64,
|
||||||
|
additions: AtomicU64,
|
||||||
|
cleanups: AtomicU64,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ReplayEntry {
|
||||||
|
seen_at: Instant,
|
||||||
|
seq: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ReplayShard {
|
||||||
|
cache: LruCache<Box<[u8]>, ReplayEntry>,
|
||||||
|
queue: VecDeque<(Instant, Box<[u8]>, u64)>,
|
||||||
|
seq_counter: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReplayShard {
|
||||||
|
fn new(cap: NonZeroUsize) -> Self {
|
||||||
|
Self {
|
||||||
|
cache: LruCache::new(cap),
|
||||||
|
queue: VecDeque::with_capacity(cap.get()),
|
||||||
|
seq_counter: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn next_seq(&mut self) -> u64 {
|
||||||
|
self.seq_counter += 1;
|
||||||
|
self.seq_counter
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cleanup(&mut self, now: Instant, window: Duration) {
|
||||||
|
if window.is_zero() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let cutoff = now.checked_sub(window).unwrap_or(now);
|
||||||
|
|
||||||
|
while let Some((ts, _, _)) = self.queue.front() {
|
||||||
|
if *ts >= cutoff {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let (_, key, queue_seq) = self.queue.pop_front().unwrap();
|
||||||
|
|
||||||
|
// Use key.as_ref() to get &[u8] — avoids Borrow<Q> ambiguity
|
||||||
|
// between Borrow<[u8]> and Borrow<Box<[u8]>>
|
||||||
|
if let Some(entry) = self.cache.peek(key.as_ref()) {
|
||||||
|
if entry.seq == queue_seq {
|
||||||
|
self.cache.pop(key.as_ref());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool {
|
||||||
|
self.cleanup(now, window);
|
||||||
|
// key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]>
|
||||||
|
self.cache.get(key).is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add(&mut self, key: &[u8], now: Instant, window: Duration) {
|
||||||
|
self.cleanup(now, window);
|
||||||
|
|
||||||
|
let seq = self.next_seq();
|
||||||
|
let boxed_key: Box<[u8]> = key.into();
|
||||||
|
|
||||||
|
self.cache.put(boxed_key.clone(), ReplayEntry { seen_at: now, seq });
|
||||||
|
self.queue.push_back((now, boxed_key, seq));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn len(&self) -> usize {
|
||||||
|
self.cache.len()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ReplayChecker {
|
impl ReplayChecker {
|
||||||
pub fn new(capacity: usize) -> Self {
|
pub fn new(total_capacity: usize, window: Duration) -> Self {
|
||||||
let cap = NonZeroUsize::new(capacity.max(1)).unwrap();
|
let num_shards = 64;
|
||||||
|
let shard_capacity = (total_capacity / num_shards).max(1);
|
||||||
|
let cap = NonZeroUsize::new(shard_capacity).unwrap();
|
||||||
|
|
||||||
|
let mut shards = Vec::with_capacity(num_shards);
|
||||||
|
for _ in 0..num_shards {
|
||||||
|
shards.push(Mutex::new(ReplayShard::new(cap)));
|
||||||
|
}
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
handshakes: RwLock::new(LruCache::new(cap)),
|
shards,
|
||||||
tls_digests: RwLock::new(LruCache::new(cap)),
|
shard_mask: num_shards - 1,
|
||||||
|
window,
|
||||||
|
checks: AtomicU64::new(0),
|
||||||
|
hits: AtomicU64::new(0),
|
||||||
|
additions: AtomicU64::new(0),
|
||||||
|
cleanups: AtomicU64::new(0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check_handshake(&self, data: &[u8]) -> bool {
|
fn get_shard_idx(&self, key: &[u8]) -> usize {
|
||||||
self.handshakes.read().contains(&data.to_vec())
|
let mut hasher = DefaultHasher::new();
|
||||||
|
key.hash(&mut hasher);
|
||||||
|
(hasher.finish() as usize) & self.shard_mask
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_handshake(&self, data: &[u8]) {
|
fn check(&self, data: &[u8]) -> bool {
|
||||||
self.handshakes.write().put(data.to_vec(), ());
|
self.checks.fetch_add(1, Ordering::Relaxed);
|
||||||
|
let idx = self.get_shard_idx(data);
|
||||||
|
let mut shard = self.shards[idx].lock();
|
||||||
|
let found = shard.check(data, Instant::now(), self.window);
|
||||||
|
if found {
|
||||||
|
self.hits.fetch_add(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
found
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
|
fn add(&self, data: &[u8]) {
|
||||||
self.tls_digests.read().contains(&data.to_vec())
|
self.additions.fetch_add(1, Ordering::Relaxed);
|
||||||
|
let idx = self.get_shard_idx(data);
|
||||||
|
let mut shard = self.shards[idx].lock();
|
||||||
|
shard.add(data, Instant::now(), self.window);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_tls_digest(&self, data: &[u8]) {
|
pub fn check_handshake(&self, data: &[u8]) -> bool { self.check(data) }
|
||||||
self.tls_digests.write().put(data.to_vec(), ());
|
pub fn add_handshake(&self, data: &[u8]) { self.add(data) }
|
||||||
|
pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check(data) }
|
||||||
|
pub fn add_tls_digest(&self, data: &[u8]) { self.add(data) }
|
||||||
|
|
||||||
|
pub fn stats(&self) -> ReplayStats {
|
||||||
|
let mut total_entries = 0;
|
||||||
|
let mut total_queue_len = 0;
|
||||||
|
for shard in &self.shards {
|
||||||
|
let s = shard.lock();
|
||||||
|
total_entries += s.cache.len();
|
||||||
|
total_queue_len += s.queue.len();
|
||||||
|
}
|
||||||
|
|
||||||
|
ReplayStats {
|
||||||
|
total_entries,
|
||||||
|
total_queue_len,
|
||||||
|
total_checks: self.checks.load(Ordering::Relaxed),
|
||||||
|
total_hits: self.hits.load(Ordering::Relaxed),
|
||||||
|
total_additions: self.additions.load(Ordering::Relaxed),
|
||||||
|
total_cleanups: self.cleanups.load(Ordering::Relaxed),
|
||||||
|
num_shards: self.shards.len(),
|
||||||
|
window_secs: self.window.as_secs(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_periodic_cleanup(&self) {
|
||||||
|
let interval = if self.window.as_secs() > 60 {
|
||||||
|
Duration::from_secs(30)
|
||||||
|
} else {
|
||||||
|
Duration::from_secs(self.window.as_secs().max(1) / 2)
|
||||||
|
};
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(interval).await;
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
let mut cleaned = 0usize;
|
||||||
|
|
||||||
|
for shard_mutex in &self.shards {
|
||||||
|
let mut shard = shard_mutex.lock();
|
||||||
|
let before = shard.len();
|
||||||
|
shard.cleanup(now, self.window);
|
||||||
|
let after = shard.len();
|
||||||
|
cleaned += before.saturating_sub(after);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.cleanups.fetch_add(1, Ordering::Relaxed);
|
||||||
|
|
||||||
|
if cleaned > 0 {
|
||||||
|
debug!(cleaned = cleaned, "Replay checker: periodic cleanup");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ReplayStats {
|
||||||
|
pub total_entries: usize,
|
||||||
|
pub total_queue_len: usize,
|
||||||
|
pub total_checks: u64,
|
||||||
|
pub total_hits: u64,
|
||||||
|
pub total_additions: u64,
|
||||||
|
pub total_cleanups: u64,
|
||||||
|
pub num_shards: usize,
|
||||||
|
pub window_secs: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReplayStats {
|
||||||
|
pub fn hit_rate(&self) -> f64 {
|
||||||
|
if self.total_checks == 0 { 0.0 }
|
||||||
|
else { (self.total_hits as f64 / self.total_checks as f64) * 100.0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ghost_ratio(&self) -> f64 {
|
||||||
|
if self.total_entries == 0 { 0.0 }
|
||||||
|
else { self.total_queue_len as f64 / self.total_entries as f64 }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -182,42 +317,60 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_stats_shared_counters() {
|
fn test_stats_shared_counters() {
|
||||||
let stats = Arc::new(Stats::new());
|
let stats = Arc::new(Stats::new());
|
||||||
|
stats.increment_connects_all();
|
||||||
// Симулируем использование из разных "задач"
|
stats.increment_connects_all();
|
||||||
let stats1 = Arc::clone(&stats);
|
stats.increment_connects_all();
|
||||||
let stats2 = Arc::clone(&stats);
|
|
||||||
|
|
||||||
stats1.increment_connects_all();
|
|
||||||
stats2.increment_connects_all();
|
|
||||||
stats1.increment_connects_all();
|
|
||||||
|
|
||||||
// Все инкременты должны быть видны
|
|
||||||
assert_eq!(stats.get_connects_all(), 3);
|
assert_eq!(stats.get_connects_all(), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_user_stats_shared() {
|
fn test_replay_checker_basic() {
|
||||||
let stats = Arc::new(Stats::new());
|
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||||
|
assert!(!checker.check_handshake(b"test1"));
|
||||||
let stats1 = Arc::clone(&stats);
|
checker.add_handshake(b"test1");
|
||||||
let stats2 = Arc::clone(&stats);
|
assert!(checker.check_handshake(b"test1"));
|
||||||
|
assert!(!checker.check_handshake(b"test2"));
|
||||||
stats1.add_user_octets_from("user1", 100);
|
|
||||||
stats2.add_user_octets_from("user1", 200);
|
|
||||||
stats1.add_user_octets_to("user1", 50);
|
|
||||||
|
|
||||||
assert_eq!(stats.get_user_total_octets("user1"), 350);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_concurrent_user_connects() {
|
fn test_replay_checker_duplicate_add() {
|
||||||
let stats = Arc::new(Stats::new());
|
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||||
|
checker.add_handshake(b"dup");
|
||||||
|
checker.add_handshake(b"dup");
|
||||||
|
assert!(checker.check_handshake(b"dup"));
|
||||||
|
}
|
||||||
|
|
||||||
stats.increment_user_curr_connects("user1");
|
#[test]
|
||||||
stats.increment_user_curr_connects("user1");
|
fn test_replay_checker_expiration() {
|
||||||
assert_eq!(stats.get_user_curr_connects("user1"), 2);
|
let checker = ReplayChecker::new(100, Duration::from_millis(50));
|
||||||
|
checker.add_handshake(b"expire");
|
||||||
|
assert!(checker.check_handshake(b"expire"));
|
||||||
|
std::thread::sleep(Duration::from_millis(100));
|
||||||
|
assert!(!checker.check_handshake(b"expire"));
|
||||||
|
}
|
||||||
|
|
||||||
stats.decrement_user_curr_connects("user1");
|
#[test]
|
||||||
assert_eq!(stats.get_user_curr_connects("user1"), 1);
|
fn test_replay_checker_stats() {
|
||||||
|
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||||
|
checker.add_handshake(b"k1");
|
||||||
|
checker.add_handshake(b"k2");
|
||||||
|
checker.check_handshake(b"k1");
|
||||||
|
checker.check_handshake(b"k3");
|
||||||
|
let stats = checker.stats();
|
||||||
|
assert_eq!(stats.total_additions, 2);
|
||||||
|
assert_eq!(stats.total_checks, 2);
|
||||||
|
assert_eq!(stats.total_hits, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_replay_checker_many_keys() {
|
||||||
|
let checker = ReplayChecker::new(1000, Duration::from_secs(60));
|
||||||
|
for i in 0..500u32 {
|
||||||
|
checker.add(&i.to_le_bytes());
|
||||||
|
}
|
||||||
|
for i in 0..500u32 {
|
||||||
|
assert!(checker.check(&i.to_le_bytes()));
|
||||||
|
}
|
||||||
|
assert_eq!(checker.stats().total_entries, 500);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
451
src/stream/buffer_pool.rs
Normal file
451
src/stream/buffer_pool.rs
Normal file
@@ -0,0 +1,451 @@
|
|||||||
|
//! Reusable buffer pool to avoid allocations in hot paths
|
||||||
|
//!
|
||||||
|
//! This module provides a thread-safe pool of BytesMut buffers
|
||||||
|
//! that can be reused across connections to reduce allocation pressure.
|
||||||
|
|
||||||
|
use bytes::BytesMut;
|
||||||
|
use crossbeam_queue::ArrayQueue;
|
||||||
|
use std::ops::{Deref, DerefMut};
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
// ============= Configuration =============
|
||||||
|
|
||||||
|
/// Default buffer size
|
||||||
|
/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and prevent bufferbloat.
|
||||||
|
pub const DEFAULT_BUFFER_SIZE: usize = 16 * 1024;
|
||||||
|
|
||||||
|
/// Default maximum number of pooled buffers
|
||||||
|
pub const DEFAULT_MAX_BUFFERS: usize = 1024;
|
||||||
|
|
||||||
|
// ============= Buffer Pool =============
|
||||||
|
|
||||||
|
/// Thread-safe pool of reusable buffers
|
||||||
|
pub struct BufferPool {
|
||||||
|
/// Queue of available buffers
|
||||||
|
buffers: ArrayQueue<BytesMut>,
|
||||||
|
/// Size of each buffer
|
||||||
|
buffer_size: usize,
|
||||||
|
/// Maximum number of buffers to pool
|
||||||
|
max_buffers: usize,
|
||||||
|
/// Total allocated buffers (including in-use)
|
||||||
|
allocated: AtomicUsize,
|
||||||
|
/// Number of times we had to create a new buffer
|
||||||
|
misses: AtomicUsize,
|
||||||
|
/// Number of successful reuses
|
||||||
|
hits: AtomicUsize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BufferPool {
|
||||||
|
/// Create a new buffer pool with default settings
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::with_config(DEFAULT_BUFFER_SIZE, DEFAULT_MAX_BUFFERS)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a buffer pool with custom configuration
|
||||||
|
pub fn with_config(buffer_size: usize, max_buffers: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
buffers: ArrayQueue::new(max_buffers),
|
||||||
|
buffer_size,
|
||||||
|
max_buffers,
|
||||||
|
allocated: AtomicUsize::new(0),
|
||||||
|
misses: AtomicUsize::new(0),
|
||||||
|
hits: AtomicUsize::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a buffer from the pool, or create a new one if empty
|
||||||
|
pub fn get(self: &Arc<Self>) -> PooledBuffer {
|
||||||
|
match self.buffers.pop() {
|
||||||
|
Some(mut buffer) => {
|
||||||
|
self.hits.fetch_add(1, Ordering::Relaxed);
|
||||||
|
buffer.clear();
|
||||||
|
PooledBuffer {
|
||||||
|
buffer: Some(buffer),
|
||||||
|
pool: Arc::clone(self),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
self.misses.fetch_add(1, Ordering::Relaxed);
|
||||||
|
self.allocated.fetch_add(1, Ordering::Relaxed);
|
||||||
|
PooledBuffer {
|
||||||
|
buffer: Some(BytesMut::with_capacity(self.buffer_size)),
|
||||||
|
pool: Arc::clone(self),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to get a buffer, returns None if pool is empty
|
||||||
|
pub fn try_get(self: &Arc<Self>) -> Option<PooledBuffer> {
|
||||||
|
self.buffers.pop().map(|mut buffer| {
|
||||||
|
self.hits.fetch_add(1, Ordering::Relaxed);
|
||||||
|
buffer.clear();
|
||||||
|
PooledBuffer {
|
||||||
|
buffer: Some(buffer),
|
||||||
|
pool: Arc::clone(self),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a buffer to the pool
|
||||||
|
fn return_buffer(&self, mut buffer: BytesMut) {
|
||||||
|
// Clear the buffer but keep capacity
|
||||||
|
buffer.clear();
|
||||||
|
|
||||||
|
// Only return if we haven't exceeded max and buffer is right size
|
||||||
|
if buffer.capacity() >= self.buffer_size {
|
||||||
|
// Try to push to pool, if full just drop
|
||||||
|
let _ = self.buffers.push(buffer);
|
||||||
|
}
|
||||||
|
// If buffer was dropped (pool full), decrement allocated
|
||||||
|
// Actually we don't decrement here because the buffer might have been
|
||||||
|
// grown beyond our size - we just let it go
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get pool statistics
|
||||||
|
pub fn stats(&self) -> PoolStats {
|
||||||
|
PoolStats {
|
||||||
|
pooled: self.buffers.len(),
|
||||||
|
allocated: self.allocated.load(Ordering::Relaxed),
|
||||||
|
max_buffers: self.max_buffers,
|
||||||
|
buffer_size: self.buffer_size,
|
||||||
|
hits: self.hits.load(Ordering::Relaxed),
|
||||||
|
misses: self.misses.load(Ordering::Relaxed),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get buffer size
|
||||||
|
pub fn buffer_size(&self) -> usize {
|
||||||
|
self.buffer_size
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Preallocate buffers to fill the pool
|
||||||
|
pub fn preallocate(&self, count: usize) {
|
||||||
|
let to_alloc = count.min(self.max_buffers);
|
||||||
|
for _ in 0..to_alloc {
|
||||||
|
if self.buffers.push(BytesMut::with_capacity(self.buffer_size)).is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
self.allocated.fetch_add(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for BufferPool {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Pool Statistics =============
|
||||||
|
|
||||||
|
/// Statistics about buffer pool usage
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PoolStats {
|
||||||
|
/// Current number of buffers in pool
|
||||||
|
pub pooled: usize,
|
||||||
|
/// Total buffers allocated (in-use + pooled)
|
||||||
|
pub allocated: usize,
|
||||||
|
/// Maximum buffers allowed
|
||||||
|
pub max_buffers: usize,
|
||||||
|
/// Size of each buffer
|
||||||
|
pub buffer_size: usize,
|
||||||
|
/// Number of cache hits (reused buffer)
|
||||||
|
pub hits: usize,
|
||||||
|
/// Number of cache misses (new allocation)
|
||||||
|
pub misses: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PoolStats {
|
||||||
|
/// Get hit rate as percentage
|
||||||
|
pub fn hit_rate(&self) -> f64 {
|
||||||
|
let total = self.hits + self.misses;
|
||||||
|
if total == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
(self.hits as f64 / total as f64) * 100.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Pooled Buffer =============
|
||||||
|
|
||||||
|
/// A buffer that automatically returns to the pool when dropped
|
||||||
|
pub struct PooledBuffer {
|
||||||
|
buffer: Option<BytesMut>,
|
||||||
|
pool: Arc<BufferPool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PooledBuffer {
|
||||||
|
/// Take the inner buffer, preventing return to pool
|
||||||
|
pub fn take(mut self) -> BytesMut {
|
||||||
|
self.buffer.take().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the capacity of the buffer
|
||||||
|
pub fn capacity(&self) -> usize {
|
||||||
|
self.buffer.as_ref().map(|b| b.capacity()).unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if buffer is empty
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.buffer.as_ref().map(|b| b.is_empty()).unwrap_or(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the length of data in buffer
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.buffer.as_ref().map(|b| b.len()).unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the buffer
|
||||||
|
pub fn clear(&mut self) {
|
||||||
|
if let Some(ref mut b) = self.buffer {
|
||||||
|
b.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deref for PooledBuffer {
|
||||||
|
type Target = BytesMut;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
self.buffer.as_ref().expect("buffer taken")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DerefMut for PooledBuffer {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
self.buffer.as_mut().expect("buffer taken")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for PooledBuffer {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Some(buffer) = self.buffer.take() {
|
||||||
|
self.pool.return_buffer(buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsRef<[u8]> for PooledBuffer {
|
||||||
|
fn as_ref(&self) -> &[u8] {
|
||||||
|
self.buffer.as_ref().map(|b| b.as_ref()).unwrap_or(&[])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsMut<[u8]> for PooledBuffer {
|
||||||
|
fn as_mut(&mut self) -> &mut [u8] {
|
||||||
|
self.buffer.as_mut().map(|b| b.as_mut()).unwrap_or(&mut [])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Scoped Buffer =============
|
||||||
|
|
||||||
|
/// A buffer that can be used for a scoped operation
|
||||||
|
/// Useful for ensuring buffer is returned even on early return
|
||||||
|
pub struct ScopedBuffer<'a> {
|
||||||
|
buffer: &'a mut PooledBuffer,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> ScopedBuffer<'a> {
|
||||||
|
/// Create a new scoped buffer
|
||||||
|
pub fn new(buffer: &'a mut PooledBuffer) -> Self {
|
||||||
|
buffer.clear();
|
||||||
|
Self { buffer }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Deref for ScopedBuffer<'a> {
|
||||||
|
type Target = BytesMut;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
self.buffer.deref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> DerefMut for ScopedBuffer<'a> {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
self.buffer.deref_mut()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Drop for ScopedBuffer<'a> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.buffer.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pool_basic() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||||
|
|
||||||
|
// Get a buffer
|
||||||
|
let mut buf1 = pool.get();
|
||||||
|
buf1.extend_from_slice(b"hello");
|
||||||
|
assert_eq!(&buf1[..], b"hello");
|
||||||
|
|
||||||
|
// Drop returns to pool
|
||||||
|
drop(buf1);
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.pooled, 1);
|
||||||
|
assert_eq!(stats.hits, 0);
|
||||||
|
assert_eq!(stats.misses, 1);
|
||||||
|
|
||||||
|
// Get again - should reuse
|
||||||
|
let buf2 = pool.get();
|
||||||
|
assert!(buf2.is_empty()); // Buffer was cleared
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.pooled, 0);
|
||||||
|
assert_eq!(stats.hits, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pool_multiple_buffers() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||||
|
|
||||||
|
// Get multiple buffers
|
||||||
|
let buf1 = pool.get();
|
||||||
|
let buf2 = pool.get();
|
||||||
|
let buf3 = pool.get();
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.allocated, 3);
|
||||||
|
assert_eq!(stats.pooled, 0);
|
||||||
|
|
||||||
|
// Return all
|
||||||
|
drop(buf1);
|
||||||
|
drop(buf2);
|
||||||
|
drop(buf3);
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.pooled, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pool_overflow() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 2));
|
||||||
|
|
||||||
|
// Get 3 buffers (more than max)
|
||||||
|
let buf1 = pool.get();
|
||||||
|
let buf2 = pool.get();
|
||||||
|
let buf3 = pool.get();
|
||||||
|
|
||||||
|
// Return all - only 2 should be pooled
|
||||||
|
drop(buf1);
|
||||||
|
drop(buf2);
|
||||||
|
drop(buf3);
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.pooled, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pool_take() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||||
|
|
||||||
|
let mut buf = pool.get();
|
||||||
|
buf.extend_from_slice(b"data");
|
||||||
|
|
||||||
|
// Take ownership, buffer should not return to pool
|
||||||
|
let taken = buf.take();
|
||||||
|
assert_eq!(&taken[..], b"data");
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.pooled, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pool_preallocate() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||||
|
pool.preallocate(5);
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.pooled, 5);
|
||||||
|
assert_eq!(stats.allocated, 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pool_try_get() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||||
|
|
||||||
|
// Pool is empty, try_get returns None
|
||||||
|
assert!(pool.try_get().is_none());
|
||||||
|
|
||||||
|
// Add a buffer to pool
|
||||||
|
pool.preallocate(1);
|
||||||
|
|
||||||
|
// Now try_get should succeed
|
||||||
|
assert!(pool.try_get().is_some());
|
||||||
|
assert!(pool.try_get().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_hit_rate() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||||
|
|
||||||
|
// First get is a miss
|
||||||
|
let buf1 = pool.get();
|
||||||
|
drop(buf1);
|
||||||
|
|
||||||
|
// Second get is a hit
|
||||||
|
let buf2 = pool.get();
|
||||||
|
drop(buf2);
|
||||||
|
|
||||||
|
// Third get is a hit
|
||||||
|
let _buf3 = pool.get();
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
assert_eq!(stats.hits, 2);
|
||||||
|
assert_eq!(stats.misses, 1);
|
||||||
|
assert!((stats.hit_rate() - 66.67).abs() < 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_scoped_buffer() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||||
|
let mut buf = pool.get();
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut scoped = ScopedBuffer::new(&mut buf);
|
||||||
|
scoped.extend_from_slice(b"scoped data");
|
||||||
|
assert_eq!(&scoped[..], b"scoped data");
|
||||||
|
}
|
||||||
|
|
||||||
|
// After scoped is dropped, buffer is cleared
|
||||||
|
assert!(buf.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_concurrent_access() {
|
||||||
|
use std::thread;
|
||||||
|
|
||||||
|
let pool = Arc::new(BufferPool::with_config(1024, 100));
|
||||||
|
let mut handles = vec![];
|
||||||
|
|
||||||
|
for _ in 0..10 {
|
||||||
|
let pool_clone = Arc::clone(&pool);
|
||||||
|
handles.push(thread::spawn(move || {
|
||||||
|
for _ in 0..100 {
|
||||||
|
let mut buf = pool_clone.get();
|
||||||
|
buf.extend_from_slice(b"test");
|
||||||
|
// buf auto-returned on drop
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for handle in handles {
|
||||||
|
handle.join().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
// All buffers should be returned
|
||||||
|
assert!(stats.pooled > 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
189
src/stream/frame.rs
Normal file
189
src/stream/frame.rs
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
//! MTProto frame types and traits
|
||||||
|
//!
|
||||||
|
//! This module defines the common types and traits used by all
|
||||||
|
//! frame encoding/decoding implementations.
|
||||||
|
|
||||||
|
use bytes::{Bytes, BytesMut};
|
||||||
|
use std::io::Result;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::protocol::constants::ProtoTag;
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
|
|
||||||
|
// ============= Frame Types =============
|
||||||
|
|
||||||
|
/// A decoded MTProto frame
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Frame {
|
||||||
|
/// Frame payload data
|
||||||
|
pub data: Bytes,
|
||||||
|
/// Frame metadata
|
||||||
|
pub meta: FrameMeta,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Frame {
|
||||||
|
/// Create a new frame with data and default metadata
|
||||||
|
pub fn new(data: Bytes) -> Self {
|
||||||
|
Self {
|
||||||
|
data,
|
||||||
|
meta: FrameMeta::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new frame with data and metadata
|
||||||
|
pub fn with_meta(data: Bytes, meta: FrameMeta) -> Self {
|
||||||
|
Self { data, meta }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an empty frame
|
||||||
|
pub fn empty() -> Self {
|
||||||
|
Self::new(Bytes::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if frame is empty
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.data.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get frame length
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.data.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a QuickAck request frame
|
||||||
|
pub fn quickack(data: Bytes) -> Self {
|
||||||
|
Self {
|
||||||
|
data,
|
||||||
|
meta: FrameMeta {
|
||||||
|
quickack: true,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a simple ACK frame
|
||||||
|
pub fn simple_ack(data: Bytes) -> Self {
|
||||||
|
Self {
|
||||||
|
data,
|
||||||
|
meta: FrameMeta {
|
||||||
|
simple_ack: true,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Frame metadata
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
pub struct FrameMeta {
|
||||||
|
/// Quick ACK requested - client wants immediate acknowledgment
|
||||||
|
pub quickack: bool,
|
||||||
|
/// This is a simple ACK message (reversed data)
|
||||||
|
pub simple_ack: bool,
|
||||||
|
/// Original padding length (for secure mode)
|
||||||
|
pub padding_len: u8,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FrameMeta {
|
||||||
|
/// Create new empty metadata
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create with quickack flag
|
||||||
|
pub fn with_quickack(mut self) -> Self {
|
||||||
|
self.quickack = true;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create with simple_ack flag
|
||||||
|
pub fn with_simple_ack(mut self) -> Self {
|
||||||
|
self.simple_ack = true;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create with padding length
|
||||||
|
pub fn with_padding(mut self, len: u8) -> Self {
|
||||||
|
self.padding_len = len;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if any special flags are set
|
||||||
|
pub fn has_flags(&self) -> bool {
|
||||||
|
self.quickack || self.simple_ack
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Codec Trait =============
|
||||||
|
|
||||||
|
/// Trait for frame codecs that can encode and decode frames
|
||||||
|
pub trait FrameCodec: Send + Sync {
|
||||||
|
/// Get the protocol tag for this codec
|
||||||
|
fn proto_tag(&self) -> ProtoTag;
|
||||||
|
|
||||||
|
/// Encode a frame into the destination buffer
|
||||||
|
///
|
||||||
|
/// Returns the number of bytes written.
|
||||||
|
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> Result<usize>;
|
||||||
|
|
||||||
|
/// Try to decode a frame from the source buffer
|
||||||
|
///
|
||||||
|
/// Returns:
|
||||||
|
/// - `Ok(Some(frame))` if a complete frame was decoded
|
||||||
|
/// - `Ok(None)` if more data is needed
|
||||||
|
/// - `Err(e)` if an error occurred
|
||||||
|
///
|
||||||
|
/// On success, the consumed bytes are removed from `src`.
|
||||||
|
fn decode(&self, src: &mut BytesMut) -> Result<Option<Frame>>;
|
||||||
|
|
||||||
|
/// Get the minimum bytes needed to determine frame length
|
||||||
|
fn min_header_size(&self) -> usize;
|
||||||
|
|
||||||
|
/// Get the maximum allowed frame size
|
||||||
|
fn max_frame_size(&self) -> usize {
|
||||||
|
// Default: 16MB
|
||||||
|
16 * 1024 * 1024
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Codec Factory =============
|
||||||
|
|
||||||
|
/// Create a frame codec for the given protocol tag
|
||||||
|
pub fn create_codec(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Box<dyn FrameCodec> {
|
||||||
|
match proto_tag {
|
||||||
|
ProtoTag::Abridged => Box::new(crate::stream::frame_codec::AbridgedCodec::new()),
|
||||||
|
ProtoTag::Intermediate => Box::new(crate::stream::frame_codec::IntermediateCodec::new()),
|
||||||
|
ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new(rng)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_frame_creation() {
|
||||||
|
let frame = Frame::new(Bytes::from_static(b"test"));
|
||||||
|
assert_eq!(frame.len(), 4);
|
||||||
|
assert!(!frame.is_empty());
|
||||||
|
assert!(!frame.meta.quickack);
|
||||||
|
|
||||||
|
let frame = Frame::empty();
|
||||||
|
assert!(frame.is_empty());
|
||||||
|
|
||||||
|
let frame = Frame::quickack(Bytes::from_static(b"ack"));
|
||||||
|
assert!(frame.meta.quickack);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_frame_meta() {
|
||||||
|
let meta = FrameMeta::new()
|
||||||
|
.with_quickack()
|
||||||
|
.with_padding(3);
|
||||||
|
|
||||||
|
assert!(meta.quickack);
|
||||||
|
assert!(!meta.simple_ack);
|
||||||
|
assert_eq!(meta.padding_len, 3);
|
||||||
|
assert!(meta.has_flags());
|
||||||
|
}
|
||||||
|
}
|
||||||
628
src/stream/frame_codec.rs
Normal file
628
src/stream/frame_codec.rs
Normal file
@@ -0,0 +1,628 @@
|
|||||||
|
//! tokio-util codec integration for MTProto frames
|
||||||
|
//!
|
||||||
|
//! This module provides Encoder/Decoder implementations compatible
|
||||||
|
//! with tokio-util's Framed wrapper for easy async frame I/O.
|
||||||
|
|
||||||
|
use bytes::{Bytes, BytesMut, BufMut};
|
||||||
|
use std::io::{self, Error, ErrorKind};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio_util::codec::{Decoder, Encoder};
|
||||||
|
|
||||||
|
use crate::protocol::constants::ProtoTag;
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
|
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
|
||||||
|
|
||||||
|
// ============= Unified Codec =============
|
||||||
|
|
||||||
|
/// Unified frame codec that wraps all protocol variants
|
||||||
|
///
|
||||||
|
/// This codec implements tokio-util's Encoder and Decoder traits,
|
||||||
|
/// allowing it to be used with `Framed` for async frame I/O.
|
||||||
|
pub struct FrameCodec {
|
||||||
|
/// Protocol variant
|
||||||
|
proto_tag: ProtoTag,
|
||||||
|
/// Maximum allowed frame size
|
||||||
|
max_frame_size: usize,
|
||||||
|
/// RNG for secure padding
|
||||||
|
rng: Arc<SecureRandom>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FrameCodec {
|
||||||
|
/// Create a new codec for the given protocol
|
||||||
|
pub fn new(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
|
||||||
|
Self {
|
||||||
|
proto_tag,
|
||||||
|
max_frame_size: 16 * 1024 * 1024, // 16MB default
|
||||||
|
rng,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set maximum frame size
|
||||||
|
pub fn with_max_frame_size(mut self, size: usize) -> Self {
|
||||||
|
self.max_frame_size = size;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get protocol tag
|
||||||
|
pub fn proto_tag(&self) -> ProtoTag {
|
||||||
|
self.proto_tag
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Decoder for FrameCodec {
|
||||||
|
type Item = Frame;
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||||
|
match self.proto_tag {
|
||||||
|
ProtoTag::Abridged => decode_abridged(src, self.max_frame_size),
|
||||||
|
ProtoTag::Intermediate => decode_intermediate(src, self.max_frame_size),
|
||||||
|
ProtoTag::Secure => decode_secure(src, self.max_frame_size),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encoder<Frame> for FrameCodec {
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||||
|
match self.proto_tag {
|
||||||
|
ProtoTag::Abridged => encode_abridged(&frame, dst),
|
||||||
|
ProtoTag::Intermediate => encode_intermediate(&frame, dst),
|
||||||
|
ProtoTag::Secure => encode_secure(&frame, dst, &self.rng),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Abridged Protocol =============
|
||||||
|
|
||||||
|
fn decode_abridged(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame>> {
|
||||||
|
if src.is_empty() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut meta = FrameMeta::new();
|
||||||
|
let first_byte = src[0];
|
||||||
|
|
||||||
|
// Extract length and quickack flag
|
||||||
|
let mut len_words = (first_byte & 0x7f) as usize;
|
||||||
|
if first_byte >= 0x80 {
|
||||||
|
meta.quickack = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
let header_len;
|
||||||
|
|
||||||
|
if len_words == 0x7f {
|
||||||
|
// Extended length (3 more bytes needed)
|
||||||
|
if src.len() < 4 {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
len_words = u32::from_le_bytes([src[1], src[2], src[3], 0]) as usize;
|
||||||
|
header_len = 4;
|
||||||
|
} else {
|
||||||
|
header_len = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Length is in 4-byte words
|
||||||
|
let byte_len = len_words.checked_mul(4).ok_or_else(|| {
|
||||||
|
Error::new(ErrorKind::InvalidData, "frame length overflow")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Validate size
|
||||||
|
if byte_len > max_size {
|
||||||
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidData,
|
||||||
|
format!("frame too large: {} bytes (max {})", byte_len, max_size)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_len = header_len + byte_len;
|
||||||
|
|
||||||
|
if src.len() < total_len {
|
||||||
|
// Reserve space for the rest of the frame
|
||||||
|
src.reserve(total_len - src.len());
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract data
|
||||||
|
let _ = src.split_to(header_len);
|
||||||
|
let data = src.split_to(byte_len).freeze();
|
||||||
|
|
||||||
|
Ok(Some(Frame::with_meta(data, meta)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
||||||
|
let data = &frame.data;
|
||||||
|
|
||||||
|
// Validate alignment
|
||||||
|
if data.len() % 4 != 0 {
|
||||||
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("abridged frame must be 4-byte aligned, got {} bytes", data.len())
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple ACK: send reversed data without header
|
||||||
|
if frame.meta.simple_ack {
|
||||||
|
dst.reserve(data.len());
|
||||||
|
for byte in data.iter().rev() {
|
||||||
|
dst.put_u8(*byte);
|
||||||
|
}
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let len_words = data.len() / 4;
|
||||||
|
|
||||||
|
if len_words < 0x7f {
|
||||||
|
// Short header
|
||||||
|
dst.reserve(1 + data.len());
|
||||||
|
let mut len_byte = len_words as u8;
|
||||||
|
if frame.meta.quickack {
|
||||||
|
len_byte |= 0x80;
|
||||||
|
}
|
||||||
|
dst.put_u8(len_byte);
|
||||||
|
} else if len_words < (1 << 24) {
|
||||||
|
// Extended header
|
||||||
|
dst.reserve(4 + data.len());
|
||||||
|
let mut first = 0x7fu8;
|
||||||
|
if frame.meta.quickack {
|
||||||
|
first |= 0x80;
|
||||||
|
}
|
||||||
|
dst.put_u8(first);
|
||||||
|
let len_bytes = (len_words as u32).to_le_bytes();
|
||||||
|
dst.extend_from_slice(&len_bytes[..3]);
|
||||||
|
} else {
|
||||||
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidInput,
|
||||||
|
format!("frame too large: {} bytes", data.len())
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.extend_from_slice(data);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Intermediate Protocol =============
|
||||||
|
|
||||||
|
fn decode_intermediate(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame>> {
|
||||||
|
if src.len() < 4 {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut meta = FrameMeta::new();
|
||||||
|
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
|
||||||
|
|
||||||
|
// Check QuickACK flag
|
||||||
|
if len >= 0x80000000 {
|
||||||
|
meta.quickack = true;
|
||||||
|
len -= 0x80000000;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate size
|
||||||
|
if len > max_size {
|
||||||
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidData,
|
||||||
|
format!("frame too large: {} bytes (max {})", len, max_size)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_len = 4 + len;
|
||||||
|
|
||||||
|
if src.len() < total_len {
|
||||||
|
src.reserve(total_len - src.len());
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract data
|
||||||
|
let _ = src.split_to(4);
|
||||||
|
let data = src.split_to(len).freeze();
|
||||||
|
|
||||||
|
Ok(Some(Frame::with_meta(data, meta)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_intermediate(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
||||||
|
let data = &frame.data;
|
||||||
|
|
||||||
|
// Simple ACK: just send data
|
||||||
|
if frame.meta.simple_ack {
|
||||||
|
dst.reserve(data.len());
|
||||||
|
dst.extend_from_slice(data);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.reserve(4 + data.len());
|
||||||
|
|
||||||
|
let mut len = data.len() as u32;
|
||||||
|
if frame.meta.quickack {
|
||||||
|
len |= 0x80000000;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.extend_from_slice(&len.to_le_bytes());
|
||||||
|
dst.extend_from_slice(data);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Secure Intermediate Protocol =============
|
||||||
|
|
||||||
|
fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame>> {
|
||||||
|
if src.len() < 4 {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut meta = FrameMeta::new();
|
||||||
|
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
|
||||||
|
|
||||||
|
// Check QuickACK flag
|
||||||
|
if len >= 0x80000000 {
|
||||||
|
meta.quickack = true;
|
||||||
|
len -= 0x80000000;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate size
|
||||||
|
if len > max_size {
|
||||||
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidData,
|
||||||
|
format!("frame too large: {} bytes (max {})", len, max_size)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_len = 4 + len;
|
||||||
|
|
||||||
|
if src.len() < total_len {
|
||||||
|
src.reserve(total_len - src.len());
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate padding (indicated by length not divisible by 4)
|
||||||
|
let padding_len = len % 4;
|
||||||
|
let data_len = if padding_len != 0 {
|
||||||
|
len - padding_len
|
||||||
|
} else {
|
||||||
|
len
|
||||||
|
};
|
||||||
|
|
||||||
|
meta.padding_len = padding_len as u8;
|
||||||
|
|
||||||
|
// Extract data (excluding padding)
|
||||||
|
let _ = src.split_to(4);
|
||||||
|
let all_data = src.split_to(len);
|
||||||
|
// Copy only the data portion, excluding padding
|
||||||
|
let data = Bytes::copy_from_slice(&all_data[..data_len]);
|
||||||
|
|
||||||
|
Ok(Some(Frame::with_meta(data, meta)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::Result<()> {
|
||||||
|
let data = &frame.data;
|
||||||
|
|
||||||
|
// Simple ACK: just send data
|
||||||
|
if frame.meta.simple_ack {
|
||||||
|
dst.reserve(data.len());
|
||||||
|
dst.extend_from_slice(data);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate padding to make length not divisible by 4
|
||||||
|
let padding_len = if data.len() % 4 == 0 {
|
||||||
|
// Add 1-3 bytes to make it non-aligned
|
||||||
|
(rng.range(3) + 1) as usize
|
||||||
|
} else {
|
||||||
|
// Already non-aligned, can add 0-3
|
||||||
|
rng.range(4) as usize
|
||||||
|
};
|
||||||
|
|
||||||
|
let total_len = data.len() + padding_len;
|
||||||
|
dst.reserve(4 + total_len);
|
||||||
|
|
||||||
|
let mut len = total_len as u32;
|
||||||
|
if frame.meta.quickack {
|
||||||
|
len |= 0x80000000;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.extend_from_slice(&len.to_le_bytes());
|
||||||
|
dst.extend_from_slice(data);
|
||||||
|
|
||||||
|
if padding_len > 0 {
|
||||||
|
let padding = rng.bytes(padding_len);
|
||||||
|
dst.extend_from_slice(&padding);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Typed Codecs =============
|
||||||
|
|
||||||
|
/// Abridged protocol codec
|
||||||
|
pub struct AbridgedCodec {
|
||||||
|
max_frame_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AbridgedCodec {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
max_frame_size: 16 * 1024 * 1024,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AbridgedCodec {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Decoder for AbridgedCodec {
|
||||||
|
type Item = Frame;
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||||
|
decode_abridged(src, self.max_frame_size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encoder<Frame> for AbridgedCodec {
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||||
|
encode_abridged(&frame, dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FrameCodecTrait for AbridgedCodec {
|
||||||
|
fn proto_tag(&self) -> ProtoTag {
|
||||||
|
ProtoTag::Abridged
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
|
||||||
|
let before = dst.len();
|
||||||
|
encode_abridged(frame, dst)?;
|
||||||
|
Ok(dst.len() - before)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
|
||||||
|
decode_abridged(src, self.max_frame_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn min_header_size(&self) -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Intermediate protocol codec
|
||||||
|
pub struct IntermediateCodec {
|
||||||
|
max_frame_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntermediateCodec {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
max_frame_size: 16 * 1024 * 1024,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for IntermediateCodec {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Decoder for IntermediateCodec {
|
||||||
|
type Item = Frame;
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||||
|
decode_intermediate(src, self.max_frame_size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encoder<Frame> for IntermediateCodec {
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||||
|
encode_intermediate(&frame, dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FrameCodecTrait for IntermediateCodec {
|
||||||
|
fn proto_tag(&self) -> ProtoTag {
|
||||||
|
ProtoTag::Intermediate
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
|
||||||
|
let before = dst.len();
|
||||||
|
encode_intermediate(frame, dst)?;
|
||||||
|
Ok(dst.len() - before)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
|
||||||
|
decode_intermediate(src, self.max_frame_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn min_header_size(&self) -> usize {
|
||||||
|
4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Secure Intermediate protocol codec
|
||||||
|
pub struct SecureCodec {
|
||||||
|
max_frame_size: usize,
|
||||||
|
rng: Arc<SecureRandom>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SecureCodec {
|
||||||
|
pub fn new(rng: Arc<SecureRandom>) -> Self {
|
||||||
|
Self {
|
||||||
|
max_frame_size: 16 * 1024 * 1024,
|
||||||
|
rng,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for SecureCodec {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new(Arc::new(SecureRandom::new()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Decoder for SecureCodec {
|
||||||
|
type Item = Frame;
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||||
|
decode_secure(src, self.max_frame_size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encoder<Frame> for SecureCodec {
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||||
|
encode_secure(&frame, dst, &self.rng)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FrameCodecTrait for SecureCodec {
|
||||||
|
fn proto_tag(&self) -> ProtoTag {
|
||||||
|
ProtoTag::Secure
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
|
||||||
|
let before = dst.len();
|
||||||
|
encode_secure(frame, dst, &self.rng)?;
|
||||||
|
Ok(dst.len() - before)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
|
||||||
|
decode_secure(src, self.max_frame_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn min_header_size(&self) -> usize {
|
||||||
|
4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Tests =============
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tokio_util::codec::{FramedRead, FramedWrite};
|
||||||
|
use tokio::io::duplex;
|
||||||
|
use futures::{SinkExt, StreamExt};
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_framed_abridged() {
|
||||||
|
let (client, server) = duplex(4096);
|
||||||
|
|
||||||
|
let mut writer = FramedWrite::new(client, AbridgedCodec::new());
|
||||||
|
let mut reader = FramedRead::new(server, AbridgedCodec::new());
|
||||||
|
|
||||||
|
// Write a frame
|
||||||
|
let frame = Frame::new(Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]));
|
||||||
|
writer.send(frame).await.unwrap();
|
||||||
|
|
||||||
|
// Read it back
|
||||||
|
let received = reader.next().await.unwrap().unwrap();
|
||||||
|
assert_eq!(&received.data[..], &[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_framed_intermediate() {
|
||||||
|
let (client, server) = duplex(4096);
|
||||||
|
|
||||||
|
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
|
||||||
|
let mut reader = FramedRead::new(server, IntermediateCodec::new());
|
||||||
|
|
||||||
|
let frame = Frame::new(Bytes::from_static(b"hello world"));
|
||||||
|
writer.send(frame).await.unwrap();
|
||||||
|
|
||||||
|
let received = reader.next().await.unwrap().unwrap();
|
||||||
|
assert_eq!(&received.data[..], b"hello world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_framed_secure() {
|
||||||
|
let (client, server) = duplex(4096);
|
||||||
|
|
||||||
|
let mut writer = FramedWrite::new(client, SecureCodec::new(Arc::new(SecureRandom::new())));
|
||||||
|
let mut reader = FramedRead::new(server, SecureCodec::new(Arc::new(SecureRandom::new())));
|
||||||
|
|
||||||
|
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||||
|
let frame = Frame::new(original.clone());
|
||||||
|
writer.send(frame).await.unwrap();
|
||||||
|
|
||||||
|
let received = reader.next().await.unwrap().unwrap();
|
||||||
|
assert_eq!(&received.data[..], &original[..]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_unified_codec() {
|
||||||
|
for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
|
||||||
|
let (client, server) = duplex(4096);
|
||||||
|
|
||||||
|
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
|
||||||
|
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
|
||||||
|
|
||||||
|
// Use 4-byte aligned data for abridged compatibility
|
||||||
|
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||||
|
let frame = Frame::new(original.clone());
|
||||||
|
writer.send(frame).await.unwrap();
|
||||||
|
|
||||||
|
let received = reader.next().await.unwrap().unwrap();
|
||||||
|
assert_eq!(received.data.len(), 8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_multiple_frames() {
|
||||||
|
let (client, server) = duplex(4096);
|
||||||
|
|
||||||
|
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
|
||||||
|
let mut reader = FramedRead::new(server, IntermediateCodec::new());
|
||||||
|
|
||||||
|
// Send multiple frames
|
||||||
|
for i in 0..10 {
|
||||||
|
let data: Vec<u8> = (0..((i + 1) * 10)).map(|j| (j % 256) as u8).collect();
|
||||||
|
let frame = Frame::new(Bytes::from(data));
|
||||||
|
writer.send(frame).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive them
|
||||||
|
for i in 0..10 {
|
||||||
|
let received = reader.next().await.unwrap().unwrap();
|
||||||
|
assert_eq!(received.data.len(), (i + 1) * 10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_quickack_flag() {
|
||||||
|
let (client, server) = duplex(4096);
|
||||||
|
|
||||||
|
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
|
||||||
|
let mut reader = FramedRead::new(server, IntermediateCodec::new());
|
||||||
|
|
||||||
|
let frame = Frame::quickack(Bytes::from_static(b"urgent"));
|
||||||
|
writer.send(frame).await.unwrap();
|
||||||
|
|
||||||
|
let received = reader.next().await.unwrap().unwrap();
|
||||||
|
assert!(received.meta.quickack);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_frame_too_large() {
|
||||||
|
let mut codec = FrameCodec::new(ProtoTag::Intermediate, Arc::new(SecureRandom::new()))
|
||||||
|
.with_max_frame_size(100);
|
||||||
|
|
||||||
|
// Create a "frame" that claims to be very large
|
||||||
|
let mut buf = BytesMut::new();
|
||||||
|
buf.extend_from_slice(&1000u32.to_le_bytes()); // length = 1000
|
||||||
|
buf.extend_from_slice(&[0u8; 10]); // partial data
|
||||||
|
|
||||||
|
let result = codec.decode(&mut buf);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,8 +4,8 @@ use bytes::{Bytes, BytesMut};
|
|||||||
use std::io::{Error, ErrorKind, Result};
|
use std::io::{Error, ErrorKind, Result};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||||
use crate::protocol::constants::*;
|
use crate::protocol::constants::*;
|
||||||
use crate::crypto::crc32;
|
use crate::crypto::{crc32, SecureRandom};
|
||||||
use crate::crypto::random::SECURE_RANDOM;
|
use std::sync::Arc;
|
||||||
use super::traits::{FrameMeta, LayeredStream};
|
use super::traits::{FrameMeta, LayeredStream};
|
||||||
|
|
||||||
// ============= Abridged (Compact) Frame =============
|
// ============= Abridged (Compact) Frame =============
|
||||||
@@ -251,11 +251,12 @@ impl<R> LayeredStream<R> for SecureIntermediateFrameReader<R> {
|
|||||||
/// Writer for secure intermediate MTProto framing
|
/// Writer for secure intermediate MTProto framing
|
||||||
pub struct SecureIntermediateFrameWriter<W> {
|
pub struct SecureIntermediateFrameWriter<W> {
|
||||||
upstream: W,
|
upstream: W,
|
||||||
|
rng: Arc<SecureRandom>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<W> SecureIntermediateFrameWriter<W> {
|
impl<W> SecureIntermediateFrameWriter<W> {
|
||||||
pub fn new(upstream: W) -> Self {
|
pub fn new(upstream: W, rng: Arc<SecureRandom>) -> Self {
|
||||||
Self { upstream }
|
Self { upstream, rng }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -267,8 +268,8 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add random padding (0-3 bytes)
|
// Add random padding (0-3 bytes)
|
||||||
let padding_len = SECURE_RANDOM.range(4);
|
let padding_len = self.rng.range(4);
|
||||||
let padding = SECURE_RANDOM.bytes(padding_len);
|
let padding = self.rng.bytes(padding_len);
|
||||||
|
|
||||||
let total_len = data.len() + padding_len;
|
let total_len = data.len() + padding_len;
|
||||||
let len_bytes = (total_len as u32).to_le_bytes();
|
let len_bytes = (total_len as u32).to_le_bytes();
|
||||||
@@ -454,11 +455,11 @@ pub enum FrameWriterKind<W> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
|
impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
|
||||||
pub fn new(upstream: W, proto_tag: ProtoTag) -> Self {
|
pub fn new(upstream: W, proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
|
||||||
match proto_tag {
|
match proto_tag {
|
||||||
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
|
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
|
||||||
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
|
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
|
||||||
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream)),
|
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream, rng)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -483,6 +484,8 @@ impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use tokio::io::duplex;
|
use tokio::io::duplex;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_abridged_roundtrip() {
|
async fn test_abridged_roundtrip() {
|
||||||
@@ -539,7 +542,7 @@ mod tests {
|
|||||||
async fn test_secure_intermediate_padding() {
|
async fn test_secure_intermediate_padding() {
|
||||||
let (client, server) = duplex(1024);
|
let (client, server) = duplex(1024);
|
||||||
|
|
||||||
let mut writer = SecureIntermediateFrameWriter::new(client);
|
let mut writer = SecureIntermediateFrameWriter::new(client, Arc::new(SecureRandom::new()));
|
||||||
let mut reader = SecureIntermediateFrameReader::new(server);
|
let mut reader = SecureIntermediateFrameReader::new(server);
|
||||||
|
|
||||||
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
|
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
|
||||||
@@ -572,7 +575,7 @@ mod tests {
|
|||||||
async fn test_frame_reader_kind() {
|
async fn test_frame_reader_kind() {
|
||||||
let (client, server) = duplex(1024);
|
let (client, server) = duplex(1024);
|
||||||
|
|
||||||
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate);
|
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate, Arc::new(SecureRandom::new()));
|
||||||
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
|
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
|
||||||
|
|
||||||
let data = vec![1u8, 2, 3, 4];
|
let data = vec![1u8, 2, 3, 4];
|
||||||
|
|||||||
@@ -1,10 +1,43 @@
|
|||||||
//! Stream wrappers for MTProto protocol layers
|
//! Stream wrappers for MTProto protocol layers
|
||||||
|
|
||||||
|
pub mod state;
|
||||||
|
pub mod buffer_pool;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
pub mod crypto_stream;
|
pub mod crypto_stream;
|
||||||
pub mod tls_stream;
|
pub mod tls_stream;
|
||||||
|
pub mod frame;
|
||||||
|
pub mod frame_codec;
|
||||||
|
|
||||||
|
// Legacy compatibility - will be removed later
|
||||||
pub mod frame_stream;
|
pub mod frame_stream;
|
||||||
|
|
||||||
|
// Re-export state machine types
|
||||||
|
pub use state::{
|
||||||
|
StreamState, Transition, PollResult,
|
||||||
|
ReadBuffer, WriteBuffer, HeaderBuffer, YieldBuffer,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Re-export buffer pool
|
||||||
|
pub use buffer_pool::{BufferPool, PooledBuffer, PoolStats};
|
||||||
|
|
||||||
|
// Re-export stream implementations
|
||||||
pub use crypto_stream::{CryptoReader, CryptoWriter, PassthroughStream};
|
pub use crypto_stream::{CryptoReader, CryptoWriter, PassthroughStream};
|
||||||
pub use tls_stream::{FakeTlsReader, FakeTlsWriter};
|
pub use tls_stream::{FakeTlsReader, FakeTlsWriter};
|
||||||
pub use frame_stream::*;
|
|
||||||
|
// Re-export frame types
|
||||||
|
pub use frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait, create_codec};
|
||||||
|
|
||||||
|
// Re-export tokio-util compatible codecs
|
||||||
|
pub use frame_codec::{
|
||||||
|
FrameCodec,
|
||||||
|
AbridgedCodec, IntermediateCodec, SecureCodec,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Legacy re-exports for compatibility
|
||||||
|
pub use frame_stream::{
|
||||||
|
AbridgedFrameReader, AbridgedFrameWriter,
|
||||||
|
IntermediateFrameReader, IntermediateFrameWriter,
|
||||||
|
SecureIntermediateFrameReader, SecureIntermediateFrameWriter,
|
||||||
|
MtprotoFrameReader, MtprotoFrameWriter,
|
||||||
|
FrameReaderKind, FrameWriterKind,
|
||||||
|
};
|
||||||
571
src/stream/state.rs
Normal file
571
src/stream/state.rs
Normal file
@@ -0,0 +1,571 @@
|
|||||||
|
//! State machine foundation types for async streams
|
||||||
|
//!
|
||||||
|
//! This module provides core types and traits for implementing
|
||||||
|
//! stateful async streams with proper partial read/write handling.
|
||||||
|
|
||||||
|
use bytes::{Bytes, BytesMut};
|
||||||
|
use std::io;
|
||||||
|
|
||||||
|
// ============= Core Traits =============
|
||||||
|
|
||||||
|
/// Trait for stream states
|
||||||
|
pub trait StreamState: Sized {
|
||||||
|
/// Check if this is a terminal state (no more transitions possible)
|
||||||
|
fn is_terminal(&self) -> bool;
|
||||||
|
|
||||||
|
/// Check if stream is in poisoned/error state
|
||||||
|
fn is_poisoned(&self) -> bool;
|
||||||
|
|
||||||
|
/// Get human-readable state name for debugging
|
||||||
|
fn state_name(&self) -> &'static str;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Transition Types =============
|
||||||
|
|
||||||
|
/// Result of a state transition
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum Transition<S, O> {
|
||||||
|
/// Stay in the same state, no output
|
||||||
|
Same,
|
||||||
|
/// Transition to a new state, no output
|
||||||
|
Next(S),
|
||||||
|
/// Complete with output, typically transitions to Idle
|
||||||
|
Complete(O),
|
||||||
|
/// Yield output and transition to new state
|
||||||
|
Yield(O, S),
|
||||||
|
/// Error occurred, transition to error state
|
||||||
|
Error(io::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, O> Transition<S, O> {
|
||||||
|
/// Check if transition produces output
|
||||||
|
pub fn has_output(&self) -> bool {
|
||||||
|
matches!(self, Transition::Complete(_) | Transition::Yield(_, _))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map the output value
|
||||||
|
pub fn map_output<U, F: FnOnce(O) -> U>(self, f: F) -> Transition<S, U> {
|
||||||
|
match self {
|
||||||
|
Transition::Same => Transition::Same,
|
||||||
|
Transition::Next(s) => Transition::Next(s),
|
||||||
|
Transition::Complete(o) => Transition::Complete(f(o)),
|
||||||
|
Transition::Yield(o, s) => Transition::Yield(f(o), s),
|
||||||
|
Transition::Error(e) => Transition::Error(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map the state value
|
||||||
|
pub fn map_state<T, F: FnOnce(S) -> T>(self, f: F) -> Transition<T, O> {
|
||||||
|
match self {
|
||||||
|
Transition::Same => Transition::Same,
|
||||||
|
Transition::Next(s) => Transition::Next(f(s)),
|
||||||
|
Transition::Complete(o) => Transition::Complete(o),
|
||||||
|
Transition::Yield(o, s) => Transition::Yield(o, f(s)),
|
||||||
|
Transition::Error(e) => Transition::Error(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Poll Result Types =============
|
||||||
|
|
||||||
|
/// Result of polling for more data
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum PollResult<T> {
|
||||||
|
/// Data is ready
|
||||||
|
Ready(T),
|
||||||
|
/// Operation would block, need to poll again
|
||||||
|
Pending,
|
||||||
|
/// Need more input data (minimum bytes required)
|
||||||
|
NeedInput(usize),
|
||||||
|
/// End of stream reached
|
||||||
|
Eof,
|
||||||
|
/// Error occurred
|
||||||
|
Error(io::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> PollResult<T> {
|
||||||
|
/// Check if result is ready
|
||||||
|
pub fn is_ready(&self) -> bool {
|
||||||
|
matches!(self, PollResult::Ready(_))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if result indicates EOF
|
||||||
|
pub fn is_eof(&self) -> bool {
|
||||||
|
matches!(self, PollResult::Eof)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert to Option, discarding non-ready states
|
||||||
|
pub fn ok(self) -> Option<T> {
|
||||||
|
match self {
|
||||||
|
PollResult::Ready(t) => Some(t),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map the value
|
||||||
|
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> PollResult<U> {
|
||||||
|
match self {
|
||||||
|
PollResult::Ready(t) => PollResult::Ready(f(t)),
|
||||||
|
PollResult::Pending => PollResult::Pending,
|
||||||
|
PollResult::NeedInput(n) => PollResult::NeedInput(n),
|
||||||
|
PollResult::Eof => PollResult::Eof,
|
||||||
|
PollResult::Error(e) => PollResult::Error(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> From<io::Result<T>> for PollResult<T> {
|
||||||
|
fn from(result: io::Result<T>) -> Self {
|
||||||
|
match result {
|
||||||
|
Ok(t) => PollResult::Ready(t),
|
||||||
|
Err(e) if e.kind() == io::ErrorKind::WouldBlock => PollResult::Pending,
|
||||||
|
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => PollResult::Eof,
|
||||||
|
Err(e) => PollResult::Error(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Buffer State =============
|
||||||
|
|
||||||
|
/// State for buffered reading operations
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ReadBuffer {
|
||||||
|
/// The buffer holding data
|
||||||
|
buffer: BytesMut,
|
||||||
|
/// Target number of bytes to read (if known)
|
||||||
|
target: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReadBuffer {
|
||||||
|
/// Create new empty read buffer
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
buffer: BytesMut::with_capacity(8192),
|
||||||
|
target: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create with specific capacity
|
||||||
|
pub fn with_capacity(capacity: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
buffer: BytesMut::with_capacity(capacity),
|
||||||
|
target: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create with target size
|
||||||
|
pub fn with_target(target: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
buffer: BytesMut::with_capacity(target),
|
||||||
|
target: Some(target),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get current buffer length
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.buffer.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if buffer is empty
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.buffer.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if target is reached
|
||||||
|
pub fn is_complete(&self) -> bool {
|
||||||
|
match self.target {
|
||||||
|
Some(t) => self.buffer.len() >= t,
|
||||||
|
None => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get remaining bytes needed
|
||||||
|
pub fn remaining(&self) -> usize {
|
||||||
|
match self.target {
|
||||||
|
Some(t) => t.saturating_sub(self.buffer.len()),
|
||||||
|
None => 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append data to buffer
|
||||||
|
pub fn extend(&mut self, data: &[u8]) {
|
||||||
|
self.buffer.extend_from_slice(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Take all data from buffer
|
||||||
|
pub fn take(&mut self) -> Bytes {
|
||||||
|
self.target = None;
|
||||||
|
self.buffer.split().freeze()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Take exactly n bytes
|
||||||
|
pub fn take_exact(&mut self, n: usize) -> Option<Bytes> {
|
||||||
|
if self.buffer.len() >= n {
|
||||||
|
Some(self.buffer.split_to(n).freeze())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a slice of the buffer
|
||||||
|
pub fn as_slice(&self) -> &[u8] {
|
||||||
|
&self.buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get mutable access to underlying BytesMut
|
||||||
|
pub fn as_bytes_mut(&mut self) -> &mut BytesMut {
|
||||||
|
&mut self.buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the buffer
|
||||||
|
pub fn clear(&mut self) {
|
||||||
|
self.buffer.clear();
|
||||||
|
self.target = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set new target
|
||||||
|
pub fn set_target(&mut self, target: usize) {
|
||||||
|
self.target = Some(target);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ReadBuffer {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// State for buffered writing operations
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct WriteBuffer {
|
||||||
|
/// The buffer holding data to write
|
||||||
|
buffer: BytesMut,
|
||||||
|
/// Position of next byte to write
|
||||||
|
position: usize,
|
||||||
|
/// Maximum buffer size
|
||||||
|
max_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WriteBuffer {
|
||||||
|
/// Create new write buffer with default max size (256KB)
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::with_max_size(256 * 1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create with specific max size
|
||||||
|
pub fn with_max_size(max_size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
buffer: BytesMut::with_capacity(8192),
|
||||||
|
position: 0,
|
||||||
|
max_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get pending bytes count
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.buffer.len() - self.position
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if buffer is empty (all written)
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.position >= self.buffer.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if buffer is full
|
||||||
|
pub fn is_full(&self) -> bool {
|
||||||
|
self.buffer.len() >= self.max_size
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get remaining capacity
|
||||||
|
pub fn remaining_capacity(&self) -> usize {
|
||||||
|
self.max_size.saturating_sub(self.buffer.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append data to buffer
|
||||||
|
pub fn extend(&mut self, data: &[u8]) -> Result<(), ()> {
|
||||||
|
if self.buffer.len() + data.len() > self.max_size {
|
||||||
|
return Err(());
|
||||||
|
}
|
||||||
|
self.buffer.extend_from_slice(data);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get slice of data to write
|
||||||
|
pub fn pending(&self) -> &[u8] {
|
||||||
|
&self.buffer[self.position..]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Advance position by n bytes (after successful write)
|
||||||
|
pub fn advance(&mut self, n: usize) {
|
||||||
|
self.position += n;
|
||||||
|
|
||||||
|
// If all data written, reset buffer
|
||||||
|
if self.position >= self.buffer.len() {
|
||||||
|
self.buffer.clear();
|
||||||
|
self.position = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the buffer
|
||||||
|
pub fn clear(&mut self) {
|
||||||
|
self.buffer.clear();
|
||||||
|
self.position = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for WriteBuffer {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Fixed-Size Buffer States =============
|
||||||
|
|
||||||
|
/// State for reading a fixed-size header
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct HeaderBuffer<const N: usize> {
|
||||||
|
/// The buffer
|
||||||
|
data: [u8; N],
|
||||||
|
/// Bytes filled so far
|
||||||
|
filled: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const N: usize> HeaderBuffer<N> {
|
||||||
|
/// Create new empty header buffer
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
data: [0u8; N],
|
||||||
|
filled: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get slice for reading into
|
||||||
|
pub fn unfilled_mut(&mut self) -> &mut [u8] {
|
||||||
|
&mut self.data[self.filled..]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Advance filled count
|
||||||
|
pub fn advance(&mut self, n: usize) {
|
||||||
|
self.filled = (self.filled + n).min(N);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if completely filled
|
||||||
|
pub fn is_complete(&self) -> bool {
|
||||||
|
self.filled >= N
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get remaining bytes needed
|
||||||
|
pub fn remaining(&self) -> usize {
|
||||||
|
N - self.filled
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get filled bytes as slice
|
||||||
|
pub fn as_slice(&self) -> &[u8] {
|
||||||
|
&self.data[..self.filled]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get complete buffer (panics if not complete)
|
||||||
|
pub fn as_array(&self) -> &[u8; N] {
|
||||||
|
assert!(self.is_complete());
|
||||||
|
&self.data
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Take the buffer, resetting state
|
||||||
|
pub fn take(&mut self) -> [u8; N] {
|
||||||
|
let data = self.data;
|
||||||
|
self.data = [0u8; N];
|
||||||
|
self.filled = 0;
|
||||||
|
data
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reset to empty state
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.filled = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const N: usize> Default for HeaderBuffer<N> {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Yield Buffer =============
|
||||||
|
|
||||||
|
/// Buffer for yielding data to caller in chunks
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct YieldBuffer {
|
||||||
|
data: Bytes,
|
||||||
|
position: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl YieldBuffer {
|
||||||
|
/// Create new yield buffer
|
||||||
|
pub fn new(data: Bytes) -> Self {
|
||||||
|
Self { data, position: 0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if all data has been yielded
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.position >= self.data.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get remaining bytes
|
||||||
|
pub fn remaining(&self) -> usize {
|
||||||
|
self.data.len() - self.position
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copy data to output slice, return bytes copied
|
||||||
|
pub fn copy_to(&mut self, dst: &mut [u8]) -> usize {
|
||||||
|
let available = &self.data[self.position..];
|
||||||
|
let to_copy = available.len().min(dst.len());
|
||||||
|
dst[..to_copy].copy_from_slice(&available[..to_copy]);
|
||||||
|
self.position += to_copy;
|
||||||
|
to_copy
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get remaining data as slice
|
||||||
|
pub fn as_slice(&self) -> &[u8] {
|
||||||
|
&self.data[self.position..]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Macros =============
|
||||||
|
|
||||||
|
/// Macro to simplify state transitions in poll methods
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! transition {
|
||||||
|
(same) => {
|
||||||
|
$crate::stream::state::Transition::Same
|
||||||
|
};
|
||||||
|
(next $state:expr) => {
|
||||||
|
$crate::stream::state::Transition::Next($state)
|
||||||
|
};
|
||||||
|
(complete $output:expr) => {
|
||||||
|
$crate::stream::state::Transition::Complete($output)
|
||||||
|
};
|
||||||
|
(yield $output:expr, $state:expr) => {
|
||||||
|
$crate::stream::state::Transition::Yield($output, $state)
|
||||||
|
};
|
||||||
|
(error $err:expr) => {
|
||||||
|
$crate::stream::state::Transition::Error($err)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Macro to match poll ready or return pending
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! ready_or_pending {
|
||||||
|
($poll:expr) => {
|
||||||
|
match $poll {
|
||||||
|
std::task::Poll::Ready(t) => t,
|
||||||
|
std::task::Poll::Pending => return std::task::Poll::Pending,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_read_buffer_basic() {
|
||||||
|
let mut buf = ReadBuffer::with_target(10);
|
||||||
|
assert_eq!(buf.remaining(), 10);
|
||||||
|
assert!(!buf.is_complete());
|
||||||
|
|
||||||
|
buf.extend(b"hello");
|
||||||
|
assert_eq!(buf.len(), 5);
|
||||||
|
assert_eq!(buf.remaining(), 5);
|
||||||
|
assert!(!buf.is_complete());
|
||||||
|
|
||||||
|
buf.extend(b"world");
|
||||||
|
assert_eq!(buf.len(), 10);
|
||||||
|
assert!(buf.is_complete());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_read_buffer_take() {
|
||||||
|
let mut buf = ReadBuffer::new();
|
||||||
|
buf.extend(b"test data");
|
||||||
|
|
||||||
|
let data = buf.take();
|
||||||
|
assert_eq!(&data[..], b"test data");
|
||||||
|
assert!(buf.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_write_buffer_basic() {
|
||||||
|
let mut buf = WriteBuffer::with_max_size(100);
|
||||||
|
assert!(buf.is_empty());
|
||||||
|
|
||||||
|
buf.extend(b"hello").unwrap();
|
||||||
|
assert_eq!(buf.len(), 5);
|
||||||
|
assert!(!buf.is_empty());
|
||||||
|
|
||||||
|
buf.advance(3);
|
||||||
|
assert_eq!(buf.len(), 2);
|
||||||
|
assert_eq!(buf.pending(), b"lo");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_write_buffer_overflow() {
|
||||||
|
let mut buf = WriteBuffer::with_max_size(10);
|
||||||
|
assert!(buf.extend(b"short").is_ok());
|
||||||
|
assert!(buf.extend(b"toolong").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_header_buffer() {
|
||||||
|
let mut buf = HeaderBuffer::<5>::new();
|
||||||
|
assert!(!buf.is_complete());
|
||||||
|
assert_eq!(buf.remaining(), 5);
|
||||||
|
|
||||||
|
buf.unfilled_mut()[..3].copy_from_slice(b"hel");
|
||||||
|
buf.advance(3);
|
||||||
|
assert_eq!(buf.remaining(), 2);
|
||||||
|
|
||||||
|
buf.unfilled_mut()[..2].copy_from_slice(b"lo");
|
||||||
|
buf.advance(2);
|
||||||
|
assert!(buf.is_complete());
|
||||||
|
assert_eq!(buf.as_array(), b"hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_yield_buffer() {
|
||||||
|
let mut buf = YieldBuffer::new(Bytes::from_static(b"hello world"));
|
||||||
|
|
||||||
|
let mut dst = [0u8; 5];
|
||||||
|
assert_eq!(buf.copy_to(&mut dst), 5);
|
||||||
|
assert_eq!(&dst, b"hello");
|
||||||
|
|
||||||
|
assert_eq!(buf.remaining(), 6);
|
||||||
|
|
||||||
|
let mut dst = [0u8; 10];
|
||||||
|
assert_eq!(buf.copy_to(&mut dst), 6);
|
||||||
|
assert_eq!(&dst[..6], b" world");
|
||||||
|
|
||||||
|
assert!(buf.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_transition_map() {
|
||||||
|
let t: Transition<i32, String> = Transition::Complete("hello".to_string());
|
||||||
|
let t = t.map_output(|s| s.len());
|
||||||
|
|
||||||
|
match t {
|
||||||
|
Transition::Complete(5) => {}
|
||||||
|
_ => panic!("Expected Complete(5)"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_poll_result() {
|
||||||
|
let r: PollResult<i32> = PollResult::Ready(42);
|
||||||
|
assert!(r.is_ready());
|
||||||
|
assert_eq!(r.ok(), Some(42));
|
||||||
|
|
||||||
|
let r: PollResult<i32> = PollResult::Eof;
|
||||||
|
assert!(r.is_eof());
|
||||||
|
assert_eq!(r.ok(), None);
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,11 @@
|
|||||||
pub mod pool;
|
pub mod pool;
|
||||||
pub mod proxy_protocol;
|
pub mod proxy_protocol;
|
||||||
pub mod socket;
|
pub mod socket;
|
||||||
|
pub mod socks;
|
||||||
|
pub mod upstream;
|
||||||
|
|
||||||
pub use pool::ConnectionPool;
|
pub use pool::ConnectionPool;
|
||||||
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
|
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
|
||||||
pub use socket::*;
|
pub use socket::*;
|
||||||
|
pub use socks::*;
|
||||||
|
pub use upstream::{UpstreamManager, StartupPingResult, DcPingResult};
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
//! TCP Socket Configuration
|
//! TCP Socket Configuration
|
||||||
|
|
||||||
use std::io::Result;
|
use std::io::Result;
|
||||||
use std::net::SocketAddr;
|
use std::net::{SocketAddr, IpAddr};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use socket2::{Socket, TcpKeepalive, Domain, Type, Protocol};
|
use socket2::{Socket, TcpKeepalive, Domain, Type, Protocol};
|
||||||
@@ -30,20 +30,13 @@ pub fn configure_tcp_socket(
|
|||||||
socket.set_tcp_keepalive(&keepalive)?;
|
socket.set_tcp_keepalive(&keepalive)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set buffer sizes
|
// CHANGED: Removed manual buffer size setting (was 256KB).
|
||||||
set_buffer_sizes(&socket, 65536, 65536)?;
|
// Allowing the OS kernel to handle TCP window scaling (Autotuning) is critical
|
||||||
|
// for mobile clients to avoid bufferbloat and stalled connections during uploads.
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set socket buffer sizes
|
|
||||||
fn set_buffer_sizes(socket: &socket2::SockRef, recv: usize, send: usize) -> Result<()> {
|
|
||||||
// These may fail on some systems, so we ignore errors
|
|
||||||
let _ = socket.set_recv_buffer_size(recv);
|
|
||||||
let _ = socket.set_send_buffer_size(send);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Configure socket for accepting client connections
|
/// Configure socket for accepting client connections
|
||||||
pub fn configure_client_socket(
|
pub fn configure_client_socket(
|
||||||
stream: &TcpStream,
|
stream: &TcpStream,
|
||||||
@@ -65,6 +58,8 @@ pub fn configure_client_socket(
|
|||||||
socket.set_tcp_keepalive(&keepalive)?;
|
socket.set_tcp_keepalive(&keepalive)?;
|
||||||
|
|
||||||
// Set TCP user timeout (Linux only)
|
// Set TCP user timeout (Linux only)
|
||||||
|
// NOTE: iOS does not support TCP_USER_TIMEOUT - application-level timeout
|
||||||
|
// is implemented in relay_bidirectional instead
|
||||||
#[cfg(target_os = "linux")]
|
#[cfg(target_os = "linux")]
|
||||||
{
|
{
|
||||||
use std::os::unix::io::AsRawFd;
|
use std::os::unix::io::AsRawFd;
|
||||||
@@ -93,6 +88,11 @@ pub fn set_linger_zero(stream: &TcpStream) -> Result<()> {
|
|||||||
|
|
||||||
/// Create a new TCP socket for outgoing connections
|
/// Create a new TCP socket for outgoing connections
|
||||||
pub fn create_outgoing_socket(addr: SocketAddr) -> Result<Socket> {
|
pub fn create_outgoing_socket(addr: SocketAddr) -> Result<Socket> {
|
||||||
|
create_outgoing_socket_bound(addr, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new TCP socket for outgoing connections, optionally bound to a specific interface
|
||||||
|
pub fn create_outgoing_socket_bound(addr: SocketAddr, bind_addr: Option<IpAddr>) -> Result<Socket> {
|
||||||
let domain = if addr.is_ipv4() {
|
let domain = if addr.is_ipv4() {
|
||||||
Domain::IPV4
|
Domain::IPV4
|
||||||
} else {
|
} else {
|
||||||
@@ -107,9 +107,16 @@ pub fn create_outgoing_socket(addr: SocketAddr) -> Result<Socket> {
|
|||||||
// Disable Nagle
|
// Disable Nagle
|
||||||
socket.set_nodelay(true)?;
|
socket.set_nodelay(true)?;
|
||||||
|
|
||||||
|
if let Some(bind_ip) = bind_addr {
|
||||||
|
let bind_sock_addr = SocketAddr::new(bind_ip, 0);
|
||||||
|
socket.bind(&bind_sock_addr.into())?;
|
||||||
|
debug!("Bound outgoing socket to {}", bind_ip);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(socket)
|
Ok(socket)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Get local address of a socket
|
/// Get local address of a socket
|
||||||
pub fn get_local_addr(stream: &TcpStream) -> Option<SocketAddr> {
|
pub fn get_local_addr(stream: &TcpStream) -> Option<SocketAddr> {
|
||||||
stream.local_addr().ok()
|
stream.local_addr().ok()
|
||||||
|
|||||||
145
src/transport/socks.rs
Normal file
145
src/transport/socks.rs
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
//! SOCKS4/5 Client Implementation
|
||||||
|
|
||||||
|
use std::net::{IpAddr, SocketAddr};
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use crate::error::{ProxyError, Result};
|
||||||
|
|
||||||
|
pub async fn connect_socks4(
|
||||||
|
stream: &mut TcpStream,
|
||||||
|
target: SocketAddr,
|
||||||
|
user_id: Option<&str>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let ip = match target.ip() {
|
||||||
|
IpAddr::V4(ip) => ip,
|
||||||
|
IpAddr::V6(_) => return Err(ProxyError::Proxy("SOCKS4 does not support IPv6".to_string())),
|
||||||
|
};
|
||||||
|
|
||||||
|
let port = target.port();
|
||||||
|
let user = user_id.unwrap_or("").as_bytes();
|
||||||
|
|
||||||
|
// VN (4) | CD (1) | DSTPORT (2) | DSTIP (4) | USERID (variable) | NULL (1)
|
||||||
|
let mut buf = Vec::with_capacity(9 + user.len());
|
||||||
|
buf.push(4); // VN
|
||||||
|
buf.push(1); // CD (CONNECT)
|
||||||
|
buf.extend_from_slice(&port.to_be_bytes());
|
||||||
|
buf.extend_from_slice(&ip.octets());
|
||||||
|
buf.extend_from_slice(user);
|
||||||
|
buf.push(0); // NULL
|
||||||
|
|
||||||
|
stream.write_all(&buf).await.map_err(|e| ProxyError::Io(e))?;
|
||||||
|
|
||||||
|
// Response: VN (1) | CD (1) | DSTPORT (2) | DSTIP (4)
|
||||||
|
let mut resp = [0u8; 8];
|
||||||
|
stream.read_exact(&mut resp).await.map_err(|e| ProxyError::Io(e))?;
|
||||||
|
|
||||||
|
if resp[1] != 90 {
|
||||||
|
return Err(ProxyError::Proxy(format!("SOCKS4 request rejected: code {}", resp[1])));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn connect_socks5(
|
||||||
|
stream: &mut TcpStream,
|
||||||
|
target: SocketAddr,
|
||||||
|
username: Option<&str>,
|
||||||
|
password: Option<&str>,
|
||||||
|
) -> Result<()> {
|
||||||
|
// 1. Auth negotiation
|
||||||
|
// VER (1) | NMETHODS (1) | METHODS (variable)
|
||||||
|
let mut methods = vec![0u8]; // No auth
|
||||||
|
if username.is_some() {
|
||||||
|
methods.push(2u8); // Username/Password
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut buf = vec![5u8, methods.len() as u8];
|
||||||
|
buf.extend_from_slice(&methods);
|
||||||
|
|
||||||
|
stream.write_all(&buf).await.map_err(|e| ProxyError::Io(e))?;
|
||||||
|
|
||||||
|
let mut resp = [0u8; 2];
|
||||||
|
stream.read_exact(&mut resp).await.map_err(|e| ProxyError::Io(e))?;
|
||||||
|
|
||||||
|
if resp[0] != 5 {
|
||||||
|
return Err(ProxyError::Proxy("Invalid SOCKS5 version".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
match resp[1] {
|
||||||
|
0 => {}, // No auth
|
||||||
|
2 => {
|
||||||
|
// Username/Password auth
|
||||||
|
if let (Some(u), Some(p)) = (username, password) {
|
||||||
|
let u_bytes = u.as_bytes();
|
||||||
|
let p_bytes = p.as_bytes();
|
||||||
|
|
||||||
|
let mut auth_buf = Vec::with_capacity(3 + u_bytes.len() + p_bytes.len());
|
||||||
|
auth_buf.push(1); // VER
|
||||||
|
auth_buf.push(u_bytes.len() as u8);
|
||||||
|
auth_buf.extend_from_slice(u_bytes);
|
||||||
|
auth_buf.push(p_bytes.len() as u8);
|
||||||
|
auth_buf.extend_from_slice(p_bytes);
|
||||||
|
|
||||||
|
stream.write_all(&auth_buf).await.map_err(|e| ProxyError::Io(e))?;
|
||||||
|
|
||||||
|
let mut auth_resp = [0u8; 2];
|
||||||
|
stream.read_exact(&mut auth_resp).await.map_err(|e| ProxyError::Io(e))?;
|
||||||
|
|
||||||
|
if auth_resp[1] != 0 {
|
||||||
|
return Err(ProxyError::Proxy("SOCKS5 authentication failed".to_string()));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(ProxyError::Proxy("SOCKS5 server requires authentication".to_string()));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => return Err(ProxyError::Proxy("Unsupported SOCKS5 auth method".to_string())),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Connection request
|
||||||
|
// VER (1) | CMD (1) | RSV (1) | ATYP (1) | DST.ADDR (variable) | DST.PORT (2)
|
||||||
|
let mut req = vec![5u8, 1u8, 0u8]; // CONNECT
|
||||||
|
|
||||||
|
match target {
|
||||||
|
SocketAddr::V4(v4) => {
|
||||||
|
req.push(1u8); // IPv4
|
||||||
|
req.extend_from_slice(&v4.ip().octets());
|
||||||
|
},
|
||||||
|
SocketAddr::V6(v6) => {
|
||||||
|
req.push(4u8); // IPv6
|
||||||
|
req.extend_from_slice(&v6.ip().octets());
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
req.extend_from_slice(&target.port().to_be_bytes());
|
||||||
|
|
||||||
|
stream.write_all(&req).await.map_err(|e| ProxyError::Io(e))?;
|
||||||
|
|
||||||
|
// Response
|
||||||
|
let mut head = [0u8; 4];
|
||||||
|
stream.read_exact(&mut head).await.map_err(|e| ProxyError::Io(e))?;
|
||||||
|
|
||||||
|
if head[1] != 0 {
|
||||||
|
return Err(ProxyError::Proxy(format!("SOCKS5 request failed: code {}", head[1])));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip address part of response
|
||||||
|
match head[3] {
|
||||||
|
1 => { // IPv4
|
||||||
|
let mut addr = [0u8; 4 + 2];
|
||||||
|
stream.read_exact(&mut addr).await.map_err(|e| ProxyError::Io(e))?;
|
||||||
|
},
|
||||||
|
3 => { // Domain
|
||||||
|
let mut len = [0u8; 1];
|
||||||
|
stream.read_exact(&mut len).await.map_err(|e| ProxyError::Io(e))?;
|
||||||
|
let mut addr = vec![0u8; len[0] as usize + 2];
|
||||||
|
stream.read_exact(&mut addr).await.map_err(|e| ProxyError::Io(e))?;
|
||||||
|
},
|
||||||
|
4 => { // IPv6
|
||||||
|
let mut addr = [0u8; 16 + 2];
|
||||||
|
stream.read_exact(&mut addr).await.map_err(|e| ProxyError::Io(e))?;
|
||||||
|
},
|
||||||
|
_ => return Err(ProxyError::Proxy("Invalid address type in SOCKS5 response".to_string())),
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
488
src/transport/upstream.rs
Normal file
488
src/transport/upstream.rs
Normal file
@@ -0,0 +1,488 @@
|
|||||||
|
//! Upstream Management with per-DC latency-weighted selection
|
||||||
|
|
||||||
|
use std::net::{SocketAddr, IpAddr};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use rand::Rng;
|
||||||
|
use tracing::{debug, warn, info, trace};
|
||||||
|
|
||||||
|
use crate::config::{UpstreamConfig, UpstreamType};
|
||||||
|
use crate::error::{Result, ProxyError};
|
||||||
|
use crate::protocol::constants::{TG_DATACENTERS_V4, TG_DATACENTERS_V6, TG_DATACENTER_PORT};
|
||||||
|
use crate::transport::socket::create_outgoing_socket_bound;
|
||||||
|
use crate::transport::socks::{connect_socks4, connect_socks5};
|
||||||
|
|
||||||
|
/// Number of Telegram datacenters
|
||||||
|
const NUM_DCS: usize = 5;
|
||||||
|
|
||||||
|
// ============= RTT Tracking =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
struct LatencyEma {
|
||||||
|
value_ms: Option<f64>,
|
||||||
|
alpha: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LatencyEma {
|
||||||
|
const fn new(alpha: f64) -> Self {
|
||||||
|
Self { value_ms: None, alpha }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&mut self, sample_ms: f64) {
|
||||||
|
self.value_ms = Some(match self.value_ms {
|
||||||
|
None => sample_ms,
|
||||||
|
Some(prev) => prev * (1.0 - self.alpha) + sample_ms * self.alpha,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get(&self) -> Option<f64> {
|
||||||
|
self.value_ms
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Upstream State =============
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct UpstreamState {
|
||||||
|
config: UpstreamConfig,
|
||||||
|
healthy: bool,
|
||||||
|
fails: u32,
|
||||||
|
last_check: std::time::Instant,
|
||||||
|
/// Per-DC latency EMA (index 0 = DC1, index 4 = DC5)
|
||||||
|
dc_latency: [LatencyEma; NUM_DCS],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UpstreamState {
|
||||||
|
fn new(config: UpstreamConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
healthy: true,
|
||||||
|
fails: 0,
|
||||||
|
last_check: std::time::Instant::now(),
|
||||||
|
dc_latency: [LatencyEma::new(0.3); NUM_DCS],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map DC index to latency array slot (0..NUM_DCS).
|
||||||
|
///
|
||||||
|
/// Matches the C implementation's `mf_cluster_lookup` behavior:
|
||||||
|
/// - Standard DCs ±1..±5 → direct mapping to array index 0..4
|
||||||
|
/// - Unknown DCs (CDN, media, etc.) → default DC slot (index 1 = DC 2)
|
||||||
|
/// This matches Telegram's `default 2;` in proxy-multi.conf.
|
||||||
|
/// - There is NO modular arithmetic in the C implementation.
|
||||||
|
fn dc_array_idx(dc_idx: i16) -> Option<usize> {
|
||||||
|
let abs_dc = dc_idx.unsigned_abs() as usize;
|
||||||
|
if abs_dc == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
if abs_dc >= 1 && abs_dc <= NUM_DCS {
|
||||||
|
Some(abs_dc - 1)
|
||||||
|
} else {
|
||||||
|
// Unknown DC → default cluster (DC 2, index 1)
|
||||||
|
// Same as C: mf_cluster_lookup returns default_cluster
|
||||||
|
Some(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get latency for a specific DC, falling back to average across all known DCs
|
||||||
|
fn effective_latency(&self, dc_idx: Option<i16>) -> Option<f64> {
|
||||||
|
// Try DC-specific latency first
|
||||||
|
if let Some(di) = dc_idx.and_then(Self::dc_array_idx) {
|
||||||
|
if let Some(ms) = self.dc_latency[di].get() {
|
||||||
|
return Some(ms);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: average of all known DC latencies
|
||||||
|
let (sum, count) = self.dc_latency.iter()
|
||||||
|
.filter_map(|l| l.get())
|
||||||
|
.fold((0.0, 0u32), |(s, c), v| (s + v, c + 1));
|
||||||
|
|
||||||
|
if count > 0 { Some(sum / count as f64) } else { None }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of a single DC ping
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DcPingResult {
|
||||||
|
pub dc_idx: usize,
|
||||||
|
pub dc_addr: SocketAddr,
|
||||||
|
pub rtt_ms: Option<f64>,
|
||||||
|
pub error: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of startup ping for one upstream
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StartupPingResult {
|
||||||
|
pub results: Vec<DcPingResult>,
|
||||||
|
pub upstream_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Upstream Manager =============
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct UpstreamManager {
|
||||||
|
upstreams: Arc<RwLock<Vec<UpstreamState>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UpstreamManager {
|
||||||
|
pub fn new(configs: Vec<UpstreamConfig>) -> Self {
|
||||||
|
let states = configs.into_iter()
|
||||||
|
.filter(|c| c.enabled)
|
||||||
|
.map(UpstreamState::new)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
upstreams: Arc::new(RwLock::new(states)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Select upstream using latency-weighted random selection.
|
||||||
|
///
|
||||||
|
/// `effective_weight = config_weight × latency_factor`
|
||||||
|
///
|
||||||
|
/// where `latency_factor = 1000 / latency_ms` if latency is known,
|
||||||
|
/// or `1.0` if no latency data is available.
|
||||||
|
///
|
||||||
|
/// This means a 50ms upstream gets factor 20, a 200ms upstream gets
|
||||||
|
/// factor 5 — the faster route is 4× more likely to be chosen
|
||||||
|
/// (all else being equal).
|
||||||
|
async fn select_upstream(&self, dc_idx: Option<i16>) -> Option<usize> {
|
||||||
|
let upstreams = self.upstreams.read().await;
|
||||||
|
if upstreams.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let healthy: Vec<usize> = upstreams.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter(|(_, u)| u.healthy)
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if healthy.is_empty() {
|
||||||
|
// All unhealthy — pick any
|
||||||
|
return Some(rand::rng().gen_range(0..upstreams.len()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if healthy.len() == 1 {
|
||||||
|
return Some(healthy[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate latency-weighted scores
|
||||||
|
let weights: Vec<(usize, f64)> = healthy.iter().map(|&i| {
|
||||||
|
let base = upstreams[i].config.weight as f64;
|
||||||
|
let latency_factor = upstreams[i].effective_latency(dc_idx)
|
||||||
|
.map(|ms| if ms > 1.0 { 1000.0 / ms } else { 1000.0 })
|
||||||
|
.unwrap_or(1.0);
|
||||||
|
|
||||||
|
(i, base * latency_factor)
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
let total: f64 = weights.iter().map(|(_, w)| w).sum();
|
||||||
|
|
||||||
|
if total <= 0.0 {
|
||||||
|
return Some(healthy[rand::rng().gen_range(0..healthy.len())]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut choice: f64 = rand::rng().gen_range(0.0..total);
|
||||||
|
|
||||||
|
for &(idx, weight) in &weights {
|
||||||
|
if choice < weight {
|
||||||
|
trace!(
|
||||||
|
upstream = idx,
|
||||||
|
dc = ?dc_idx,
|
||||||
|
weight = format!("{:.2}", weight),
|
||||||
|
total = format!("{:.2}", total),
|
||||||
|
"Upstream selected"
|
||||||
|
);
|
||||||
|
return Some(idx);
|
||||||
|
}
|
||||||
|
choice -= weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(healthy[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Connect to target through a selected upstream.
|
||||||
|
///
|
||||||
|
/// `dc_idx` is used for latency-based upstream selection and RTT tracking.
|
||||||
|
/// Pass `None` if DC index is unknown.
|
||||||
|
pub async fn connect(&self, target: SocketAddr, dc_idx: Option<i16>) -> Result<TcpStream> {
|
||||||
|
let idx = self.select_upstream(dc_idx).await
|
||||||
|
.ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?;
|
||||||
|
|
||||||
|
let upstream = {
|
||||||
|
let guard = self.upstreams.read().await;
|
||||||
|
guard[idx].config.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
|
match self.connect_via_upstream(&upstream, target).await {
|
||||||
|
Ok(stream) => {
|
||||||
|
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
|
||||||
|
let mut guard = self.upstreams.write().await;
|
||||||
|
if let Some(u) = guard.get_mut(idx) {
|
||||||
|
if !u.healthy {
|
||||||
|
debug!(rtt_ms = format!("{:.1}", rtt_ms), "Upstream recovered");
|
||||||
|
}
|
||||||
|
u.healthy = true;
|
||||||
|
u.fails = 0;
|
||||||
|
|
||||||
|
// Store per-DC latency
|
||||||
|
if let Some(di) = dc_idx.and_then(UpstreamState::dc_array_idx) {
|
||||||
|
u.dc_latency[di].update(rtt_ms);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(stream)
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
let mut guard = self.upstreams.write().await;
|
||||||
|
if let Some(u) = guard.get_mut(idx) {
|
||||||
|
u.fails += 1;
|
||||||
|
warn!(fails = u.fails, "Upstream failed: {}", e);
|
||||||
|
if u.fails > 3 {
|
||||||
|
u.healthy = false;
|
||||||
|
warn!("Upstream marked unhealthy");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_via_upstream(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<TcpStream> {
|
||||||
|
match &config.upstream_type {
|
||||||
|
UpstreamType::Direct { interface } => {
|
||||||
|
let bind_ip = interface.as_ref()
|
||||||
|
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||||
|
|
||||||
|
let socket = create_outgoing_socket_bound(target, bind_ip)?;
|
||||||
|
|
||||||
|
socket.set_nonblocking(true)?;
|
||||||
|
match socket.connect(&target.into()) {
|
||||||
|
Ok(()) => {},
|
||||||
|
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
|
||||||
|
Err(err) => return Err(ProxyError::Io(err)),
|
||||||
|
}
|
||||||
|
|
||||||
|
let std_stream: std::net::TcpStream = socket.into();
|
||||||
|
let stream = TcpStream::from_std(std_stream)?;
|
||||||
|
|
||||||
|
stream.writable().await?;
|
||||||
|
if let Some(e) = stream.take_error()? {
|
||||||
|
return Err(ProxyError::Io(e));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(stream)
|
||||||
|
},
|
||||||
|
UpstreamType::Socks4 { address, interface, user_id } => {
|
||||||
|
let proxy_addr: SocketAddr = address.parse()
|
||||||
|
.map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?;
|
||||||
|
|
||||||
|
let bind_ip = interface.as_ref()
|
||||||
|
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||||
|
|
||||||
|
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
|
||||||
|
|
||||||
|
socket.set_nonblocking(true)?;
|
||||||
|
match socket.connect(&proxy_addr.into()) {
|
||||||
|
Ok(()) => {},
|
||||||
|
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
|
||||||
|
Err(err) => return Err(ProxyError::Io(err)),
|
||||||
|
}
|
||||||
|
|
||||||
|
let std_stream: std::net::TcpStream = socket.into();
|
||||||
|
let mut stream = TcpStream::from_std(std_stream)?;
|
||||||
|
|
||||||
|
stream.writable().await?;
|
||||||
|
if let Some(e) = stream.take_error()? {
|
||||||
|
return Err(ProxyError::Io(e));
|
||||||
|
}
|
||||||
|
|
||||||
|
connect_socks4(&mut stream, target, user_id.as_deref()).await?;
|
||||||
|
Ok(stream)
|
||||||
|
},
|
||||||
|
UpstreamType::Socks5 { address, interface, username, password } => {
|
||||||
|
let proxy_addr: SocketAddr = address.parse()
|
||||||
|
.map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?;
|
||||||
|
|
||||||
|
let bind_ip = interface.as_ref()
|
||||||
|
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||||
|
|
||||||
|
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
|
||||||
|
|
||||||
|
socket.set_nonblocking(true)?;
|
||||||
|
match socket.connect(&proxy_addr.into()) {
|
||||||
|
Ok(()) => {},
|
||||||
|
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
|
||||||
|
Err(err) => return Err(ProxyError::Io(err)),
|
||||||
|
}
|
||||||
|
|
||||||
|
let std_stream: std::net::TcpStream = socket.into();
|
||||||
|
let mut stream = TcpStream::from_std(std_stream)?;
|
||||||
|
|
||||||
|
stream.writable().await?;
|
||||||
|
if let Some(e) = stream.take_error()? {
|
||||||
|
return Err(ProxyError::Io(e));
|
||||||
|
}
|
||||||
|
|
||||||
|
connect_socks5(&mut stream, target, username.as_deref(), password.as_deref()).await?;
|
||||||
|
Ok(stream)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Startup Ping =============
|
||||||
|
|
||||||
|
/// Ping all Telegram DCs through all upstreams.
|
||||||
|
pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec<StartupPingResult> {
|
||||||
|
let upstreams: Vec<(usize, UpstreamConfig)> = {
|
||||||
|
let guard = self.upstreams.read().await;
|
||||||
|
guard.iter().enumerate()
|
||||||
|
.map(|(i, u)| (i, u.config.clone()))
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
|
||||||
|
|
||||||
|
let mut all_results = Vec::new();
|
||||||
|
|
||||||
|
for (upstream_idx, upstream_config) in &upstreams {
|
||||||
|
let upstream_name = match &upstream_config.upstream_type {
|
||||||
|
UpstreamType::Direct { interface } => {
|
||||||
|
format!("direct{}", interface.as_ref().map(|i| format!(" ({})", i)).unwrap_or_default())
|
||||||
|
}
|
||||||
|
UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address),
|
||||||
|
UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut dc_results = Vec::new();
|
||||||
|
|
||||||
|
for (dc_zero_idx, dc_ip) in datacenters.iter().enumerate() {
|
||||||
|
let dc_addr = SocketAddr::new(*dc_ip, TG_DATACENTER_PORT);
|
||||||
|
|
||||||
|
let ping_result = tokio::time::timeout(
|
||||||
|
Duration::from_secs(5),
|
||||||
|
self.ping_single_dc(upstream_config, dc_addr)
|
||||||
|
).await;
|
||||||
|
|
||||||
|
let result = match ping_result {
|
||||||
|
Ok(Ok(rtt_ms)) => {
|
||||||
|
// Store per-DC latency
|
||||||
|
let mut guard = self.upstreams.write().await;
|
||||||
|
if let Some(u) = guard.get_mut(*upstream_idx) {
|
||||||
|
u.dc_latency[dc_zero_idx].update(rtt_ms);
|
||||||
|
}
|
||||||
|
DcPingResult {
|
||||||
|
dc_idx: dc_zero_idx + 1,
|
||||||
|
dc_addr,
|
||||||
|
rtt_ms: Some(rtt_ms),
|
||||||
|
error: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Err(e)) => DcPingResult {
|
||||||
|
dc_idx: dc_zero_idx + 1,
|
||||||
|
dc_addr,
|
||||||
|
rtt_ms: None,
|
||||||
|
error: Some(e.to_string()),
|
||||||
|
},
|
||||||
|
Err(_) => DcPingResult {
|
||||||
|
dc_idx: dc_zero_idx + 1,
|
||||||
|
dc_addr,
|
||||||
|
rtt_ms: None,
|
||||||
|
error: Some("timeout (5s)".to_string()),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
dc_results.push(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
all_results.push(StartupPingResult {
|
||||||
|
results: dc_results,
|
||||||
|
upstream_name,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
all_results
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<f64> {
|
||||||
|
let start = Instant::now();
|
||||||
|
let _stream = self.connect_via_upstream(config, target).await?;
|
||||||
|
Ok(start.elapsed().as_secs_f64() * 1000.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Health Checks =============
|
||||||
|
|
||||||
|
/// Background health check: rotates through DCs, 30s interval.
|
||||||
|
pub async fn run_health_checks(&self, prefer_ipv6: bool) {
|
||||||
|
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
|
||||||
|
let mut dc_rotation = 0usize;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||||
|
|
||||||
|
let dc_zero_idx = dc_rotation % datacenters.len();
|
||||||
|
dc_rotation += 1;
|
||||||
|
|
||||||
|
let check_target = SocketAddr::new(datacenters[dc_zero_idx], TG_DATACENTER_PORT);
|
||||||
|
|
||||||
|
let count = self.upstreams.read().await.len();
|
||||||
|
for i in 0..count {
|
||||||
|
let config = {
|
||||||
|
let guard = self.upstreams.read().await;
|
||||||
|
guard[i].config.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let start = Instant::now();
|
||||||
|
let result = tokio::time::timeout(
|
||||||
|
Duration::from_secs(10),
|
||||||
|
self.connect_via_upstream(&config, check_target)
|
||||||
|
).await;
|
||||||
|
|
||||||
|
let mut guard = self.upstreams.write().await;
|
||||||
|
let u = &mut guard[i];
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(Ok(_stream)) => {
|
||||||
|
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
|
||||||
|
u.dc_latency[dc_zero_idx].update(rtt_ms);
|
||||||
|
|
||||||
|
if !u.healthy {
|
||||||
|
info!(
|
||||||
|
rtt = format!("{:.0}ms", rtt_ms),
|
||||||
|
dc = dc_zero_idx + 1,
|
||||||
|
"Upstream recovered"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
u.healthy = true;
|
||||||
|
u.fails = 0;
|
||||||
|
}
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
u.fails += 1;
|
||||||
|
debug!(dc = dc_zero_idx + 1, fails = u.fails,
|
||||||
|
"Health check failed: {}", e);
|
||||||
|
if u.fails > 3 {
|
||||||
|
u.healthy = false;
|
||||||
|
warn!("Upstream unhealthy (fails)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
u.fails += 1;
|
||||||
|
debug!(dc = dc_zero_idx + 1, fails = u.fails,
|
||||||
|
"Health check timeout");
|
||||||
|
if u.fails > 3 {
|
||||||
|
u.healthy = false;
|
||||||
|
warn!("Upstream unhealthy (timeout)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
u.last_check = std::time::Instant::now();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
//! IP Addr Detect
|
//! IP Addr Detect
|
||||||
|
|
||||||
use std::net::IpAddr;
|
use std::net::{IpAddr, SocketAddr, UdpSocket};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tracing::{debug, warn};
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
@@ -40,31 +40,77 @@ const IPV6_URLS: &[&str] = &[
|
|||||||
"http://api6.ipify.org/",
|
"http://api6.ipify.org/",
|
||||||
];
|
];
|
||||||
|
|
||||||
|
/// Detect local IP address by connecting to a public DNS
|
||||||
|
/// This does not actually send any packets
|
||||||
|
fn get_local_ip(target: &str) -> Option<IpAddr> {
|
||||||
|
let socket = UdpSocket::bind("0.0.0.0:0").ok()?;
|
||||||
|
socket.connect(target).ok()?;
|
||||||
|
socket.local_addr().ok().map(|addr| addr.ip())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_local_ipv6(target: &str) -> Option<IpAddr> {
|
||||||
|
let socket = UdpSocket::bind("[::]:0").ok()?;
|
||||||
|
socket.connect(target).ok()?;
|
||||||
|
socket.local_addr().ok().map(|addr| addr.ip())
|
||||||
|
}
|
||||||
|
|
||||||
/// Detect public IP addresses
|
/// Detect public IP addresses
|
||||||
pub async fn detect_ip() -> IpInfo {
|
pub async fn detect_ip() -> IpInfo {
|
||||||
let mut info = IpInfo::default();
|
let mut info = IpInfo::default();
|
||||||
|
|
||||||
// Detect IPv4
|
// Try to get local interface IP first (default gateway interface)
|
||||||
|
// We connect to Google DNS to find out which interface is used for routing
|
||||||
|
if let Some(ip) = get_local_ip("8.8.8.8:80") {
|
||||||
|
if ip.is_ipv4() && !ip.is_loopback() {
|
||||||
|
info.ipv4 = Some(ip);
|
||||||
|
debug!(ip = %ip, "Detected local IPv4 address via routing");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ip) = get_local_ipv6("[2001:4860:4860::8888]:80") {
|
||||||
|
if ip.is_ipv6() && !ip.is_loopback() {
|
||||||
|
info.ipv6 = Some(ip);
|
||||||
|
debug!(ip = %ip, "Detected local IPv6 address via routing");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If local detection failed or returned private IP (and we want public),
|
||||||
|
// or just as a fallback/verification, we might want to check external services.
|
||||||
|
// However, the requirement is: "if IP for listening is not set... it should be IP from interface...
|
||||||
|
// if impossible - request external resources".
|
||||||
|
|
||||||
|
// So if we found a local IP, we might be good. But often servers are behind NAT.
|
||||||
|
// If the local IP is private, we probably want the public IP for the tg:// link.
|
||||||
|
// Let's check if the detected IPs are private.
|
||||||
|
|
||||||
|
let need_external_v4 = info.ipv4.map_or(true, |ip| is_private_ip(ip));
|
||||||
|
let need_external_v6 = info.ipv6.map_or(true, |ip| is_private_ip(ip));
|
||||||
|
|
||||||
|
if need_external_v4 {
|
||||||
|
debug!("Local IPv4 is private or missing, checking external services...");
|
||||||
for url in IPV4_URLS {
|
for url in IPV4_URLS {
|
||||||
if let Some(ip) = fetch_ip(url).await {
|
if let Some(ip) = fetch_ip(url).await {
|
||||||
if ip.is_ipv4() {
|
if ip.is_ipv4() {
|
||||||
info.ipv4 = Some(ip);
|
info.ipv4 = Some(ip);
|
||||||
debug!(ip = %ip, "Detected IPv4 address");
|
debug!(ip = %ip, "Detected public IPv4 address");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Detect IPv6
|
if need_external_v6 {
|
||||||
|
debug!("Local IPv6 is private or missing, checking external services...");
|
||||||
for url in IPV6_URLS {
|
for url in IPV6_URLS {
|
||||||
if let Some(ip) = fetch_ip(url).await {
|
if let Some(ip) = fetch_ip(url).await {
|
||||||
if ip.is_ipv6() {
|
if ip.is_ipv6() {
|
||||||
info.ipv6 = Some(ip);
|
info.ipv6 = Some(ip);
|
||||||
debug!(ip = %ip, "Detected IPv6 address");
|
debug!(ip = %ip, "Detected public IPv6 address");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !info.has_any() {
|
if !info.has_any() {
|
||||||
warn!("Failed to detect public IP address");
|
warn!("Failed to detect public IP address");
|
||||||
@@ -73,6 +119,17 @@ pub async fn detect_ip() -> IpInfo {
|
|||||||
info
|
info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn is_private_ip(ip: IpAddr) -> bool {
|
||||||
|
match ip {
|
||||||
|
IpAddr::V4(ipv4) => {
|
||||||
|
ipv4.is_private() || ipv4.is_loopback() || ipv4.is_link_local()
|
||||||
|
}
|
||||||
|
IpAddr::V6(ipv6) => {
|
||||||
|
ipv6.is_loopback() || (ipv6.segments()[0] & 0xfe00) == 0xfc00 // Unique Local
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Fetch IP from URL
|
/// Fetch IP from URL
|
||||||
async fn fetch_ip(url: &str) -> Option<IpAddr> {
|
async fn fetch_ip(url: &str) -> Option<IpAddr> {
|
||||||
let client = reqwest::Client::builder()
|
let client = reqwest::Client::builder()
|
||||||
|
|||||||
12
telemt.service
Normal file
12
telemt.service
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
[Unit]
|
||||||
|
Description=Telemt
|
||||||
|
After=network.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
WorkingDirectory=/bin
|
||||||
|
ExecStart=/bin/telemt /etc/telemt.toml
|
||||||
|
Restart=on-failure
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
Reference in New Issue
Block a user